1use crossbeam_queue::ArrayQueue;
2use dashmap::{mapref::entry::Entry, DashMap};
3use std::{
4 borrow::Borrow,
5 fmt::{self, Debug, Formatter},
6 hash::Hash,
7 sync::Arc,
8 time::Instant,
9};
10
11pub const DEFAULT_CONNECTIONS: usize = 16;
12
13pub struct PoolEntry<V> {
14 item: V,
15 expiry: Option<Instant>,
16}
17
18impl<V: Debug> Debug for PoolEntry<V> {
19 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
20 f.debug_struct("PoolEntry")
21 .field("item", &self.item)
22 .field("expiry", &self.expiry)
23 .finish()
24 }
25}
26
27impl<V> PoolEntry<V> {
28 pub fn new(item: V, expiry: Option<Instant>) -> Self {
29 Self { item, expiry }
30 }
31
32 pub fn is_expired(&self) -> bool {
33 match self.expiry {
34 None => false,
35 Some(instant) => instant < Instant::now(),
36 }
37 }
38
39 pub fn take(self) -> Option<V> {
40 if self.is_expired() {
41 None
42 } else {
43 Some(self.item)
44 }
45 }
46}
47
48pub struct PoolSet<V>(Arc<ArrayQueue<PoolEntry<V>>>);
49impl<V: Debug> Debug for PoolSet<V> {
50 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
51 f.debug_tuple("PoolSet").field(&self.0).finish()
52 }
53}
54
55impl<V> Default for PoolSet<V> {
56 fn default() -> Self {
57 Self::new(DEFAULT_CONNECTIONS)
58 }
59}
60
61impl<V> Clone for PoolSet<V> {
62 fn clone(&self) -> Self {
63 Self(Arc::clone(&self.0))
64 }
65}
66
67impl<V> PoolSet<V> {
68 pub fn insert(&self, entry: PoolEntry<V>) {
69 self.0.force_push(entry);
70 }
71
72 pub fn new(size: usize) -> Self {
73 Self(Arc::new(ArrayQueue::new(size)))
74 }
75
76 pub fn is_empty(&self) -> bool {
77 self.0.is_empty()
78 }
79}
80
81impl<V> Iterator for PoolSet<V> {
82 type Item = PoolEntry<V>;
83
84 fn next(&mut self) -> Option<Self::Item> {
85 self.0.pop()
86 }
87}
88
89pub struct Pool<K, V> {
90 pub(crate) max_set_size: usize,
91 connections: Arc<DashMap<K, PoolSet<V>>>,
92}
93
94struct Connections<'a, K, V>(&'a DashMap<K, PoolSet<V>>);
95impl<K, V> Debug for Connections<'_, K, V>
96where
97 K: Hash + Debug + Eq,
98{
99 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
100 let mut map = f.debug_map();
101 for item in self.0 {
102 let (k, v) = item.pair();
103 map.entry(&k, &v.0.len());
104 }
105
106 map.finish()
107 }
108}
109
110impl<K, V> Debug for Pool<K, V>
111where
112 K: Hash + Debug + Eq,
113{
114 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
115 f.debug_struct("Pool")
116 .field("max_set_size", &self.max_set_size)
117 .field("connections", &Connections(&self.connections))
118 .finish()
119 }
120}
121
122impl<K, V> Clone for Pool<K, V> {
123 fn clone(&self) -> Self {
124 Self {
125 connections: Arc::clone(&self.connections),
126 max_set_size: self.max_set_size,
127 }
128 }
129}
130
131impl<K, V> Default for Pool<K, V>
132where
133 K: Hash + Debug + Eq,
134{
135 fn default() -> Self {
136 Self {
137 connections: Default::default(),
138 max_set_size: DEFAULT_CONNECTIONS,
139 }
140 }
141}
142
143impl<K, V> Pool<K, V>
144where
145 K: Hash + Debug + Eq + Clone + Debug,
146{
147 #[allow(dead_code)]
148 pub fn new(max_set_size: usize) -> Self {
149 Self {
150 connections: Default::default(),
151 max_set_size,
152 }
153 }
154
155 pub fn insert(&self, k: K, entry: PoolEntry<V>) {
156 log::debug!("saving connection to {:?}", &k);
157 match self.connections.entry(k) {
158 Entry::Occupied(o) => {
159 o.get().insert(entry);
160 }
161
162 Entry::Vacant(v) => {
163 let pool_set = PoolSet::new(self.max_set_size);
164 pool_set.insert(entry);
165 v.insert(pool_set);
166 }
167 }
168 }
169
170 #[allow(dead_code)]
171 pub fn keys(&self) -> impl Iterator<Item = K> + '_ {
172 self.connections.iter().map(|k| (*k.key()).clone())
173 }
174
175 pub fn candidates<Q>(&self, key: &Q) -> impl Iterator<Item = V>
176 where
177 K: Borrow<Q>,
178 Q: Hash + Eq + ?Sized,
179 {
180 self.connections
181 .get(key)
182 .map(|poolset| poolset.clone().filter_map(|v| v.take()))
183 .into_iter()
184 .flatten()
185 }
186
187 pub fn cleanup(&self) {
188 self.connections.retain(|_k, v| !v.is_empty())
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use trillium_server_common::Url;
195
196 use super::*;
197
198 #[test]
199 fn basic_pool_functionality() {
200 let pool = Pool::default();
201 for n in 0..5 {
202 pool.insert(String::from("127.0.0.1:8080"), PoolEntry::new(n, None));
203 }
204
205 assert_eq!(pool.candidates("127.0.0.1:8080").next(), Some(0));
206 assert_eq!(
207 pool.candidates("127.0.0.1:8080").collect::<Vec<_>>(),
208 vec![1, 2, 3, 4]
209 );
210 }
211
212 #[test]
213 fn eviction() {
214 let pool = Pool::new(5);
215 for n in 0..10 {
216 pool.insert(
217 Url::parse("http://127.0.0.1:8080").unwrap().origin(),
218 PoolEntry::new(n, None),
219 );
220 }
221
222 assert_eq!(
223 pool.candidates(&Url::parse("http://127.0.0.1:8080").unwrap().origin())
224 .collect::<Vec<_>>(),
225 vec![5, 6, 7, 8, 9]
226 );
227 }
228
229 #[test]
230 fn cleanup() {
231 let pool = Pool::new(5);
232 for n in 0..10 {
233 pool.insert(
234 Url::parse("http://127.0.0.1:8080").unwrap().origin(),
235 PoolEntry::new(n, None),
236 );
237 pool.insert(
238 Url::parse("http://0.0.0.0:1234").unwrap().origin(),
239 PoolEntry::new(n * 100, None),
240 );
241 }
242 assert_eq!(pool.keys().count(), 2);
243 pool.cleanup(); assert_eq!(pool.keys().count(), 2);
245 let _ = pool
246 .candidates(&Url::parse("http://0.0.0.0:1234").unwrap().origin())
247 .collect::<Vec<_>>();
248 assert_eq!(pool.keys().count(), 2); pool.cleanup();
250 assert_eq!(pool.keys().count(), 1);
251 }
252}