Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
root = true
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not strictly related to the change, but I ran into some annoying check_style issues and it wasn't immediately clear how to properly run the checks on my machine.

This should help with general file formatting, trailing whitespace, and line endings at the end of files for those who use an editorconfig editor plugin. I highly recommend it!


[*.rs]
insert_final_newline = true
trim_trailing_whitespace = true
indent_style = space
indent_size = 4
tab_width = 4
27 changes: 26 additions & 1 deletion contrib/db_pools/lib/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use rocket::serde::{Deserialize, Serialize};
/// For higher-level details on configuring a database, see the [crate-level
/// docs](crate#configuration).
// NOTE: Defaults provided by the figment created in the `Initializer` fairing.
#[derive(Default, Serialize, Deserialize, Debug, Clone, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(crate = "rocket::serde")]
pub struct Config {
/// Database-specific connection and configuration URL.
Expand Down Expand Up @@ -93,3 +93,28 @@ pub struct Config {
/// _Default:_ `None`.
pub extensions: Option<Vec<String>>,
}

impl Default for Config {
fn default() -> Self {
Self {
url: Default::default(),
min_connections: Default::default(),
max_connections: rocket::Config::default().workers * 4,
connect_timeout: 5,
idle_timeout: Default::default(),
extensions: Default::default(),
}
}
}

#[cfg(test)]
mod tests {
use super::Config;

#[test]
fn default_values_sane() {
let config = Config::default();
assert_ne!(config.max_connections, 0);
assert_eq!(config.connect_timeout, 5);
}
}
6 changes: 2 additions & 4 deletions contrib/db_pools/lib/src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,12 @@ mod sqlx {
}
}

sqlx::pool::PoolOptions::new()
Ok(sqlx::pool::PoolOptions::new()
.max_connections(config.max_connections as u32)
.acquire_timeout(Duration::from_secs(config.connect_timeout))
.idle_timeout(config.idle_timeout.map(Duration::from_secs))
.min_connections(config.min_connections.unwrap_or_default())
.connect_with(opts)
.await
.map_err(Error::Init)
.connect_lazy_with(opts))
}

async fn get(&self) -> Result<Self::Connection, Self::Error> {
Expand Down
16 changes: 16 additions & 0 deletions core/http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,19 @@ pub use crate::method::Method;
pub use crate::status::{Status, StatusClass};
pub use crate::raw_str::{RawStr, RawStrBuf};
pub use crate::header::*;

/// HTTP Protocol version
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum HttpVersion {
/// `HTTP/0.9`
Http09,
/// `HTTP/1.0`
Http10,
/// `HTTP/1.1`
Http11,
/// `HTTP/2`
Http2,
/// `HTTP/3`
Http3,
}
4 changes: 2 additions & 2 deletions core/http/src/uri/uri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl<'a> Uri<'a> {
/// // Invalid URIs fail to parse.
/// Uri::parse::<Origin>("foo bar").expect_err("invalid URI");
/// ```
pub fn parse<T>(string: &'a str) -> Result<Uri<'a>, Error<'_>>
pub fn parse<T>(string: &'a str) -> Result<Uri<'a>, Error<'a>>
where T: Into<Uri<'a>> + TryFrom<&'a str, Error = Error<'a>>
{
T::try_from(string).map(|v| v.into())
Expand Down Expand Up @@ -127,7 +127,7 @@ impl<'a> Uri<'a> {
/// let uri: Origin = uri!("/a/b/c?query");
/// let uri: Reference = uri!("/a/b/c?query#fragment");
/// ```
pub fn parse_any(string: &'a str) -> Result<Uri<'a>, Error<'_>> {
pub fn parse_any(string: &'a str) -> Result<Uri<'a>, Error<'a>> {
crate::parse::uri::from_str(string)
}

Expand Down
4 changes: 2 additions & 2 deletions core/lib/src/form/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ impl<'v> Context<'v> {
/// let foo_bar = form.context.field_errors("foo.bar");
/// }
/// ```
pub fn field_errors<'a, N>(&'a self, name: N) -> impl Iterator<Item = &Error<'v>> + '_
pub fn field_errors<'a, N>(&'a self, name: N) -> impl Iterator<Item = &'a Error<'v>> + 'a
where N: AsRef<Name> + 'a
{
self.errors.values()
Expand Down Expand Up @@ -267,7 +267,7 @@ impl<'v> Context<'v> {
/// let foo_bar = form.context.exact_field_errors("foo.bar");
/// }
/// ```
pub fn exact_field_errors<'a, N>(&'a self, name: N) -> impl Iterator<Item = &Error<'v>> + '_
pub fn exact_field_errors<'a, N>(&'a self, name: N) -> impl Iterator<Item = &'a Error<'v>> + 'a
where N: AsRef<Name> + 'a
{
self.errors.values()
Expand Down
2 changes: 2 additions & 0 deletions core/lib/src/listener/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin {
/// Defaults to an empty vector to indicate that no certificates were
/// presented.
fn certificates(&self) -> Option<Certificates<'_>> { None }

fn server_name(&self) -> Option<&str> { None }
}

impl<A: Connection, B: Connection> Connection for Either<A, B> {
Expand Down
9 changes: 8 additions & 1 deletion core/lib/src/local/asynchronous/request.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt;

use rocket_http::HttpVersion;

use crate::{Request, Data};
use crate::http::{Status, Method};
use crate::http::uri::Origin;
Expand Down Expand Up @@ -48,7 +50,7 @@ impl<'c> LocalRequest<'c> {

// Create a request. We'll handle bad URIs later, in `_dispatch`.
let origin = try_origin.clone().unwrap_or_else(|bad| bad);
let mut request = Request::new(client.rocket(), method, origin);
let mut request = Request::new(client.rocket(), method, origin, None);

// Add any cookies we know about.
if client.tracked {
Expand All @@ -62,6 +64,11 @@ impl<'c> LocalRequest<'c> {
LocalRequest { client, request, uri: try_origin, data: vec![] }
}

#[inline]
pub fn override_version(&mut self, version: HttpVersion) {
self.version = Some(version);
}

pub(crate) fn _request(&self) -> &Request<'c> {
&self.request
}
Expand Down
7 changes: 7 additions & 0 deletions core/lib/src/local/blocking/request.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt;

use rocket_http::HttpVersion;

use crate::{Request, http::Method, local::asynchronous};
use crate::http::uri::Origin;

Expand Down Expand Up @@ -42,6 +44,11 @@ impl<'c> LocalRequest<'c> {
Self { inner, client }
}

#[inline]
pub fn override_version(&mut self, version: HttpVersion) {
self.inner.override_version(version);
}

#[inline]
fn _request(&self) -> &Request<'c> {
self.inner._request()
Expand Down
50 changes: 47 additions & 3 deletions core/lib/src/request/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::str::FromStr;
use std::future::Future;
use std::net::IpAddr;

use http::Version;
use rocket_http::HttpVersion;
use state::{TypeMap, InitCell};
use futures::future::BoxFuture;
use ref_swap::OptionRefSwap;
Expand All @@ -31,6 +33,7 @@ pub struct Request<'r> {
method: AtomicMethod,
uri: Origin<'r>,
headers: HeaderMap<'r>,
pub(crate) version: Option<HttpVersion>,
pub(crate) errors: Vec<RequestError>,
pub(crate) connection: ConnectionMeta,
pub(crate) state: RequestState<'r>,
Expand All @@ -42,13 +45,19 @@ pub(crate) struct ConnectionMeta {
pub peer_endpoint: Option<Endpoint>,
#[cfg_attr(not(feature = "mtls"), allow(dead_code))]
pub peer_certs: Option<Arc<Certificates<'static>>>,
#[cfg_attr(feature = "tls", allow(dead_code))]
pub server_name: Option<String>,
}

impl ConnectionMeta {
pub fn new(endpoint: io::Result<Endpoint>, certs: Option<Certificates<'_>>) -> Self {
pub fn new(
endpoint: io::Result<Endpoint>,
certs: Option<Certificates<'_>>,
server_name: Option<&str>) -> Self {
ConnectionMeta {
peer_endpoint: endpoint.ok(),
peer_certs: certs.map(|c| c.into_owned()).map(Arc::new),
server_name: server_name.map(|s| s.to_string()),
}
}
}
Expand Down Expand Up @@ -84,12 +93,14 @@ impl<'r> Request<'r> {
pub(crate) fn new<'s: 'r>(
rocket: &'r Rocket<Orbit>,
method: Method,
uri: Origin<'s>
uri: Origin<'s>,
version: Option<HttpVersion>,
) -> Request<'r> {
Request {
uri,
method: AtomicMethod::new(method),
headers: HeaderMap::new(),
version,
errors: Vec::new(),
connection: ConnectionMeta::default(),
state: RequestState {
Expand All @@ -104,6 +115,22 @@ impl<'r> Request<'r> {
}
}

/// Retrieve http protocol version, when applicable.
///
/// # Example
///
/// ```rust
/// use rocket::http::HttpVersion;
///
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
/// # let mut req = c.get("/");
/// # req.override_version(HttpVersion::Http11);
/// assert_eq!(req.version(), Some(HttpVersion::Http11));
/// ```
pub fn version(&self) -> Option<HttpVersion> {
self.version
}

/// Retrieve the method from `self`.
///
/// # Example
Expand Down Expand Up @@ -274,6 +301,16 @@ impl<'r> Request<'r> {
self.state.host.as_ref()
}

/// Returns the resolved SNI server name requested in the TLS handshake, if
/// any.
///
/// Ideally, this will match the `Host` header in the request.
#[cfg(feature = "tls")]
#[inline(always)]
pub fn sni(&mut self) -> Option<&str> {
self.connection.server_name.as_deref()
}

/// Sets the host of `self` to `host`.
///
/// # Example
Expand Down Expand Up @@ -1130,7 +1167,14 @@ impl<'r> Request<'r> {
});

// Construct the request object; fill in metadata and headers next.
let mut request = Request::new(rocket, method, uri);
let mut request = Request::new(rocket, method, uri, match hyper.version {
Version::HTTP_09 => Some(HttpVersion::Http09),
Version::HTTP_10 => Some(HttpVersion::Http10),
Version::HTTP_11 => Some(HttpVersion::Http11),
Version::HTTP_2 => Some(HttpVersion::Http2),
Version::HTTP_3 => Some(HttpVersion::Http3),
_ => None,
});
request.errors = errors;

// Set the passed in connection metadata.
Expand Down
8 changes: 6 additions & 2 deletions core/lib/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ impl Rocket<Orbit> {
let (listener, rocket, server) = (listener.clone(), self.clone(), server.clone());
spawn_inspect(|e| log_server_error(&**e), async move {
let conn = listener.connect(accept).race_io(rocket.shutdown()).await?;
let meta = ConnectionMeta::new(conn.endpoint(), conn.certificates());
let meta = ConnectionMeta::new(
conn.endpoint(),
conn.certificates(),
conn.server_name()
);
let service = service_fn(|mut req| {
let upgrade = hyper::upgrade::on(&mut req);
let (parts, incoming) = req.into_parts();
Expand Down Expand Up @@ -205,7 +209,7 @@ impl Rocket<Orbit> {
while let Some(mut conn) = stream.accept().race_io(rocket.shutdown()).await? {
let rocket = rocket.clone();
spawn_inspect(|e: &io::Error| log_server_error(e), async move {
let meta = ConnectionMeta::new(conn.endpoint(), None);
let meta = ConnectionMeta::new(conn.endpoint(), None, None);
let rx = conn.rx.cancellable(rocket.shutdown.clone());
let response = rocket.clone()
.service(conn.parts, rx, None, meta)
Expand Down
9 changes: 9 additions & 0 deletions core/lib/src/tls/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,13 @@ impl<C: Connection> Connection for TlsStream<C> {
#[cfg(not(feature = "mtls"))]
None
}

fn server_name(&self) -> Option<&str> {
#[cfg(feature = "tls")] {
self.get_ref().1.server_name()
}

#[cfg(not(feature = "tls"))]
None
}
}
6 changes: 3 additions & 3 deletions docs/guide/11-deploying.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ WORKDIR /app
COPY --from=build /build/main ./

## copy runtime assets which may or may not exist
COPY --from=build /build/Rocket.tom[l] ./static
COPY --from=build /build/stati[c] ./static
COPY --from=build /build/template[s] ./templates
COPY --from=build /build/Rocket.tom[l] ./
COPY --from=build /build/stati[c] ./static/
COPY --from=build /build/template[s] ./templates/

## ensure the container listens globally on port 8080
ENV ROCKET_ADDRESS=0.0.0.0
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/12-pastebin.md
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ We note the following Rocket APIs being used in our implementation:
* [`Data::open()`] to open [`Data`] as a [`DataStream`].
* [`DataStream::into_file()`] for writing the data stream into a file.
* The [`UriDisplayPath`] derive, allowing `PasteId` to be used in [`uri!`].
* The [`uri!`] macro to crate type-safe, URL-safe URIs.
* The [`uri!`] macro to create type-safe, URL-safe URIs.

[`Data::open()`]: @api/master/rocket/data/struct.Data.html#method.open
[`Data`]: @api/master/rocket/data/struct.Data.html
Expand Down
Loading