actix / actix-extras
1
use std::{collections::HashSet, convert::TryInto, iter::FromIterator, rc::Rc};
2

3
use actix_web::{
4
    dev::{RequestHead, Service, ServiceRequest, ServiceResponse, Transform},
5
    error::{Error, Result},
6
    http::{self, header::HeaderName, Error as HttpError, HeaderValue, Method, Uri},
7
    Either,
8
};
9
use futures_util::future::{self, Ready};
10
use log::error;
11
use once_cell::sync::Lazy;
12
use smallvec::smallvec;
13

14
use crate::{AllOrSome, CorsError, CorsMiddleware, Inner, OriginFn};
15

16
/// Convenience for getting mut refs to inner. Cleaner than `Rc::get_mut`.
17
/// Additionally, always causes first error (if any) to be reported during initialization.
18 1
fn cors<'a>(
19
    inner: &'a mut Rc<Inner>,
20
    err: &Option<Either<http::Error, CorsError>>,
21
) -> Option<&'a mut Inner> {
22 1
    if err.is_some() {
23 0
        return None;
24
    }
25

26 1
    Rc::get_mut(inner)
27
}
28

29 1
static ALL_METHODS_SET: Lazy<HashSet<Method>> = Lazy::new(|| {
30 1
    HashSet::from_iter(vec![
31
        Method::GET,
32
        Method::POST,
33
        Method::PUT,
34
        Method::DELETE,
35
        Method::HEAD,
36
        Method::OPTIONS,
37
        Method::CONNECT,
38
        Method::PATCH,
39
        Method::TRACE,
40
    ])
41
});
42

43
/// Builder for CORS middleware.
44
///
45
/// To construct a CORS middleware, call [`Cors::default()`] to create a blank, restrictive builder.
46
/// Then use any of the builder methods to customize CORS behavior.
47
///
48
/// The alternative [`Cors::permissive()`] constructor is available for local development, allowing
49
/// all origins and headers, etc. **The permissive constructor should not be used in production.**
50
///
51
/// # Errors
52
/// Errors surface in the middleware initialization phase. This means that, if you have logs enabled
53
/// in Actix Web (using `env_logger` or other crate that exposes logs from the `log` crate), error
54
/// messages will outline what is wrong with the CORS configuration in the server logs and the
55
/// server will fail to start up or serve requests.
56
///
57
/// # Example
58
/// ```rust
59
/// use actix_cors::Cors;
60
/// use actix_web::http::header;
61
///
62
/// let cors = Cors::default()
63
///     .allowed_origin("https://www.rust-lang.org")
64
///     .allowed_methods(vec!["GET", "POST"])
65
///     .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
66
///     .allowed_header(header::CONTENT_TYPE)
67
///     .max_age(3600);
68
///
69
/// // `cors` can now be used in `App::wrap`.
70
/// ```
71
#[derive(Debug)]
72
pub struct Cors {
73
    inner: Rc<Inner>,
74
    error: Option<Either<http::Error, CorsError>>,
75
}
76

77
impl Cors {
78
    /// A very permissive set of default for quick development. Not recommended for production use.
79
    ///
80
    /// *All* origins, methods, request headers and exposed headers allowed. Credentials supported.
81
    /// Max age 1 hour. Does not send wildcard.
82 1
    pub fn permissive() -> Self {
83
        let inner = Inner {
84
            allowed_origins: AllOrSome::All,
85 1
            allowed_origins_fns: smallvec![],
86

87 1
            allowed_methods: ALL_METHODS_SET.clone(),
88
            allowed_methods_baked: None,
89

90
            allowed_headers: AllOrSome::All,
91
            allowed_headers_baked: None,
92

93
            expose_headers: AllOrSome::All,
94
            expose_headers_baked: None,
95 1
            max_age: Some(3600),
96
            preflight: true,
97
            send_wildcard: false,
98
            supports_credentials: true,
99
            vary_header: true,
100
        };
101

102
        Cors {
103 1
            inner: Rc::new(inner),
104
            error: None,
105
        }
106
    }
107

108
    /// Resets allowed origin list to a state where any origin is accepted.
109
    ///
110
    /// See [`Cors::allowed_origin`] for more info on allowed origins.
111 1
    pub fn allow_any_origin(mut self) -> Cors {
112 1
        if let Some(cors) = cors(&mut self.inner, &self.error) {
113 1
            cors.allowed_origins = AllOrSome::All;
114
        }
115

116 1
        self
117
    }
118

119
    /// Add an origin that is allowed to make requests.
120
    ///
121
    /// By default, requests from all origins are accepted by CORS logic. This method allows to
122
    /// specify a finite set of origins to verify the value of the `Origin` request header.
123
    ///
124
    /// These are `origin-or-null` types in the [Fetch Standard].
125
    ///
126
    /// When this list is set, the client's `Origin` request header will be checked in a
127
    /// case-sensitive manner.
128
    ///
129
    /// When all origins are allowed and `send_wildcard` is set, `*` will be sent in the
130
    /// `Access-Control-Allow-Origin` response header. If `send_wildcard` is not set, the client's
131
    /// `Origin` request header will be echoed back in the `Access-Control-Allow-Origin`
132
    /// response header.
133
    ///
134
    /// If the origin of the request doesn't match any allowed origins and at least one
135
    /// `allowed_origin_fn` function is set, these functions will be used to determinate
136
    /// allowed origins.
137
    ///
138
    /// # Initialization Errors
139
    /// - If supplied origin is not valid uri
140
    /// - If supplied origin is a wildcard (`*`). [`Cors::send_wildcard`] should be used instead.
141
    ///
142
    /// [Fetch Standard]: https://fetch.spec.whatwg.org/#origin-header
143 1
    pub fn allowed_origin(mut self, origin: &str) -> Cors {
144 1
        if let Some(cors) = cors(&mut self.inner, &self.error) {
145 1
            match TryInto::<Uri>::try_into(origin) {
146 1
                Ok(_) if origin == "*" => {
147 1
                    error!("Wildcard in `allowed_origin` is not allowed. Use `send_wildcard`.");
148 1
                    self.error = Some(Either::Right(CorsError::WildcardOrigin));
149
                }
150

151 0
                Ok(_) => {
152 1
                    if cors.allowed_origins.is_all() {
153 0
                        cors.allowed_origins =
154 0
                            AllOrSome::Some(HashSet::with_capacity(8));
155
                    }
156

157 1
                    if let Some(origins) = cors.allowed_origins.as_mut() {
158
                        // any uri is a valid header value
159 1
                        let hv = origin.try_into().unwrap();
160 1
                        origins.insert(hv);
161
                    }
162
                }
163

164 0
                Err(err) => {
165 0
                    self.error = Some(Either::Left(err.into()));
166
                }
167
            }
168
        }
169

170 1
        self
171
    }
172

173
    /// Determinate allowed origins by processing requests which didn't match any origins specified
174
    /// in the `allowed_origin`.
175
    ///
176
    /// The function will receive two parameters, the Origin header value, and the `RequestHead` of
177
    /// each request, which can be used to determine whether to allow the request or not.
178
    ///
179
    /// If the function returns `true`, the client's `Origin` request header will be echoed back
180
    /// into the `Access-Control-Allow-Origin` response header.
181 1
    pub fn allowed_origin_fn<F>(mut self, f: F) -> Cors
182
    where
183
        F: (Fn(&HeaderValue, &RequestHead) -> bool) + 'static,
184
    {
185 1
        if let Some(cors) = cors(&mut self.inner, &self.error) {
186 1
            cors.allowed_origins_fns.push(OriginFn {
187 1
                boxed_fn: Rc::new(f),
188
            });
189
        }
190

191 1
        self
192
    }
193

194
    /// Resets allowed methods list to all methods.
195
    ///
196
    /// See [`Cors::allowed_methods`] for more info on allowed methods.
197 1
    pub fn allow_any_method(mut self) -> Cors {
198 1
        if let Some(cors) = cors(&mut self.inner, &self.error) {
199 1
            cors.allowed_methods = ALL_METHODS_SET.clone();
200
        }
201

202 1
        self
203
    }
204

205
    /// Set a list of methods which allowed origins can perform.
206
    ///
207
    /// These will be sent in the `Access-Control-Allow-Methods` response header as specified in
208
    /// the [Fetch Standard CORS protocol].
209
    ///
210
    /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]`
211
    ///
212
    /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol
213 1
    pub fn allowed_methods<U, M>(mut self, methods: U) -> Cors
214
    where
215
        U: IntoIterator<Item = M>,
216
        M: TryInto<Method>,
217
        <M as TryInto<Method>>::Error: Into<HttpError>,
218
    {
219 1
        if let Some(cors) = cors(&mut self.inner, &self.error) {
220 1
            for m in methods {
221 1
                match m.try_into() {
222 1
                    Ok(method) => {
223 1
                        cors.allowed_methods.insert(method);
224
                    }
225

226 0
                    Err(err) => {
227 0
                        self.error = Some(Either::Left(err.into()));
228 0
                        break;
229
                    }
230
                }
231
            }
232
        }
233

234 1
        self
235
    }
236

237
    /// Resets allowed request header list to a state where any header is accepted.
238
    ///
239
    /// See [`Cors::allowed_headers`] for more info on allowed request headers.
240 1
    pub fn allow_any_header(mut self) -> Cors {
241 1
        if let Some(cors) = cors(&mut self.inner, &self.error) {
242 1
            cors.allowed_headers = AllOrSome::All;
243
        }
244

245 1
        self
246
    }
247

248
    /// Add an allowed request header.
249
    ///
250
    /// See [`Cors::allowed_headers`] for more info on allowed request headers.
251 1
    pub fn allowed_header<H>(mut self, header: H) -> Cors
252
    where
253
        H: TryInto<HeaderName>,
254
        <H as TryInto<HeaderName>>::Error: Into<HttpError>,
255
    {
256 1
        if let Some(cors) = cors(&mut self.inner, &self.error) {
257 1
            match header.try_into() {
258 1
                Ok(method) => {
259 1
                    if cors.allowed_headers.is_all() {
260 0
                        cors.allowed_headers =
261 0
                            AllOrSome::Some(HashSet::with_capacity(8));
262
                    }
263

264 1
                    if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
265 1
                        headers.insert(method);
266
                    }
267
                }
268

269 0
                Err(err) => self.error = Some(Either::Left(err.into())),
270
            }
271
        }
272

273 1
        self
274
    }
275

276
    /// Set a list of request header field names which can be used when this resource is accessed by
277
    /// allowed origins.
278
    ///
279
    /// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers`
280
    /// will be echoed back in the `Access-Control-Allow-Headers` header as specified in
281
    /// the [Fetch Standard CORS protocol].
282
    ///
283
    /// Defaults to `All`.
284
    ///
285
    /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol
286 1
    pub fn allowed_headers<U, H>(mut self, headers: U) -> Cors
287
    where
288
        U: IntoIterator<Item = H>,
289
        H: TryInto<HeaderName>,
290
        <H as TryInto<HeaderName>>::Error: Into<HttpError>,
291
    {
292 1
        if let Some(cors) = cors(&mut self.inner, &self.error) {
293 1
            for h in headers {
294 1
                match h.try_into() {
295 1
                    Ok(method) => {
296 1
                        if cors.allowed_headers.is_all() {
297 0
                            cors.allowed_headers =
298 0
                                AllOrSome::Some(HashSet::with_capacity(8));
299
                        }
300

301 1
                        if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
302 1
                            headers.insert(method);
303
                        }
304
                    }
305 0
                    Err(err) => {
306 0
                        self.error = Some(Either::Left(err.into()));
307 0
                        break;
308
                    }
309
                }
310
            }
311
        }
312

313 1
        self
314
    }
315

316
    /// Resets exposed response header list to a state where any header is accepted.
317
    ///
318
    /// See [`Cors::expose_headers`] for more info on exposed response headers.
319 0
    pub fn expose_any_header(mut self) -> Cors {
320 0
        if let Some(cors) = cors(&mut self.inner, &self.error) {
321 0
            cors.expose_headers = AllOrSome::All;
322
        }
323

324 0
        self
325
    }
326

327
    /// Set a list of headers which are safe to expose to the API of a CORS API specification.
328
    /// This corresponds to the `Access-Control-Expose-Headers` response header as specified in
329
    /// the [Fetch Standard CORS protocol].
330
    ///
331
    /// This defaults to an empty set.
332
    ///
333
    /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol
334 1
    pub fn expose_headers<U, H>(mut self, headers: U) -> Cors
335
    where
336
        U: IntoIterator<Item = H>,
337
        H: TryInto<HeaderName>,
338
        <H as TryInto<HeaderName>>::Error: Into<HttpError>,
339
    {
340 1
        for h in headers {
341 1
            match h.try_into() {
342 1
                Ok(header) => {
343 1
                    if let Some(cors) = cors(&mut self.inner, &self.error) {
344 1
                        if cors.expose_headers.is_all() {
345 0
                            cors.expose_headers =
346 0
                                AllOrSome::Some(HashSet::with_capacity(8));
347
                        }
348 1
                        if let AllOrSome::Some(ref mut headers) = cors.expose_headers {
349 1
                            headers.insert(header);
350
                        }
351
                    }
352
                }
353 0
                Err(err) => {
354 0
                    self.error = Some(Either::Left(err.into()));
355 0
                    break;
356
                }
357
            }
358
        }
359

360 1
        self
361
    }
362

363
    /// Set a maximum time (in seconds) for which this CORS request maybe cached.
364
    /// This value is set as the `Access-Control-Max-Age` header as specified in
365
    /// the [Fetch Standard CORS protocol].
366
    ///
367
    /// Pass a number (of seconds) or use None to disable sending max age header.
368
    ///
369
    /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol
370 1
    pub fn max_age(mut self, max_age: impl Into<Option<usize>>) -> Cors {
371 1
        if let Some(cors) = cors(&mut self.inner, &self.error) {
372 1
            cors.max_age = max_age.into()
373
        }
374

375 1
        self
376
    }
377

378
    /// Set to use wildcard origins.
379
    ///
380
    /// If send wildcard is set and the `allowed_origins` parameter is `All`, a wildcard
381
    /// `Access-Control-Allow-Origin` response header is sent, rather than the request’s
382
    /// `Origin` header.
383
    ///
384
    /// This **CANNOT** be used in conjunction with `allowed_origins` set to `All` and
385
    /// `allow_credentials` set to `true`. Depending on the mode of usage, this will either result
386
    /// in an `CorsError::CredentialsWithWildcardOrigin` error during actix launch or runtime.
387
    ///
388
    /// Defaults to `false`.
389 1
    pub fn send_wildcard(mut self) -> Cors {
390 1
        if let Some(cors) = cors(&mut self.inner, &self.error) {
391 1
            cors.send_wildcard = true
392
        }
393

394 1
        self
395
    }
396

397
    /// Allows users to make authenticated requests
398
    ///
399
    /// If true, injects the `Access-Control-Allow-Credentials` header in responses. This allows
400
    /// cookies and credentials to be submitted across domains as specified in
401
    /// the [Fetch Standard CORS protocol].
402
    ///
403
    /// This option cannot be used in conjunction with an `allowed_origin` set to `All` and
404
    /// `send_wildcards` set to `true`.
405
    ///
406
    /// Defaults to `false`.
407
    ///
408
    /// A server initialization error will occur if credentials are allowed, but the Origin is set
409
    /// to send wildcards (`*`); this is not allowed by the CORS protocol.
410
    ///
411
    /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol
412 1
    pub fn supports_credentials(mut self) -> Cors {
413 1
        if let Some(cors) = cors(&mut self.inner, &self.error) {
414 1
            cors.supports_credentials = true
415
        }
416

417 1
        self
418
    }
419

420
    /// Disable `Vary` header support.
421
    ///
422
    /// When enabled the header `Vary: Origin` will be returned as per the Fetch Standard
423
    /// implementation guidelines.
424
    ///
425
    /// Setting this header when the `Access-Control-Allow-Origin` is dynamically generated
426
    /// (eg. when there is more than one allowed origin, and an Origin other than '*' is returned)
427
    /// informs CDNs and other caches that the CORS headers are dynamic, and cannot be cached.
428
    ///
429
    /// By default, `Vary` header support is enabled.
430 1
    pub fn disable_vary_header(mut self) -> Cors {
431 1
        if let Some(cors) = cors(&mut self.inner, &self.error) {
432 1
            cors.vary_header = false
433
        }
434

435 1
        self
436
    }
437

438
    /// Disable support for preflight requests.
439
    ///
440
    /// When enabled CORS middleware automatically handles `OPTIONS` requests.
441
    /// This is useful for application level middleware.
442
    ///
443
    /// By default *preflight* support is enabled.
444 1
    pub fn disable_preflight(mut self) -> Cors {
445 1
        if let Some(cors) = cors(&mut self.inner, &self.error) {
446 1
            cors.preflight = false
447
        }
448

449 1
        self
450
    }
451
}
452

453
impl Default for Cors {
454
    /// A restrictive (security paranoid) set of defaults.
455
    ///
456
    /// *No* allowed origins, methods, request headers or exposed headers. Credentials
457
    /// not supported. No max age (will use browser's default).
458 1
    fn default() -> Cors {
459
        let inner = Inner {
460 1
            allowed_origins: AllOrSome::Some(HashSet::with_capacity(8)),
461 1
            allowed_origins_fns: smallvec![],
462

463 1
            allowed_methods: HashSet::with_capacity(8),
464
            allowed_methods_baked: None,
465

466 1
            allowed_headers: AllOrSome::Some(HashSet::with_capacity(8)),
467
            allowed_headers_baked: None,
468

469 1
            expose_headers: AllOrSome::Some(HashSet::with_capacity(8)),
470
            expose_headers_baked: None,
471

472
            max_age: None,
473
            preflight: true,
474
            send_wildcard: false,
475
            supports_credentials: false,
476
            vary_header: true,
477
        };
478

479
        Cors {
480 1
            inner: Rc::new(inner),
481
            error: None,
482
        }
483
    }
484
}
485

486
impl<S> Transform<S, ServiceRequest> for Cors
487
where
488
    S: Service<ServiceRequest, Response = ServiceResponse, Error = Error>,
489
    S::Future: 'static,
490
{
491
    type Response = ServiceResponse;
492
    type Error = Error;
493
    type InitError = ();
494
    type Transform = CorsMiddleware<S>;
495
    type Future = Ready<Result<Self::Transform, Self::InitError>>;
496

497 1
    fn new_transform(&self, service: S) -> Self::Future {
498 1
        if let Some(ref err) = self.error {
499 0
            match err {
500 1
                Either::Left(err) => error!("{}", err),
501 1
                Either::Right(err) => error!("{}", err),
502
            }
503

504 1
            return future::err(());
505
        }
506

507 1
        let mut inner = Rc::clone(&self.inner);
508

509 1
        if inner.supports_credentials
510 1
            && inner.send_wildcard
511 1
            && inner.allowed_origins.is_all()
512
        {
513 1
            error!("Illegal combination of CORS options: credentials can not be supported when all \
514
                    origins are allowed and `send_wildcard` is enabled.");
515 1
            return future::err(());
516
        }
517

518
        // bake allowed headers value if Some and not empty
519 1
        match inner.allowed_headers.as_ref() {
520 1
            Some(header_set) if !header_set.is_empty() => {
521 1
                let allowed_headers_str = intersperse_header_values(header_set);
522 1
                Rc::make_mut(&mut inner).allowed_headers_baked =
523 1
                    Some(allowed_headers_str);
524
            }
525 0
            _ => {}
526
        }
527

528
        // bake allowed methods value if not empty
529 1
        if !inner.allowed_methods.is_empty() {
530 1
            let allowed_methods_str = intersperse_header_values(&inner.allowed_methods);
531 1
            Rc::make_mut(&mut inner).allowed_methods_baked = Some(allowed_methods_str);
532
        }
533

534
        // bake exposed headers value if Some and not empty
535 1
        match inner.expose_headers.as_ref() {
536 1
            Some(header_set) if !header_set.is_empty() => {
537 1
                let expose_headers_str = intersperse_header_values(header_set);
538 1
                Rc::make_mut(&mut inner).expose_headers_baked = Some(expose_headers_str);
539
            }
540 0
            _ => {}
541
        }
542

543 1
        future::ok(CorsMiddleware { service, inner })
544
    }
545
}
546

547
/// Only call when values are guaranteed to be valid header values and set is not empty.
548 1
fn intersperse_header_values<T>(val_set: &HashSet<T>) -> HeaderValue
549
where
550
    T: AsRef<str>,
551
{
552 1
    val_set
553
        .iter()
554 1
        .fold(String::with_capacity(32), |mut acc, val| {
555 1
            acc.push_str(", ");
556 1
            acc.push_str(val.as_ref());
557 1
            acc
558
        })
559
        // set is not empty so string will always have leading ", " to trim
560 1
        [2..]
561
        .try_into()
562
        // all method names are valid header values
563
        .unwrap()
564
}
565

566
#[cfg(test)]
567
mod test {
568
    use std::convert::{Infallible, TryInto};
569

570
    use actix_web::{
571
        dev::Transform,
572
        http::{HeaderName, StatusCode},
573
        test::{self, TestRequest},
574
    };
575

576
    use super::*;
577

578
    #[test]
579 1
    fn illegal_allow_credentials() {
580
        // using the permissive defaults (all origins allowed) and adding send_wildcard
581
        // and supports_credentials should error on construction
582

583 1
        assert!(Cors::permissive()
584
            .supports_credentials()
585
            .send_wildcard()
586 1
            .new_transform(test::ok_service())
587
            .into_inner()
588
            .is_err());
589
    }
590

591 1
    #[actix_rt::test]
592 1
    async fn restrictive_defaults() {
593 1
        let cors = Cors::default()
594 1
            .new_transform(test::ok_service())
595
            .await
596
            .unwrap();
597

598 1
        let req = TestRequest::default()
599 1
            .insert_header(("Origin", "https://www.example.com"))
600
            .to_srv_request();
601

602 1
        let resp = test::call_service(&cors, req).await;
603 1
        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
604
    }
605

606 1
    #[actix_rt::test]
607 1
    async fn allowed_header_try_from() {
608 1
        let _cors = Cors::default().allowed_header("Content-Type");
609
    }
610

611 1
    #[actix_rt::test]
612 1
    async fn allowed_header_try_into() {
613
        struct ContentType;
614

615
        impl TryInto<HeaderName> for ContentType {
616
            type Error = Infallible;
617

618 1
            fn try_into(self) -> Result<HeaderName, Self::Error> {
619 1
                Ok(HeaderName::from_static("content-type"))
620
            }
621
        }
622

623 1
        let _cors = Cors::default().allowed_header(ContentType);
624
    }
625
}

Read our documentation on viewing source code .

Loading