actix / actix-extras
1
#![deny(rust_2018_idioms)]
2

3
use derive_more::Display;
4
use std::fmt;
5
use std::future::Future;
6
use std::ops::{Deref, DerefMut};
7
use std::pin::Pin;
8
use std::task;
9
use std::task::Poll;
10

11
use prost::DecodeError as ProtoBufDecodeError;
12
use prost::EncodeError as ProtoBufEncodeError;
13
use prost::Message;
14

15
use actix_web::dev::{HttpResponseBuilder, Payload};
16
use actix_web::error::{Error, PayloadError, ResponseError};
17
use actix_web::http::header::{CONTENT_LENGTH, CONTENT_TYPE};
18
use actix_web::web::BytesMut;
19
use actix_web::{FromRequest, HttpMessage, HttpRequest, HttpResponse, Responder};
20
use futures_util::future::{FutureExt, LocalBoxFuture};
21
use futures_util::StreamExt;
22

23
#[derive(Debug, Display)]
24
pub enum ProtoBufPayloadError {
25
    /// Payload size is bigger than 256k
26
    #[display(fmt = "Payload size is bigger than 256k")]
27
    Overflow,
28
    /// Content type error
29
    #[display(fmt = "Content type error")]
30
    ContentType,
31
    /// Serialize error
32
    #[display(fmt = "ProtoBuf serialize error: {}", _0)]
33
    Serialize(ProtoBufEncodeError),
34
    /// Deserialize error
35
    #[display(fmt = "ProtoBuf deserialize error: {}", _0)]
36
    Deserialize(ProtoBufDecodeError),
37
    /// Payload error
38
    #[display(fmt = "Error that occur during reading payload: {}", _0)]
39
    Payload(PayloadError),
40
}
41

42
impl ResponseError for ProtoBufPayloadError {
43 0
    fn error_response(&self) -> HttpResponse {
44 0
        match *self {
45 0
            ProtoBufPayloadError::Overflow => HttpResponse::PayloadTooLarge().into(),
46 0
            _ => HttpResponse::BadRequest().into(),
47
        }
48
    }
49
}
50

51
impl From<PayloadError> for ProtoBufPayloadError {
52 0
    fn from(err: PayloadError) -> ProtoBufPayloadError {
53 0
        ProtoBufPayloadError::Payload(err)
54
    }
55
}
56

57
impl From<ProtoBufDecodeError> for ProtoBufPayloadError {
58 0
    fn from(err: ProtoBufDecodeError) -> ProtoBufPayloadError {
59 0
        ProtoBufPayloadError::Deserialize(err)
60
    }
61
}
62

63
pub struct ProtoBuf<T: Message>(pub T);
64

65
impl<T: Message> Deref for ProtoBuf<T> {
66
    type Target = T;
67

68 0
    fn deref(&self) -> &T {
69 0
        &self.0
70
    }
71
}
72

73
impl<T: Message> DerefMut for ProtoBuf<T> {
74 0
    fn deref_mut(&mut self) -> &mut T {
75 0
        &mut self.0
76
    }
77
}
78

79
impl<T: Message> fmt::Debug for ProtoBuf<T>
80
where
81
    T: fmt::Debug,
82
{
83 0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 0
        write!(f, "ProtoBuf: {:?}", self.0)
85
    }
86
}
87

88
impl<T: Message> fmt::Display for ProtoBuf<T>
89
where
90
    T: fmt::Display,
91
{
92 0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 0
        fmt::Display::fmt(&self.0, f)
94
    }
95
}
96

97
pub struct ProtoBufConfig {
98
    limit: usize,
99
}
100

101
impl ProtoBufConfig {
102
    /// Change max size of payload. By default max size is 256Kb
103 0
    pub fn limit(&mut self, limit: usize) -> &mut Self {
104 0
        self.limit = limit;
105 0
        self
106
    }
107
}
108

109
impl Default for ProtoBufConfig {
110 0
    fn default() -> Self {
111
        ProtoBufConfig { limit: 262_144 }
112
    }
113
}
114

115
impl<T> FromRequest for ProtoBuf<T>
116
where
117
    T: Message + Default + 'static,
118
{
119
    type Config = ProtoBufConfig;
120
    type Error = Error;
121
    type Future = LocalBoxFuture<'static, Result<Self, Error>>;
122

123
    #[inline]
124 0
    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
125 0
        let limit = req
126
            .app_data::<ProtoBufConfig>()
127 0
            .map(|c| c.limit)
128
            .unwrap_or(262_144);
129 0
        ProtoBufMessage::new(req, payload)
130 0
            .limit(limit)
131 0
            .map(move |res| match res {
132 0
                Err(e) => Err(e.into()),
133 0
                Ok(item) => Ok(ProtoBuf(item)),
134
            })
135
            .boxed_local()
136
    }
137
}
138

139
impl<T: Message + Default> Responder for ProtoBuf<T> {
140 1
    fn respond_to(self, _: &HttpRequest) -> HttpResponse {
141 1
        let mut buf = Vec::new();
142 1
        match self.0.encode(&mut buf) {
143 1
            Ok(()) => HttpResponse::Ok()
144
                .content_type("application/protobuf")
145 1
                .body(buf),
146 0
            Err(err) => HttpResponse::from_error(Error::from(
147 0
                ProtoBufPayloadError::Serialize(err),
148
            )),
149
        }
150
    }
151
}
152

153
pub struct ProtoBufMessage<T: Message + Default> {
154
    limit: usize,
155
    length: Option<usize>,
156
    stream: Option<Payload>,
157
    err: Option<ProtoBufPayloadError>,
158
    fut: Option<LocalBoxFuture<'static, Result<T, ProtoBufPayloadError>>>,
159
}
160

161
impl<T: Message + Default> ProtoBufMessage<T> {
162
    /// Create `ProtoBufMessage` for request.
163 1
    pub fn new(req: &HttpRequest, payload: &mut Payload) -> Self {
164 1
        if req.content_type() != "application/protobuf" {
165 1
            return ProtoBufMessage {
166 0
                limit: 262_144,
167 1
                length: None,
168 1
                stream: None,
169 1
                fut: None,
170 1
                err: Some(ProtoBufPayloadError::ContentType),
171
            };
172
        }
173

174 1
        let mut len = None;
175 1
        if let Some(l) = req.headers().get(CONTENT_LENGTH) {
176 1
            if let Ok(s) = l.to_str() {
177 1
                if let Ok(l) = s.parse::<usize>() {
178 1
                    len = Some(l)
179
                }
180
            }
181
        }
182

183
        ProtoBufMessage {
184
            limit: 262_144,
185
            length: len,
186 1
            stream: Some(payload.take()),
187
            fut: None,
188
            err: None,
189
        }
190
    }
191

192
    /// Change max size of payload. By default max size is 256Kb
193 1
    pub fn limit(mut self, limit: usize) -> Self {
194 1
        self.limit = limit;
195 1
        self
196
    }
197
}
198

199
impl<T: Message + Default + 'static> Future for ProtoBufMessage<T> {
200
    type Output = Result<T, ProtoBufPayloadError>;
201

202 1
    fn poll(
203
        mut self: Pin<&mut Self>,
204
        task: &mut task::Context<'_>,
205
    ) -> Poll<Self::Output> {
206 1
        if let Some(ref mut fut) = self.fut {
207 0
            return Pin::new(fut).poll(task);
208
        }
209

210 1
        if let Some(err) = self.err.take() {
211 1
            return Poll::Ready(Err(err));
212
        }
213

214 1
        let limit = self.limit;
215 1
        if let Some(len) = self.length.take() {
216 1
            if len > limit {
217 1
                return Poll::Ready(Err(ProtoBufPayloadError::Overflow));
218
            }
219
        }
220

221 0
        let mut stream = self
222 0
            .stream
223
            .take()
224
            .expect("ProtoBufMessage could not be used second time");
225

226 0
        self.fut = Some(
227 0
            async move {
228 0
                let mut body = BytesMut::with_capacity(8192);
229

230 0
                while let Some(item) = stream.next().await {
231 0
                    let chunk = item?;
232 0
                    if (body.len() + chunk.len()) > limit {
233 0
                        return Err(ProtoBufPayloadError::Overflow);
234
                    } else {
235 0
                        body.extend_from_slice(&chunk);
236
                    }
237
                }
238

239 0
                Ok(<T>::decode(&mut body)?)
240
            }
241 0
            .boxed_local(),
242
        );
243 0
        self.poll(task)
244
    }
245
}
246

247
pub trait ProtoBufResponseBuilder {
248
    fn protobuf<T: Message>(&mut self, value: T) -> Result<HttpResponse, Error>;
249
}
250

251
impl ProtoBufResponseBuilder for HttpResponseBuilder {
252 0
    fn protobuf<T: Message>(&mut self, value: T) -> Result<HttpResponse, Error> {
253 0
        self.insert_header((CONTENT_TYPE, "application/protobuf"));
254

255 0
        let mut body = Vec::new();
256 0
        value
257 0
            .encode(&mut body)
258 0
            .map_err(ProtoBufPayloadError::Serialize)?;
259 0
        Ok(self.body(body))
260
    }
261
}
262

263
#[cfg(test)]
264
mod tests {
265
    use super::*;
266
    use actix_web::http::header;
267
    use actix_web::test::TestRequest;
268

269
    impl PartialEq for ProtoBufPayloadError {
270 1
        fn eq(&self, other: &ProtoBufPayloadError) -> bool {
271 0
            match *self {
272 1
                ProtoBufPayloadError::Overflow => {
273 1
                    matches!(*other, ProtoBufPayloadError::Overflow)
274
                }
275 0
                ProtoBufPayloadError::ContentType => {
276 1
                    matches!(*other, ProtoBufPayloadError::ContentType)
277
                }
278 0
                _ => false,
279
            }
280
        }
281
    }
282

283
    #[derive(Clone, PartialEq, Message)]
284
    pub struct MyObject {
285
        #[prost(int32, tag = "1")]
286
        pub number: i32,
287
        #[prost(string, tag = "2")]
288
        pub name: String,
289
    }
290

291 1
    #[actix_rt::test]
292 1
    async fn test_protobuf() {
293 1
        let protobuf = ProtoBuf(MyObject {
294
            number: 9,
295 1
            name: "test".to_owned(),
296
        });
297 1
        let req = TestRequest::default().to_http_request();
298 1
        let resp = protobuf.respond_to(&req).await.unwrap();
299 1
        let ct = resp.headers().get(header::CONTENT_TYPE).unwrap();
300 1
        assert_eq!(ct, "application/protobuf");
301
    }
302

303 1
    #[actix_rt::test]
304 1
    async fn test_protobuf_message() {
305 1
        let (req, mut pl) = TestRequest::default().to_http_parts();
306 1
        let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl).await;
307 1
        assert_eq!(protobuf.err().unwrap(), ProtoBufPayloadError::ContentType);
308

309 1
        let (req, mut pl) = TestRequest::get()
310 1
            .insert_header((header::CONTENT_TYPE, "application/text"))
311
            .to_http_parts();
312 1
        let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl).await;
313 1
        assert_eq!(protobuf.err().unwrap(), ProtoBufPayloadError::ContentType);
314

315 1
        let (req, mut pl) = TestRequest::get()
316 1
            .insert_header((header::CONTENT_TYPE, "application/protobuf"))
317 1
            .insert_header((header::CONTENT_LENGTH, "10000"))
318
            .to_http_parts();
319 1
        let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl)
320
            .limit(100)
321
            .await;
322 1
        assert_eq!(protobuf.err().unwrap(), ProtoBufPayloadError::Overflow);
323
    }
324
}

Read our documentation on viewing source code .

Loading