1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
use crate::Conn;
use futures_lite::{AsyncRead, AsyncWrite};
use std::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};

pub(crate) struct LivenessFut<'a, T>(&'a mut Conn<T>);

impl<'a, T> LivenessFut<'a, T> {
    pub(crate) fn new(conn: &'a mut Conn<T>) -> Self {
        Self(conn)
    }
}

impl<T> Future for LivenessFut<'_, T>
where
    T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
{
    type Output = ();

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let LivenessFut(Conn {
            buffer, transport, ..
        }) = &mut *self;

        let len = buffer.len();
        buffer.expand();
        match Pin::new(transport).poll_read(cx, &mut buffer[len..]) {
            Poll::Pending => {
                buffer.truncate(len);
                Poll::Pending
            }

            Poll::Ready(Err(_)) => {
                buffer.truncate(len);
                Poll::Ready(())
            }

            Poll::Ready(Ok(n)) => {
                buffer.truncate(len + n);
                if n == 0 {
                    Poll::Ready(())
                } else {
                    cx.waker().wake_by_ref();
                    Poll::Pending
                }
            }
        }
    }
}

pub(crate) struct CancelOnDisconnect<'a, Fut, T>(
    pub(crate) &'a mut Conn<T>,
    pub(crate) Pin<&'a mut Fut>,
);
impl<'a, Fut, T> Future for CancelOnDisconnect<'a, Fut, T>
where
    Fut: Future + Send + 'a,
    T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
{
    type Output = Option<Fut::Output>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let CancelOnDisconnect(conn, fut) = &mut *self;
        let fut_poll = fut.as_mut().poll(cx);
        let disconnect = Pin::new(&mut LivenessFut(conn)).poll(cx);
        match (fut_poll, disconnect) {
            (Poll::Ready(output), _) => Poll::Ready(Some(output)),
            (Poll::Pending, Poll::Ready(())) => Poll::Ready(None),
            (Poll::Pending, Poll::Pending) => Poll::Pending,
        }
    }
}