zephyr/sync/channel/
counter.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Reference counter for channels.
3
4// This file is taken from crossbeam-channels, with modifications to be nostd.
5
6extern crate alloc;
7
8use crate::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use alloc::boxed::Box;
10use core::ops;
11use core::ptr::NonNull;
12
13/// Reference counter internals.
14struct Counter<C> {
15    /// The number of senders associated with the channel.
16    senders: AtomicUsize,
17
18    /// The number of receivers associated with the channel.
19    receivers: AtomicUsize,
20
21    /// Set to `true` if the last sender or the last receiver reference deallocates the channel.
22    destroy: AtomicBool,
23
24    /// The internal channel.
25    chan: C,
26}
27
28/// Wraps a channel into the reference counter.
29pub(crate) fn new<C>(chan: C) -> (Sender<C>, Receiver<C>) {
30    let counter = NonNull::from(Box::leak(Box::new(Counter {
31        senders: AtomicUsize::new(1),
32        receivers: AtomicUsize::new(1),
33        destroy: AtomicBool::new(false),
34        chan,
35    })));
36    let s = Sender { counter };
37    let r = Receiver { counter };
38    (s, r)
39}
40
41/// The sending side.
42pub(crate) struct Sender<C> {
43    counter: NonNull<Counter<C>>,
44}
45
46impl<C> Sender<C> {
47    /// Returns the internal `Counter`.
48    fn counter(&self) -> &Counter<C> {
49        unsafe { self.counter.as_ref() }
50    }
51
52    /// Acquires another sender reference.
53    pub(crate) fn acquire(&self) -> Self {
54        let count = self.counter().senders.fetch_add(1, Ordering::Relaxed);
55
56        // Cloning senders and calling `mem::forget` on the clones could potentially overflow the
57        // counter. It's very difficult to recover sensibly from such degenerate scenarios so we
58        // just abort when the count becomes very large.
59        if count > isize::MAX as usize {
60            // TODO: We need some kind of equivalent here.
61            unimplemented!();
62        }
63
64        Self {
65            counter: self.counter,
66        }
67    }
68
69    /// Releases the sender reference.
70    ///
71    /// Function `disconnect` will be called if this is the last sender reference.
72    pub(crate) unsafe fn release<F: FnOnce(&C) -> bool>(&self, disconnect: F) {
73        if self.counter().senders.fetch_sub(1, Ordering::AcqRel) == 1 {
74            disconnect(&self.counter().chan);
75
76            if self.counter().destroy.swap(true, Ordering::AcqRel) {
77                drop(unsafe { Box::from_raw(self.counter.as_ptr()) });
78            }
79        }
80    }
81}
82
83impl<C> ops::Deref for Sender<C> {
84    type Target = C;
85
86    fn deref(&self) -> &C {
87        &self.counter().chan
88    }
89}
90
91impl<C> PartialEq for Sender<C> {
92    fn eq(&self, other: &Self) -> bool {
93        self.counter == other.counter
94    }
95}
96
97/// The receiving side.
98pub(crate) struct Receiver<C> {
99    counter: NonNull<Counter<C>>,
100}
101
102impl<C> Receiver<C> {
103    /// Returns the internal `Counter`.
104    fn counter(&self) -> &Counter<C> {
105        unsafe { self.counter.as_ref() }
106    }
107
108    /// Acquires another receiver reference.
109    pub(crate) fn acquire(&self) -> Self {
110        let count = self.counter().receivers.fetch_add(1, Ordering::Relaxed);
111
112        // Cloning receivers and calling `mem::forget` on the clones could potentially overflow the
113        // counter. It's very difficult to recover sensibly from such degenerate scenarios so we
114        // just abort when the count becomes very large.
115        if count > isize::MAX as usize {
116            unimplemented!();
117        }
118
119        Self {
120            counter: self.counter,
121        }
122    }
123
124    /// Releases the receiver reference.
125    ///
126    /// Function `disconnect` will be called if this is the last receiver reference.
127    pub(crate) unsafe fn release<F: FnOnce(&C) -> bool>(&self, disconnect: F) {
128        if self.counter().receivers.fetch_sub(1, Ordering::AcqRel) == 1 {
129            disconnect(&self.counter().chan);
130
131            if self.counter().destroy.swap(true, Ordering::AcqRel) {
132                drop(unsafe { Box::from_raw(self.counter.as_ptr()) });
133            }
134        }
135    }
136}
137
138impl<C> ops::Deref for Receiver<C> {
139    type Target = C;
140
141    fn deref(&self) -> &C {
142        &self.counter().chan
143    }
144}
145
146impl<C> PartialEq for Receiver<C> {
147    fn eq(&self, other: &Self) -> bool {
148        self.counter == other.counter
149    }
150}