Skip to content

Commit

Permalink
Finalise Middleware Auth
Browse files Browse the repository at this point in the history
  • Loading branch information
djnovin committed Oct 10, 2024
1 parent 1131158 commit 67a880f
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 29 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,6 @@ curl http://localhost:3000/api/v1/public -H "Authorization: Bearer <YOUR_API_KEY

```bash

curl http://localhost:3000/api/v1/protected -H "Authorization: Bearer <YOUR_API_KEY>"
curl http://localhost:3000/api/v1/protected -H "Authorization: Bearer 3ZiAsxZbqdNUBpqoGczVmQoc"

```
116 changes: 88 additions & 28 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ struct RateLimitRequest {
#[serde(rename = "async")]
async_field: bool,
meta: HashMap<String, String>,
resources: Vec<Resource>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
resources: Vec<Resource>, // Default to an empty array if no resources are provided
}

#[derive(Serialize)]
Expand All @@ -32,10 +33,25 @@ struct Resource {
name: String,
}

#[derive(Deserialize)]
struct ApiErrorResponse {
error: ApiError,
}

#[derive(Deserialize)]
struct ApiError {
code: String,
docs: String,
message: String,
requestId: String,
}

#[derive(Deserialize)]
struct RateLimitResponse {
remaining: i32,
reset: u64,
limit: Option<i32>,
remaining: Option<i32>,
reset: Option<u64>,
success: Option<bool>,
}

#[derive(Clone)]
Expand All @@ -56,19 +72,31 @@ async fn verify_key(
req: ServiceRequest,
next: Next<impl MessageBody>,
) -> Result<ServiceResponse<impl MessageBody>, Error> {
let headers = req.headers();
info!("Middleware start");

let headers = req.headers().clone(); // Clone the headers so req is not borrowed later

info!("Headers: {:?}", headers);

let data = req.app_data::<web::Data<AppState>>().unwrap();
let client = req.app_data::<web::Data<Client>>().unwrap();

let connection_info = req.connection_info().clone();
let user_ip = connection_info.realip_remote_addr().unwrap_or("unknown").to_string();

let authorization_header = if let Some(header_value) = headers.get("Authorization") {
header_value.to_str().unwrap_or("")
match header_value.to_str() {
Ok(value) if value.starts_with("Bearer ") => value.trim_start_matches("Bearer ").to_string(),
_ => {
return Err(actix_web::error::ErrorUnauthorized(
"Invalid Authorization header format",
))
}
}
} else {
return Err(actix_web::error::ErrorUnauthorized("Authorization header missing"));
};

// TODO: Replace with your own user ID
let user_id = "some_user_id";

let verify_request = VerifyKeyRequest {
key: authorization_header.to_string(),
api_id: data.unkey_api_id.clone().into(),
Expand All @@ -77,40 +105,74 @@ async fn verify_key(
match data.unkey_client.verify_key(verify_request).await {
Ok(res) if res.valid => {
let rate_limit_request = RateLimitRequest {
namespace: "email.outbound".to_string(),
identifier: user_id.to_string(),
namespace: "test_protected".to_string(), // Namespace for the rate limit
identifier: user_ip, // Identifier for the rate limit
limit: 10,
duration: 60000,
cost: 2,
async_field: true,
meta: HashMap::new(),
resources: vec![Resource {
r#type: "project".to_string(),
id: "p_123".to_string(),
name: "dub".to_string(),
}],
resources: vec![],
};

let unkey_root_key = env::var("UNKEY_ROOT_KEY").expect("UNKEY_ROOT_KEY must be set");

let rate_limit_response = client
.post("https://api.unkey.dev/v1/ratelimits.limit")
.bearer_auth(authorization_header)
.bearer_auth(unkey_root_key)
.header("Content-Type", "application/json")
.json(&rate_limit_request)
.send()
.await
.unwrap();

let rate_limit_result = rate_limit_response.json::<RateLimitResponse>().await.unwrap();
if rate_limit_response.status().is_success() {
let rate_limit_result = match rate_limit_response.json::<RateLimitResponse>().await {
Ok(response) => response,
Err(err) => {
log::error!("Failed to deserialize rate limit response: {:?}", err);
return Err(actix_web::error::ErrorInternalServerError(
"Failed to parse rate limit response",
));
}
};

if rate_limit_result.remaining > 0 {
info!("Rate limit check passed");
if let Some(remaining) = rate_limit_result.remaining {
if remaining > 0 {
// Rate limit passed, proceed to the next middleware or handler
let res = next.call(req).await?;
Ok(res)
} else {
log::info!("Rate limit exceeded. Resets at: {:?}", rate_limit_result.reset);
return Err(actix_web::error::ErrorTooManyRequests("Rate limit exceeded"));
}
} else {
log::error!("Rate limit response missing 'remaining' field");
return Err(actix_web::error::ErrorInternalServerError(
"Invalid rate limit response",
));
}
} else {
info!("Rate limit resets at: {}", rate_limit_result.reset);
return Err(actix_web::error::ErrorTooManyRequests("Rate limit exceeded"));
}
// Parse the error response
let error_response: ApiErrorResponse = rate_limit_response.json().await.map_err(|err| {
log::error!("Failed to parse error response: {:?}", err);
actix_web::error::ErrorInternalServerError("Failed to parse error response")
})?;

let res = next.call(req).await?;
// Log the error and return a meaningful error message to the user
log::error!(
"Rate limit request failed. Code: {}, Message: {}, Docs: {}, Request ID: {}",
error_response.error.code,
error_response.error.message,
error_response.error.docs,
error_response.error.requestId
);

Ok(res)
return Err(actix_web::error::ErrorBadRequest(format!(
"Rate limit request failed: {} (Request ID: {})",
error_response.error.message, error_response.error.requestId
)));
}
}
Ok(res) => {
error!("Key verification failed: {:?}", res);
Expand Down Expand Up @@ -161,13 +223,11 @@ async fn main() -> std::io::Result<()> {
.wrap(Logger::new("%a %{User-Agent}i"))
.app_data(client.clone())
.app_data(shared_data.clone())
// .wrap(from_fn(middleware))
.wrap(from_fn(verify_key))
.service(
web::scope("/api/v1")
.route("/public", web::get().to(public))
.route("/protected",
web::get().wrap(from_fn(verify_key)).to(protected)
),
.route("/protected", web::get().to(protected)),
)
})
.bind(("127.0.0.1", port))?
Expand Down

0 comments on commit 67a880f

Please sign in to comment.