1use crate::{Encoding, Status};
4use futures_lite::{AsyncRead, Stream};
5use std::{
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10pub use crate::encoding::DEFAULT_MAX_MESSAGE_SIZE;
11
12const PREFIX_LEN: usize = 5;
15
16pub struct MessageStream<T, R> {
27 reader: R,
28 state: ReadState,
29 max_message_size: usize,
30 encoding: Encoding,
31 decode: fn(&[u8]) -> Result<T, Status>,
32}
33
34pub(crate) enum ReadState {
35 ReadingPrefix {
36 buf: [u8; PREFIX_LEN],
37 filled: usize,
38 },
39 ReadingPayload {
40 compressed: bool,
41 payload: Vec<u8>,
42 filled: usize,
43 },
44 Done,
45}
46
47impl ReadState {
48 pub(crate) fn new() -> Self {
49 Self::ReadingPrefix {
50 buf: [0u8; PREFIX_LEN],
51 filled: 0,
52 }
53 }
54}
55
56impl<T, R> MessageStream<T, R> {
57 pub fn new(reader: R, decode: fn(&[u8]) -> Result<T, Status>) -> Self {
63 Self {
64 reader,
65 state: ReadState::new(),
66 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
67 encoding: Encoding::Identity,
68 decode,
69 }
70 }
71
72 pub fn with_max_message_size(mut self, max: usize) -> Self {
75 self.max_message_size = max;
76 self
77 }
78
79 pub fn with_encoding(mut self, encoding: Encoding) -> Self {
83 self.encoding = encoding;
84 self
85 }
86}
87
88impl<T, R> Stream for MessageStream<T, R>
89where
90 R: AsyncRead + Unpin,
91 T: 'static,
92{
93 type Item = Result<T, Status>;
94
95 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
96 let this = self.get_mut();
97 poll_read_message(
98 Pin::new(&mut this.reader),
99 &mut this.state,
100 cx,
101 this.decode,
102 this.encoding,
103 this.max_message_size,
104 )
105 }
106}
107
108pub(crate) fn poll_read_message<T, R>(
118 mut reader: Pin<&mut R>,
119 state: &mut ReadState,
120 cx: &mut Context<'_>,
121 decode: fn(&[u8]) -> Result<T, Status>,
122 encoding: Encoding,
123 max_message_size: usize,
124) -> Poll<Option<Result<T, Status>>>
125where
126 R: AsyncRead + ?Sized,
127{
128 loop {
129 match state {
130 ReadState::Done => return Poll::Ready(None),
131
132 ReadState::ReadingPrefix { buf, filled } => {
133 while *filled < PREFIX_LEN {
134 let dst = &mut buf[*filled..];
135 match reader.as_mut().poll_read(cx, dst) {
136 Poll::Pending => return Poll::Pending,
137 Poll::Ready(Err(e)) => {
138 *state = ReadState::Done;
139 return Poll::Ready(Some(Err(Status::unavailable(format!(
140 "read error: {e}"
141 )))));
142 }
143 Poll::Ready(Ok(0)) => {
144 if *filled == 0 {
145 *state = ReadState::Done;
146 return Poll::Ready(None);
147 } else {
148 *state = ReadState::Done;
149 return Poll::Ready(Some(Err(Status::internal(
150 "unexpected EOF in frame prefix",
151 ))));
152 }
153 }
154 Poll::Ready(Ok(n)) => *filled += n,
155 }
156 }
157
158 let compressed = buf[0] != 0;
159 let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
160
161 if len > max_message_size {
162 *state = ReadState::Done;
163 return Poll::Ready(Some(Err(Status::resource_exhausted(format!(
164 "received message of {len} bytes exceeds limit of {max_message_size}"
165 )))));
166 }
167
168 *state = ReadState::ReadingPayload {
169 compressed,
170 payload: vec![0u8; len],
171 filled: 0,
172 };
173 }
174
175 ReadState::ReadingPayload {
176 compressed,
177 payload,
178 filled,
179 } => {
180 while *filled < payload.len() {
181 let dst = &mut payload[*filled..];
182 match reader.as_mut().poll_read(cx, dst) {
183 Poll::Pending => return Poll::Pending,
184 Poll::Ready(Err(e)) => {
185 *state = ReadState::Done;
186 return Poll::Ready(Some(Err(Status::unavailable(format!(
187 "read error: {e}"
188 )))));
189 }
190 Poll::Ready(Ok(0)) => {
191 *state = ReadState::Done;
192 return Poll::Ready(Some(Err(Status::internal(
193 "unexpected EOF in frame payload",
194 ))));
195 }
196 Poll::Ready(Ok(n)) => *filled += n,
197 }
198 }
199
200 let compressed = *compressed;
201 let payload = std::mem::take(payload);
202 *state = ReadState::new();
203
204 let bytes = if compressed {
205 if matches!(encoding, Encoding::Identity) {
206 return Poll::Ready(Some(Err(Status::internal(
207 "received compressed message but no encoding negotiated",
208 ))));
209 }
210 match encoding.decompress(&payload, max_message_size) {
211 Ok(b) => b,
212 Err(status) => return Poll::Ready(Some(Err(status))),
213 }
214 } else {
215 payload
216 };
217
218 return Poll::Ready(Some(decode(&bytes)));
219 }
220 }
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::{Code, Codec, codec::Prost};
228 use futures_lite::{StreamExt, future::block_on};
229
230 fn frame(payload: &[u8]) -> Vec<u8> {
232 let mut out = Vec::with_capacity(PREFIX_LEN + payload.len());
233 out.push(0); out.extend_from_slice(&(payload.len() as u32).to_be_bytes());
235 out.extend_from_slice(payload);
236 out
237 }
238
239 fn vec_decode(bytes: &[u8]) -> Result<Vec<u8>, Status> {
240 <Prost as Codec<Vec<u8>>>::decode(bytes)
241 }
242
243 type Stream<'a> = MessageStream<Vec<u8>, &'a [u8]>;
244
245 fn new_stream(bytes: &[u8]) -> Stream<'_> {
246 MessageStream::new(bytes, vec_decode)
247 }
248
249 #[test]
250 fn empty_input_yields_none() {
251 let bytes: &[u8] = &[];
252 let mut s = new_stream(bytes);
253 assert!(block_on(s.next()).is_none());
254 }
255
256 #[test]
257 fn single_empty_message() {
258 let body = frame(&[]);
259 let mut s = new_stream(&body[..]);
260 let msg = block_on(s.next()).unwrap().unwrap();
262 assert!(msg.is_empty());
263 assert!(block_on(s.next()).is_none());
264 }
265
266 #[test]
267 fn multiple_messages() {
268 let mut body = Vec::new();
272 body.extend_from_slice(&frame(&[0x0A, 0x02, b'h', b'i']));
273 body.extend_from_slice(&frame(&[0x0A, 0x03, b'b', b'y', b'e']));
274
275 let mut s = new_stream(&body[..]);
276 let m1 = block_on(s.next()).unwrap().unwrap();
277 let m2 = block_on(s.next()).unwrap().unwrap();
278 assert_eq!(m1, b"hi");
279 assert_eq!(m2, b"bye");
280 assert!(block_on(s.next()).is_none());
281 }
282
283 #[test]
284 fn partial_prefix_at_eof_is_error() {
285 let body = [0u8, 0u8, 0u8]; let mut s = new_stream(&body[..]);
287 let err = block_on(s.next()).unwrap().unwrap_err();
288 assert_eq!(err.code, Code::Internal);
289 assert!(block_on(s.next()).is_none());
290 }
291
292 #[test]
293 fn partial_payload_at_eof_is_error() {
294 let mut body = Vec::new();
295 body.push(0); body.extend_from_slice(&10u32.to_be_bytes()); body.extend_from_slice(&[1, 2, 3]); let mut s = new_stream(&body[..]);
299 let err = block_on(s.next()).unwrap().unwrap_err();
300 assert_eq!(err.code, Code::Internal);
301 }
302
303 #[test]
304 fn oversized_message_is_resource_exhausted() {
305 let mut body = Vec::new();
306 body.push(0);
307 body.extend_from_slice(&100u32.to_be_bytes());
308 let mut s = new_stream(&body[..]).with_max_message_size(50);
309 let err = block_on(s.next()).unwrap().unwrap_err();
310 assert_eq!(err.code, Code::ResourceExhausted);
311 }
312
313 #[test]
314 fn compressed_flag_with_identity_encoding_is_internal() {
315 let mut body = Vec::new();
318 body.push(1); body.extend_from_slice(&0u32.to_be_bytes());
320 let mut s = new_stream(&body[..]);
321 let err = block_on(s.next()).unwrap().unwrap_err();
322 assert_eq!(err.code, Code::Internal);
323 }
324
325 #[cfg(feature = "gzip")]
326 #[test]
327 fn compressed_frame_decompressed_with_gzip() {
328 let inner = [0x0Au8, 0x02, b'h', b'i'];
330 let compressed = Encoding::Gzip.compress(&inner).unwrap();
331
332 let mut body = Vec::new();
333 body.push(1); body.extend_from_slice(&(compressed.len() as u32).to_be_bytes());
335 body.extend_from_slice(&compressed);
336
337 let mut s = new_stream(&body[..]).with_encoding(Encoding::Gzip);
338 let msg = block_on(s.next()).unwrap().unwrap();
339 assert_eq!(msg, b"hi");
340 }
341
342 #[test]
343 fn codec_decode_failure_propagates_invalid_argument() {
344 let body = frame(&[0xFF, 0xFF, 0xFF, 0xFF]);
346 let mut s = new_stream(&body[..]);
347 let err = block_on(s.next()).unwrap().unwrap_err();
348 assert_eq!(err.code, Code::InvalidArgument);
349 }
350}