Skip to main content

trillium_client/
pool.rs

1use crossbeam_queue::ArrayQueue;
2use dashmap::{mapref::entry::Entry, DashMap};
3use std::{
4    borrow::Borrow,
5    fmt::{self, Debug, Formatter},
6    hash::Hash,
7    sync::Arc,
8    time::Instant,
9};
10
11pub const DEFAULT_CONNECTIONS: usize = 16;
12
13pub struct PoolEntry<V> {
14    item: V,
15    expiry: Option<Instant>,
16}
17
18impl<V: Debug> Debug for PoolEntry<V> {
19    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
20        f.debug_struct("PoolEntry")
21            .field("item", &self.item)
22            .field("expiry", &self.expiry)
23            .finish()
24    }
25}
26
27impl<V> PoolEntry<V> {
28    pub fn new(item: V, expiry: Option<Instant>) -> Self {
29        Self { item, expiry }
30    }
31
32    pub fn is_expired(&self) -> bool {
33        match self.expiry {
34            None => false,
35            Some(instant) => instant < Instant::now(),
36        }
37    }
38
39    pub fn take(self) -> Option<V> {
40        if self.is_expired() {
41            None
42        } else {
43            Some(self.item)
44        }
45    }
46}
47
48pub struct PoolSet<V>(Arc<ArrayQueue<PoolEntry<V>>>);
49impl<V: Debug> Debug for PoolSet<V> {
50    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
51        f.debug_tuple("PoolSet").field(&self.0).finish()
52    }
53}
54
55impl<V> Default for PoolSet<V> {
56    fn default() -> Self {
57        Self::new(DEFAULT_CONNECTIONS)
58    }
59}
60
61impl<V> Clone for PoolSet<V> {
62    fn clone(&self) -> Self {
63        Self(Arc::clone(&self.0))
64    }
65}
66
67impl<V> PoolSet<V> {
68    pub fn insert(&self, entry: PoolEntry<V>) {
69        self.0.force_push(entry);
70    }
71
72    pub fn new(size: usize) -> Self {
73        Self(Arc::new(ArrayQueue::new(size)))
74    }
75
76    pub fn is_empty(&self) -> bool {
77        self.0.is_empty()
78    }
79}
80
81impl<V> Iterator for PoolSet<V> {
82    type Item = PoolEntry<V>;
83
84    fn next(&mut self) -> Option<Self::Item> {
85        self.0.pop()
86    }
87}
88
89pub struct Pool<K, V> {
90    pub(crate) max_set_size: usize,
91    connections: Arc<DashMap<K, PoolSet<V>>>,
92}
93
94struct Connections<'a, K, V>(&'a DashMap<K, PoolSet<V>>);
95impl<K, V> Debug for Connections<'_, K, V>
96where
97    K: Hash + Debug + Eq,
98{
99    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
100        let mut map = f.debug_map();
101        for item in self.0 {
102            let (k, v) = item.pair();
103            map.entry(&k, &v.0.len());
104        }
105
106        map.finish()
107    }
108}
109
110impl<K, V> Debug for Pool<K, V>
111where
112    K: Hash + Debug + Eq,
113{
114    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
115        f.debug_struct("Pool")
116            .field("max_set_size", &self.max_set_size)
117            .field("connections", &Connections(&self.connections))
118            .finish()
119    }
120}
121
122impl<K, V> Clone for Pool<K, V> {
123    fn clone(&self) -> Self {
124        Self {
125            connections: Arc::clone(&self.connections),
126            max_set_size: self.max_set_size,
127        }
128    }
129}
130
131impl<K, V> Default for Pool<K, V>
132where
133    K: Hash + Debug + Eq,
134{
135    fn default() -> Self {
136        Self {
137            connections: Default::default(),
138            max_set_size: DEFAULT_CONNECTIONS,
139        }
140    }
141}
142
143impl<K, V> Pool<K, V>
144where
145    K: Hash + Debug + Eq + Clone + Debug,
146{
147    #[allow(dead_code)]
148    pub fn new(max_set_size: usize) -> Self {
149        Self {
150            connections: Default::default(),
151            max_set_size,
152        }
153    }
154
155    pub fn insert(&self, k: K, entry: PoolEntry<V>) {
156        log::debug!("saving connection to {:?}", &k);
157        match self.connections.entry(k) {
158            Entry::Occupied(o) => {
159                o.get().insert(entry);
160            }
161
162            Entry::Vacant(v) => {
163                let pool_set = PoolSet::new(self.max_set_size);
164                pool_set.insert(entry);
165                v.insert(pool_set);
166            }
167        }
168    }
169
170    #[allow(dead_code)]
171    pub fn keys(&self) -> impl Iterator<Item = K> + '_ {
172        self.connections.iter().map(|k| (*k.key()).clone())
173    }
174
175    pub fn candidates<Q>(&self, key: &Q) -> impl Iterator<Item = V>
176    where
177        K: Borrow<Q>,
178        Q: Hash + Eq + ?Sized,
179    {
180        self.connections
181            .get(key)
182            .map(|poolset| poolset.clone().filter_map(|v| v.take()))
183            .into_iter()
184            .flatten()
185    }
186
187    pub fn cleanup(&self) {
188        self.connections.retain(|_k, v| !v.is_empty())
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use trillium_server_common::Url;
195
196    use super::*;
197
198    #[test]
199    fn basic_pool_functionality() {
200        let pool = Pool::default();
201        for n in 0..5 {
202            pool.insert(String::from("127.0.0.1:8080"), PoolEntry::new(n, None));
203        }
204
205        assert_eq!(pool.candidates("127.0.0.1:8080").next(), Some(0));
206        assert_eq!(
207            pool.candidates("127.0.0.1:8080").collect::<Vec<_>>(),
208            vec![1, 2, 3, 4]
209        );
210    }
211
212    #[test]
213    fn eviction() {
214        let pool = Pool::new(5);
215        for n in 0..10 {
216            pool.insert(
217                Url::parse("http://127.0.0.1:8080").unwrap().origin(),
218                PoolEntry::new(n, None),
219            );
220        }
221
222        assert_eq!(
223            pool.candidates(&Url::parse("http://127.0.0.1:8080").unwrap().origin())
224                .collect::<Vec<_>>(),
225            vec![5, 6, 7, 8, 9]
226        );
227    }
228
229    #[test]
230    fn cleanup() {
231        let pool = Pool::new(5);
232        for n in 0..10 {
233            pool.insert(
234                Url::parse("http://127.0.0.1:8080").unwrap().origin(),
235                PoolEntry::new(n, None),
236            );
237            pool.insert(
238                Url::parse("http://0.0.0.0:1234").unwrap().origin(),
239                PoolEntry::new(n * 100, None),
240            );
241        }
242        assert_eq!(pool.keys().count(), 2);
243        pool.cleanup(); // no change
244        assert_eq!(pool.keys().count(), 2);
245        let _ = pool
246            .candidates(&Url::parse("http://0.0.0.0:1234").unwrap().origin())
247            .collect::<Vec<_>>();
248        assert_eq!(pool.keys().count(), 2); // haven't cleaned up
249        pool.cleanup();
250        assert_eq!(pool.keys().count(), 1);
251    }
252}