Skip to main content

trillium_grpc/frame/
reader.rs

1//! Decode a gRPC body into a stream of messages. See [`MessageStream`].
2
3use crate::{Encoding, Status};
4use futures_lite::{AsyncRead, Stream};
5use std::{
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10pub use crate::encoding::DEFAULT_MAX_MESSAGE_SIZE;
11
12/// gRPC wire framing: 5-byte prefix (1 byte compressed-flag, 4 bytes
13/// big-endian length) followed by payload.
14const PREFIX_LEN: usize = 5;
15
16/// Stream of decoded messages over a length-prefixed gRPC body.
17///
18/// Wraps any `AsyncRead` (request body or response body) and yields decoded
19/// messages until the underlying reader signals EOF cleanly between frames.
20/// EOF mid-frame produces an error item and ends the stream.
21///
22/// When the per-message Compressed-Flag is set, the payload is run through
23/// the [`Encoding`] configured via [`with_encoding`](Self::with_encoding).
24/// `Identity` (the default) rejects compressed frames with `Internal` —
25/// the peer claimed compression after we advertised none.
26pub struct MessageStream<T, R> {
27    reader: R,
28    state: ReadState,
29    max_message_size: usize,
30    encoding: Encoding,
31    decode: fn(&[u8]) -> Result<T, Status>,
32}
33
34pub(crate) enum ReadState {
35    ReadingPrefix {
36        buf: [u8; PREFIX_LEN],
37        filled: usize,
38    },
39    ReadingPayload {
40        compressed: bool,
41        payload: Vec<u8>,
42        filled: usize,
43    },
44    Done,
45}
46
47impl ReadState {
48    pub(crate) fn new() -> Self {
49        Self::ReadingPrefix {
50            buf: [0u8; PREFIX_LEN],
51            filled: 0,
52        }
53    }
54}
55
56impl<T, R> MessageStream<T, R> {
57    /// Wrap `reader` (a gRPC body), decoding each frame's payload with
58    /// `decode`. Defaults to `Identity` encoding and the
59    /// [`DEFAULT_MAX_MESSAGE_SIZE`] cap; adjust with
60    /// [`with_encoding`](Self::with_encoding) and
61    /// [`with_max_message_size`](Self::with_max_message_size).
62    pub fn new(reader: R, decode: fn(&[u8]) -> Result<T, Status>) -> Self {
63        Self {
64            reader,
65            state: ReadState::new(),
66            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
67            encoding: Encoding::Identity,
68            decode,
69        }
70    }
71
72    /// Reject any single message whose framed length (or decompressed size)
73    /// exceeds `max` bytes, with `ResourceExhausted`.
74    pub fn with_max_message_size(mut self, max: usize) -> Self {
75        self.max_message_size = max;
76        self
77    }
78
79    /// Decompress payloads with the per-message Compressed-Flag set using
80    /// this encoding. Compressed frames received without an encoding
81    /// configured (the default `Identity`) are rejected.
82    pub fn with_encoding(mut self, encoding: Encoding) -> Self {
83        self.encoding = encoding;
84        self
85    }
86}
87
88impl<T, R> Stream for MessageStream<T, R>
89where
90    R: AsyncRead + Unpin,
91    T: 'static,
92{
93    type Item = Result<T, Status>;
94
95    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
96        let this = self.get_mut();
97        poll_read_message(
98            Pin::new(&mut this.reader),
99            &mut this.state,
100            cx,
101            this.decode,
102            this.encoding,
103            this.max_message_size,
104        )
105    }
106}
107
108/// Drive one step of the message-read state machine on an `AsyncRead`.
109///
110/// Returns `Poll::Ready(None)` for a clean EOF between frames, `Poll::Ready(Some(Err))`
111/// for a per-message error (state transitions to `Done` afterwards), and
112/// `Poll::Pending` when the reader has no more bytes ready.
113///
114/// The same `state` value must be passed in across polls so partial prefix
115/// and partial payload reads can resume. Once the state reaches `Done`,
116/// further calls return `Poll::Ready(None)`.
117pub(crate) fn poll_read_message<T, R>(
118    mut reader: Pin<&mut R>,
119    state: &mut ReadState,
120    cx: &mut Context<'_>,
121    decode: fn(&[u8]) -> Result<T, Status>,
122    encoding: Encoding,
123    max_message_size: usize,
124) -> Poll<Option<Result<T, Status>>>
125where
126    R: AsyncRead + ?Sized,
127{
128    loop {
129        match state {
130            ReadState::Done => return Poll::Ready(None),
131
132            ReadState::ReadingPrefix { buf, filled } => {
133                while *filled < PREFIX_LEN {
134                    let dst = &mut buf[*filled..];
135                    match reader.as_mut().poll_read(cx, dst) {
136                        Poll::Pending => return Poll::Pending,
137                        Poll::Ready(Err(e)) => {
138                            *state = ReadState::Done;
139                            return Poll::Ready(Some(Err(Status::unavailable(format!(
140                                "read error: {e}"
141                            )))));
142                        }
143                        Poll::Ready(Ok(0)) => {
144                            if *filled == 0 {
145                                *state = ReadState::Done;
146                                return Poll::Ready(None);
147                            } else {
148                                *state = ReadState::Done;
149                                return Poll::Ready(Some(Err(Status::internal(
150                                    "unexpected EOF in frame prefix",
151                                ))));
152                            }
153                        }
154                        Poll::Ready(Ok(n)) => *filled += n,
155                    }
156                }
157
158                let compressed = buf[0] != 0;
159                let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
160
161                if len > max_message_size {
162                    *state = ReadState::Done;
163                    return Poll::Ready(Some(Err(Status::resource_exhausted(format!(
164                        "received message of {len} bytes exceeds limit of {max_message_size}"
165                    )))));
166                }
167
168                *state = ReadState::ReadingPayload {
169                    compressed,
170                    payload: vec![0u8; len],
171                    filled: 0,
172                };
173            }
174
175            ReadState::ReadingPayload {
176                compressed,
177                payload,
178                filled,
179            } => {
180                while *filled < payload.len() {
181                    let dst = &mut payload[*filled..];
182                    match reader.as_mut().poll_read(cx, dst) {
183                        Poll::Pending => return Poll::Pending,
184                        Poll::Ready(Err(e)) => {
185                            *state = ReadState::Done;
186                            return Poll::Ready(Some(Err(Status::unavailable(format!(
187                                "read error: {e}"
188                            )))));
189                        }
190                        Poll::Ready(Ok(0)) => {
191                            *state = ReadState::Done;
192                            return Poll::Ready(Some(Err(Status::internal(
193                                "unexpected EOF in frame payload",
194                            ))));
195                        }
196                        Poll::Ready(Ok(n)) => *filled += n,
197                    }
198                }
199
200                let compressed = *compressed;
201                let payload = std::mem::take(payload);
202                *state = ReadState::new();
203
204                let bytes = if compressed {
205                    if matches!(encoding, Encoding::Identity) {
206                        return Poll::Ready(Some(Err(Status::internal(
207                            "received compressed message but no encoding negotiated",
208                        ))));
209                    }
210                    match encoding.decompress(&payload, max_message_size) {
211                        Ok(b) => b,
212                        Err(status) => return Poll::Ready(Some(Err(status))),
213                    }
214                } else {
215                    payload
216                };
217
218                return Poll::Ready(Some(decode(&bytes)));
219            }
220        }
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::{Code, Codec, codec::Prost};
228    use futures_lite::{StreamExt, future::block_on};
229
230    /// Helper: build a single framed message: [compressed=0, len BE u32, payload].
231    fn frame(payload: &[u8]) -> Vec<u8> {
232        let mut out = Vec::with_capacity(PREFIX_LEN + payload.len());
233        out.push(0); // compressed flag
234        out.extend_from_slice(&(payload.len() as u32).to_be_bytes());
235        out.extend_from_slice(payload);
236        out
237    }
238
239    fn vec_decode(bytes: &[u8]) -> Result<Vec<u8>, Status> {
240        <Prost as Codec<Vec<u8>>>::decode(bytes)
241    }
242
243    type Stream<'a> = MessageStream<Vec<u8>, &'a [u8]>;
244
245    fn new_stream(bytes: &[u8]) -> Stream<'_> {
246        MessageStream::new(bytes, vec_decode)
247    }
248
249    #[test]
250    fn empty_input_yields_none() {
251        let bytes: &[u8] = &[];
252        let mut s = new_stream(bytes);
253        assert!(block_on(s.next()).is_none());
254    }
255
256    #[test]
257    fn single_empty_message() {
258        let body = frame(&[]);
259        let mut s = new_stream(&body[..]);
260        // Vec<u8> as a prost Message decodes from empty bytes to empty Vec
261        let msg = block_on(s.next()).unwrap().unwrap();
262        assert!(msg.is_empty());
263        assert!(block_on(s.next()).is_none());
264    }
265
266    #[test]
267    fn multiple_messages() {
268        // Vec<u8> as a prost top-level Message: a `bytes` field at tag 1.
269        // Encoding: tag byte 0x0A, then varint-len, then payload.
270        // For payload b"hi": [0x0A, 0x02, b'h', b'i']
271        let mut body = Vec::new();
272        body.extend_from_slice(&frame(&[0x0A, 0x02, b'h', b'i']));
273        body.extend_from_slice(&frame(&[0x0A, 0x03, b'b', b'y', b'e']));
274
275        let mut s = new_stream(&body[..]);
276        let m1 = block_on(s.next()).unwrap().unwrap();
277        let m2 = block_on(s.next()).unwrap().unwrap();
278        assert_eq!(m1, b"hi");
279        assert_eq!(m2, b"bye");
280        assert!(block_on(s.next()).is_none());
281    }
282
283    #[test]
284    fn partial_prefix_at_eof_is_error() {
285        let body = [0u8, 0u8, 0u8]; // 3 of 5 prefix bytes
286        let mut s = new_stream(&body[..]);
287        let err = block_on(s.next()).unwrap().unwrap_err();
288        assert_eq!(err.code, Code::Internal);
289        assert!(block_on(s.next()).is_none());
290    }
291
292    #[test]
293    fn partial_payload_at_eof_is_error() {
294        let mut body = Vec::new();
295        body.push(0); // compressed
296        body.extend_from_slice(&10u32.to_be_bytes()); // claim 10 bytes
297        body.extend_from_slice(&[1, 2, 3]); // only deliver 3
298        let mut s = new_stream(&body[..]);
299        let err = block_on(s.next()).unwrap().unwrap_err();
300        assert_eq!(err.code, Code::Internal);
301    }
302
303    #[test]
304    fn oversized_message_is_resource_exhausted() {
305        let mut body = Vec::new();
306        body.push(0);
307        body.extend_from_slice(&100u32.to_be_bytes());
308        let mut s = new_stream(&body[..]).with_max_message_size(50);
309        let err = block_on(s.next()).unwrap().unwrap_err();
310        assert_eq!(err.code, Code::ResourceExhausted);
311    }
312
313    #[test]
314    fn compressed_flag_with_identity_encoding_is_internal() {
315        // Peer set the Compressed-Flag but we negotiated no encoding —
316        // protocol error.
317        let mut body = Vec::new();
318        body.push(1); // compressed
319        body.extend_from_slice(&0u32.to_be_bytes());
320        let mut s = new_stream(&body[..]);
321        let err = block_on(s.next()).unwrap().unwrap_err();
322        assert_eq!(err.code, Code::Internal);
323    }
324
325    #[cfg(feature = "gzip")]
326    #[test]
327    fn compressed_frame_decompressed_with_gzip() {
328        // Frame body: gzip-compressed prost-encoded `Vec<u8>` of b"hi".
329        let inner = [0x0Au8, 0x02, b'h', b'i'];
330        let compressed = Encoding::Gzip.compress(&inner).unwrap();
331
332        let mut body = Vec::new();
333        body.push(1); // compressed flag
334        body.extend_from_slice(&(compressed.len() as u32).to_be_bytes());
335        body.extend_from_slice(&compressed);
336
337        let mut s = new_stream(&body[..]).with_encoding(Encoding::Gzip);
338        let msg = block_on(s.next()).unwrap().unwrap();
339        assert_eq!(msg, b"hi");
340    }
341
342    #[test]
343    fn codec_decode_failure_propagates_invalid_argument() {
344        // 0xFF is an invalid prost tag (non-terminating varint).
345        let body = frame(&[0xFF, 0xFF, 0xFF, 0xFF]);
346        let mut s = new_stream(&body[..]);
347        let err = block_on(s.next()).unwrap().unwrap_err();
348        assert_eq!(err.code, Code::InvalidArgument);
349    }
350}