Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxy: refactor json parsing #9013

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ COPY --from=pg-build /home/nonroot/pg_install/v17/lib pg_i
COPY --chown=nonroot . .

ARG ADDITIONAL_RUSTFLAGS
ENV _RJEM_MALLOC_CONF="thp:never"
RUN set -e \
&& PQ_LIB_DIR=$(pwd)/pg_install/v${STABLE_PG_VERSION}/lib RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment ${ADDITIONAL_RUSTFLAGS}" cargo build \
--bin pg_sni_router \
Expand Down
6 changes: 4 additions & 2 deletions proxy/src/auth/backend/jwt.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
future::Future,
marker::PhantomData,
sync::Arc,
time::{Duration, SystemTime},
};
Expand Down Expand Up @@ -147,14 +148,15 @@ impl JwkCacheEntryLock {
Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
Ok(r) => {
let resp: http::Response<reqwest::Body> = r.into();
match parse_json_body_with_limit::<jose_jwk::JwkSet>(
match parse_json_body_with_limit::<jose_jwk::JwkSet, _>(
PhantomData,
resp.into_body(),
MAX_JWK_BODY_SIZE,
)
.await
{
Err(e) => {
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
tracing::warn!(url=?rule.jwks_url, error=%e, "could not decode JWKs");
}
Ok(jwks) => {
key_sets.insert(
Expand Down
1 change: 1 addition & 0 deletions proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ async fn main() -> anyhow::Result<()> {
build_tag: BUILD_TAG,
});

proxy::jemalloc::inspect_thp()?;
let jemalloc = match proxy::jemalloc::MetricRecorder::new() {
Ok(t) => Some(t),
Err(e) => {
Expand Down
33 changes: 24 additions & 9 deletions proxy/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ pub mod health_server;

use std::time::Duration;

use anyhow::bail;
use bytes::Bytes;
use http_body_util::BodyExt;
use hyper1::body::Body;
use serde::de::DeserializeOwned;
use serde::de::DeserializeSeed;

pub(crate) use reqwest::{Request, Response};
pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error};
Expand Down Expand Up @@ -113,31 +112,47 @@ impl Endpoint {
}
}

pub(crate) async fn parse_json_body_with_limit<D: DeserializeOwned>(
mut b: impl Body<Data = Bytes, Error = reqwest::Error> + Unpin,
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError<E> {
#[error("could not read the HTTP body: {0}")]
Read(E),
#[error("could not parse the HTTP body: {0}")]
Parse(#[from] serde_json::Error),
#[error("could not parse the HTTP body: content length exceeds limit of {0} bytes")]
LengthExceeded(usize),
}

pub(crate) async fn parse_json_body_with_limit<D, E>(
seed: impl for<'de> DeserializeSeed<'de, Value = D>,
mut b: impl Body<Data = Bytes, Error = E> + Unpin,
limit: usize,
) -> anyhow::Result<D> {
) -> Result<D, ReadPayloadError<E>> {
// We could use `b.limited().collect().await.to_bytes()` here
// but this ends up being slightly more efficient as far as I can tell.

// check the lower bound of the size hint.
// in reqwest, this value is influenced by the Content-Length header.
let lower_bound = match usize::try_from(b.size_hint().lower()) {
Ok(bound) if bound <= limit => bound,
_ => bail!("Content length exceeds limit of {limit} bytes"),
_ => return Err(ReadPayloadError::LengthExceeded(limit)),
};
let mut bytes = Vec::with_capacity(lower_bound);

while let Some(frame) = b.frame().await.transpose()? {
while let Some(frame) = b
.frame()
.await
.transpose()
.map_err(ReadPayloadError::Read)?
{
if let Ok(data) = frame.into_data() {
if bytes.len() + data.len() > limit {
bail!("Content length exceeds limit of {limit} bytes")
return Err(ReadPayloadError::LengthExceeded(limit));
}
bytes.extend_from_slice(&data);
}
}

Ok(serde_json::from_slice::<D>(&bytes)?)
Ok(seed.deserialize(&mut serde_json::Deserializer::from_slice(&bytes))?)
}

#[cfg(test)]
Expand Down
10 changes: 9 additions & 1 deletion proxy/src/jemalloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use measured::{
text::TextEncoder,
LabelGroup, MetricGroup,
};
use tikv_jemalloc_ctl::{config, epoch, epoch_mib, stats, version};
use tikv_jemalloc_ctl::{config, epoch, epoch_mib, stats, version, Access, AsName, Name};
use tracing::info;

pub struct MetricRecorder {
epoch: epoch_mib,
Expand Down Expand Up @@ -114,3 +115,10 @@ jemalloc_gauge!(mapped, mapped_mib);
jemalloc_gauge!(metadata, metadata_mib);
jemalloc_gauge!(resident, resident_mib);
jemalloc_gauge!(retained, retained_mib);

pub fn inspect_thp() -> Result<(), tikv_jemalloc_ctl::Error> {
let opt_thp: &Name = c"opt.thp".to_bytes_with_nul().name();
let s: &str = opt_thp.read()?;
info!("jemalloc opt.thp {s}");
Ok(())
}
Loading
Loading