Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions contrib/db_pools/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ diesel_postgres = ["diesel-async/postgres", "diesel-async/deadpool", "deadpool",
diesel_mysql = ["diesel-async/mysql", "diesel-async/deadpool", "deadpool", "diesel"]
# implicit features: mongodb

# postgres features
postgres_rustls_native_certs = ["tokio-postgres", "tokio-postgres-rustls", "rustls", "rustls-native-certs"]

[dependencies.rocket]
path = "../../../core/lib"
version = "0.6.0-dev"
Expand Down Expand Up @@ -80,6 +83,26 @@ default-features = false
features = ["runtime-tokio-rustls"]
optional = true

[dependencies.tokio-postgres]
version = "0.7"
default-features = false
optional = true

[dependencies.tokio-postgres-rustls]
version = "0.13"
default-features = false
optional = true

[dependencies.rustls]
version = "0.23"
default-features = false
optional = true

[dependencies.rustls-native-certs]
version = "0.8"
default-features = false
optional = true

[dependencies.log]
version = "0.4"
default-features = false
Expand Down
27 changes: 26 additions & 1 deletion contrib/db_pools/lib/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use rocket::serde::{Deserialize, Serialize};
/// For higher-level details on configuring a database, see the [crate-level
/// docs](crate#configuration).
// NOTE: Defaults provided by the figment created in the `Initializer` fairing.
#[derive(Default, Serialize, Deserialize, Debug, Clone, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(crate = "rocket::serde")]
pub struct Config {
/// Database-specific connection and configuration URL.
Expand Down Expand Up @@ -93,3 +93,28 @@ pub struct Config {
/// _Default:_ `None`.
pub extensions: Option<Vec<String>>,
}

impl Default for Config {
fn default() -> Self {
Self {
url: Default::default(),
min_connections: Default::default(),
max_connections: rocket::Config::default().workers * 4,
connect_timeout: 5,
idle_timeout: Default::default(),
extensions: Default::default(),
}
}
}

#[cfg(test)]
mod tests {
use super::Config;

#[test]
fn default_values_sane() {
let config = Config::default();
assert_ne!(config.max_connections, 0);
assert_eq!(config.connect_timeout, 5);
}
}
58 changes: 52 additions & 6 deletions contrib/db_pools/lib/src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,23 @@ pub trait Pool: Sized + Send + Sync + 'static {
async fn close(&self);
}

#[cfg(feature = "postgres_rustls_native_certs")]
fn tokio_postgres_tls_provider() -> tokio_postgres_rustls::MakeRustlsConnect {
let mut roots = rustls::RootCertStore::empty();

let certs = rustls_native_certs::load_native_certs()
.expect("native certs should be available");
for cert in certs {
roots.add(cert).expect("native root cert should be valid");
}

let config = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();

tokio_postgres_rustls::MakeRustlsConnect::new(config)
}

#[cfg(feature = "deadpool")]
mod deadpool_postgres {
use deadpool::{Runtime, managed::{Manager, Pool, PoolError, Object}};
Expand All @@ -167,7 +184,13 @@ mod deadpool_postgres {
#[cfg(feature = "deadpool_postgres")]
impl DeadManager for deadpool_postgres::Manager {
fn new(config: &Config) -> Result<Self, Self::Error> {
Ok(Self::new(config.url.parse()?, deadpool_postgres::tokio_postgres::NoTls))
#[cfg(feature = "postgres_rustls_native_certs")]
let tls_provider = super::tokio_postgres_tls_provider();

#[cfg(not(feature = "postgres_rustls_native_certs"))]
let tls_provider = deadpool_postgres::tokio_postgres::NoTls;

Ok(Self::new(config.url.parse()?, tls_provider))
}
}

Expand All @@ -181,7 +204,32 @@ mod deadpool_postgres {
#[cfg(feature = "diesel_postgres")]
impl DeadManager for AsyncDieselConnectionManager<diesel_async::AsyncPgConnection> {
fn new(config: &Config) -> Result<Self, Self::Error> {
Ok(Self::new(config.url.as_str()))
use diesel_async::AsyncPgConnection;
use diesel_async::pooled_connection::ManagerConfig;

let diesel_config = ManagerConfig::default();

#[cfg(feature = "postgres_rustls_native_certs")]
let diesel_config = {
let mut diesel_config = diesel_config;

let tls_provider = super::tokio_postgres_tls_provider();

diesel_config.custom_setup = Box::new(move |url| {
let tls_provider = tls_provider.clone();
Box::pin(async move {
let (client, conn) = tokio_postgres::connect(url, tls_provider)
.await
.map_err(|e| diesel::ConnectionError::BadConnection(e.to_string()))?;

AsyncPgConnection::try_from_client_and_connection(client, conn).await
}) as std::pin::Pin<Box<_>>
});

diesel_config
};

Ok(Self::new_with_config(config.url.as_str(), diesel_config))
}
}

Expand Down Expand Up @@ -276,14 +324,12 @@ mod sqlx {
}
}

sqlx::pool::PoolOptions::new()
Ok(sqlx::pool::PoolOptions::new()
.max_connections(config.max_connections as u32)
.acquire_timeout(Duration::from_secs(config.connect_timeout))
.idle_timeout(config.idle_timeout.map(Duration::from_secs))
.min_connections(config.min_connections.unwrap_or_default())
.connect_with(opts)
.await
.map_err(Error::Init)
.connect_lazy_with(opts))
}

async fn get(&self) -> Result<Self::Connection, Self::Error> {
Expand Down
16 changes: 16 additions & 0 deletions core/http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,19 @@ pub use crate::method::Method;
pub use crate::status::{Status, StatusClass};
pub use crate::raw_str::{RawStr, RawStrBuf};
pub use crate::header::*;

/// HTTP Protocol version
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum HttpVersion {
/// `HTTP/0.9`
Http09,
/// `HTTP/1.0`
Http10,
/// `HTTP/1.1`
Http11,
/// `HTTP/2`
Http2,
/// `HTTP/3`
Http3,
}
4 changes: 2 additions & 2 deletions core/http/src/uri/uri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl<'a> Uri<'a> {
/// // Invalid URIs fail to parse.
/// Uri::parse::<Origin>("foo bar").expect_err("invalid URI");
/// ```
pub fn parse<T>(string: &'a str) -> Result<Uri<'a>, Error<'_>>
pub fn parse<T>(string: &'a str) -> Result<Uri<'a>, Error<'a>>
where T: Into<Uri<'a>> + TryFrom<&'a str, Error = Error<'a>>
{
T::try_from(string).map(|v| v.into())
Expand Down Expand Up @@ -127,7 +127,7 @@ impl<'a> Uri<'a> {
/// let uri: Origin = uri!("/a/b/c?query");
/// let uri: Reference = uri!("/a/b/c?query#fragment");
/// ```
pub fn parse_any(string: &'a str) -> Result<Uri<'a>, Error<'_>> {
pub fn parse_any(string: &'a str) -> Result<Uri<'a>, Error<'a>> {
crate::parse::uri::from_str(string)
}

Expand Down
4 changes: 2 additions & 2 deletions core/lib/src/form/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ impl<'v> Context<'v> {
/// let foo_bar = form.context.field_errors("foo.bar");
/// }
/// ```
pub fn field_errors<'a, N>(&'a self, name: N) -> impl Iterator<Item = &Error<'v>> + '_
pub fn field_errors<'a, N>(&'a self, name: N) -> impl Iterator<Item = &'a Error<'v>> + 'a
where N: AsRef<Name> + 'a
{
self.errors.values()
Expand Down Expand Up @@ -267,7 +267,7 @@ impl<'v> Context<'v> {
/// let foo_bar = form.context.exact_field_errors("foo.bar");
/// }
/// ```
pub fn exact_field_errors<'a, N>(&'a self, name: N) -> impl Iterator<Item = &Error<'v>> + '_
pub fn exact_field_errors<'a, N>(&'a self, name: N) -> impl Iterator<Item = &'a Error<'v>> + 'a
where N: AsRef<Name> + 'a
{
self.errors.values()
Expand Down
9 changes: 8 additions & 1 deletion core/lib/src/local/asynchronous/request.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt;

use rocket_http::HttpVersion;

use crate::{Request, Data};
use crate::http::{Status, Method};
use crate::http::uri::Origin;
Expand Down Expand Up @@ -48,7 +50,7 @@ impl<'c> LocalRequest<'c> {

// Create a request. We'll handle bad URIs later, in `_dispatch`.
let origin = try_origin.clone().unwrap_or_else(|bad| bad);
let mut request = Request::new(client.rocket(), method, origin);
let mut request = Request::new(client.rocket(), method, origin, None);

// Add any cookies we know about.
if client.tracked {
Expand All @@ -62,6 +64,11 @@ impl<'c> LocalRequest<'c> {
LocalRequest { client, request, uri: try_origin, data: vec![] }
}

#[inline]
pub fn override_version(&mut self, version: HttpVersion) {
self.version = Some(version);
}

pub(crate) fn _request(&self) -> &Request<'c> {
&self.request
}
Expand Down
7 changes: 7 additions & 0 deletions core/lib/src/local/blocking/request.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt;

use rocket_http::HttpVersion;

use crate::{Request, http::Method, local::asynchronous};
use crate::http::uri::Origin;

Expand Down Expand Up @@ -42,6 +44,11 @@ impl<'c> LocalRequest<'c> {
Self { inner, client }
}

#[inline]
pub fn override_version(&mut self, version: HttpVersion) {
self.inner.override_version(version);
}

#[inline]
fn _request(&self) -> &Request<'c> {
self.inner._request()
Expand Down
32 changes: 30 additions & 2 deletions core/lib/src/request/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::str::FromStr;
use std::future::Future;
use std::net::IpAddr;

use http::Version;
use rocket_http::HttpVersion;
use state::{TypeMap, InitCell};
use futures::future::BoxFuture;
use ref_swap::OptionRefSwap;
Expand All @@ -31,6 +33,7 @@ pub struct Request<'r> {
method: AtomicMethod,
uri: Origin<'r>,
headers: HeaderMap<'r>,
pub(crate) version: Option<HttpVersion>,
pub(crate) errors: Vec<RequestError>,
pub(crate) connection: ConnectionMeta,
pub(crate) state: RequestState<'r>,
Expand Down Expand Up @@ -84,12 +87,14 @@ impl<'r> Request<'r> {
pub(crate) fn new<'s: 'r>(
rocket: &'r Rocket<Orbit>,
method: Method,
uri: Origin<'s>
uri: Origin<'s>,
version: Option<HttpVersion>,
) -> Request<'r> {
Request {
uri,
method: AtomicMethod::new(method),
headers: HeaderMap::new(),
version,
errors: Vec::new(),
connection: ConnectionMeta::default(),
state: RequestState {
Expand All @@ -104,6 +109,22 @@ impl<'r> Request<'r> {
}
}

/// Retrieve http protocol version, when applicable.
///
/// # Example
///
/// ```rust
/// use rocket::http::HttpVersion;
///
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
/// # let mut req = c.get("/");
/// # req.override_version(HttpVersion::Http11);
/// assert_eq!(req.version(), Some(HttpVersion::Http11));
/// ```
pub fn version(&self) -> Option<HttpVersion> {
self.version
}

/// Retrieve the method from `self`.
///
/// # Example
Expand Down Expand Up @@ -1130,7 +1151,14 @@ impl<'r> Request<'r> {
});

// Construct the request object; fill in metadata and headers next.
let mut request = Request::new(rocket, method, uri);
let mut request = Request::new(rocket, method, uri, match hyper.version {
Version::HTTP_09 => Some(HttpVersion::Http09),
Version::HTTP_10 => Some(HttpVersion::Http10),
Version::HTTP_11 => Some(HttpVersion::Http11),
Version::HTTP_2 => Some(HttpVersion::Http2),
Version::HTTP_3 => Some(HttpVersion::Http3),
_ => None,
});
request.errors = errors;

// Set the passed in connection metadata.
Expand Down
6 changes: 3 additions & 3 deletions docs/guide/11-deploying.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ WORKDIR /app
COPY --from=build /build/main ./

## copy runtime assets which may or may not exist
COPY --from=build /build/Rocket.tom[l] ./static
COPY --from=build /build/stati[c] ./static
COPY --from=build /build/template[s] ./templates
COPY --from=build /build/Rocket.tom[l] ./
COPY --from=build /build/stati[c] ./static/
COPY --from=build /build/template[s] ./templates/

## ensure the container listens globally on port 8080
ENV ROCKET_ADDRESS=0.0.0.0
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/12-pastebin.md
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ We note the following Rocket APIs being used in our implementation:
* [`Data::open()`] to open [`Data`] as a [`DataStream`].
* [`DataStream::into_file()`] for writing the data stream into a file.
* The [`UriDisplayPath`] derive, allowing `PasteId` to be used in [`uri!`].
* The [`uri!`] macro to crate type-safe, URL-safe URIs.
* The [`uri!`] macro to create type-safe, URL-safe URIs.

[`Data::open()`]: @api/master/rocket/data/struct.Data.html#method.open
[`Data`]: @api/master/rocket/data/struct.Data.html
Expand Down
Loading