From 3cd1319f008ac045001fb9ee5ad9346105728fd2 Mon Sep 17 00:00:00 2001 From: djnovin Date: Thu, 10 Oct 2024 14:53:21 +1100 Subject: [PATCH] Seperation of Concerns for App --- src/handlers.rs | 9 +++ src/main.rs | 191 ++-------------------------------------------- src/middleware.rs | 129 +++++++++++++++++++++++++++++++ src/models.rs | 55 +++++++++++++ 4 files changed, 201 insertions(+), 183 deletions(-) create mode 100644 src/handlers.rs create mode 100644 src/middleware.rs create mode 100644 src/models.rs diff --git a/src/handlers.rs b/src/handlers.rs new file mode 100644 index 0000000..1795559 --- /dev/null +++ b/src/handlers.rs @@ -0,0 +1,9 @@ +use actix_web::HttpRequest; + +pub async fn public(_req: HttpRequest) -> String { + "Hello, world!".to_owned() +} + +pub async fn protected(_req: HttpRequest) -> String { + "Hello, protected world!".to_owned() +} diff --git a/src/main.rs b/src/main.rs index 46fd0fc..a38c4e7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,62 +1,17 @@ -use actix_web::body::MessageBody; -use actix_web::dev::{ServiceRequest, ServiceResponse}; -use actix_web::middleware::{from_fn, Logger, Next}; -use actix_web::{web, App, Error, HttpRequest, HttpServer}; +use actix_web::middleware::{from_fn, Logger}; +use actix_web::{web, App, HttpServer}; use dotenv::dotenv; use env_logger::Env; -use log::{error, info}; +use handlers::{protected, public}; +use middleware::verify_key; +use models::{AppState, UnkeyApiId}; use reqwest::Client; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::env; -use unkey::models::VerifyKeyRequest; use unkey::Client as UnkeyClient; -#[derive(Serialize)] -struct RateLimitRequest { - namespace: String, - identifier: String, - limit: u32, - duration: u64, - cost: u32, - #[serde(rename = "async")] - async_field: bool, - meta: HashMap, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - resources: Vec, // Default to an empty array if no resources are provided -} - -#[derive(Serialize)] -struct Resource { - r#type: String, - id: String, - name: String, -} - -#[derive(Deserialize)] -struct ApiErrorResponse { - error: ApiError, -} - -#[derive(Deserialize)] -struct ApiError { - code: String, - docs: String, - message: String, - #[serde(rename = "requestId")] - request_id: String, -} - -#[derive(Deserialize)] -struct RateLimitResponse { - limit: Option, - remaining: Option, - reset: Option, - success: Option, -} - -#[derive(Clone)] -struct UnkeyApiId(String); +mod handlers; +mod middleware; +mod models; impl From for String { fn from(api_id: UnkeyApiId) -> Self { @@ -64,136 +19,6 @@ impl From for String { } } -struct AppState { - unkey_client: UnkeyClient, - unkey_api_id: UnkeyApiId, -} - -async fn verify_key( - req: ServiceRequest, - next: Next, -) -> Result, Error> { - info!("Middleware start"); - - let headers = req.headers().clone(); // Clone the headers so req is not borrowed later - - info!("Headers: {:?}", headers); - - let data = req.app_data::>().unwrap(); - let client = req.app_data::>().unwrap(); - - let connection_info = req.connection_info().clone(); - let user_ip = connection_info.realip_remote_addr().unwrap_or("unknown").to_string(); - - let authorization_header = if let Some(header_value) = headers.get("Authorization") { - match header_value.to_str() { - Ok(value) if value.starts_with("Bearer ") => value.trim_start_matches("Bearer ").to_string(), - _ => { - return Err(actix_web::error::ErrorUnauthorized( - "Invalid Authorization header format", - )) - } - } - } else { - return Err(actix_web::error::ErrorUnauthorized("Authorization header missing")); - }; - - let verify_request = VerifyKeyRequest { - key: authorization_header.to_string(), - api_id: data.unkey_api_id.clone().into(), - }; - - match data.unkey_client.verify_key(verify_request).await { - Ok(res) if res.valid => { - let rate_limit_request = RateLimitRequest { - namespace: "test_protected".to_string(), // Namespace for the rate limit - identifier: user_ip, // Identifier for the rate limit - limit: 10, - duration: 60000, - cost: 2, - async_field: true, - meta: HashMap::new(), - resources: vec![], - }; - - let unkey_root_key = env::var("UNKEY_ROOT_KEY").expect("UNKEY_ROOT_KEY must be set"); - - let rate_limit_response = client - .post("https://api.unkey.dev/v1/ratelimits.limit") - .bearer_auth(unkey_root_key) - .header("Content-Type", "application/json") - .json(&rate_limit_request) - .send() - .await - .unwrap(); - - if rate_limit_response.status().is_success() { - let rate_limit_result = match rate_limit_response.json::().await { - Ok(response) => response, - Err(err) => { - log::error!("Failed to deserialize rate limit response: {:?}", err); - return Err(actix_web::error::ErrorInternalServerError( - "Failed to parse rate limit response", - )); - } - }; - - if let Some(remaining) = rate_limit_result.remaining { - if remaining > 0 { - // Rate limit passed, proceed to the next middleware or handler - let res = next.call(req).await?; - Ok(res) - } else { - log::info!("Rate limit exceeded. Resets at: {:?}", rate_limit_result.reset); - return Err(actix_web::error::ErrorTooManyRequests("Rate limit exceeded")); - } - } else { - log::error!("Rate limit response missing 'remaining' field"); - return Err(actix_web::error::ErrorInternalServerError( - "Invalid rate limit response", - )); - } - } else { - // Parse the error response - let error_response: ApiErrorResponse = rate_limit_response.json().await.map_err(|err| { - log::error!("Failed to parse error response: {:?}", err); - actix_web::error::ErrorInternalServerError("Failed to parse error response") - })?; - - // Log the error and return a meaningful error message to the user - log::error!( - "Rate limit request failed. Code: {}, Message: {}, Docs: {}, Request ID: {}", - error_response.error.code, - error_response.error.message, - error_response.error.docs, - error_response.error.request_id - ); - - return Err(actix_web::error::ErrorBadRequest(format!( - "Rate limit request failed: {} (Request ID: {})", - error_response.error.message, error_response.error.request_id - ))); - } - } - Ok(res) => { - error!("Key verification failed: {:?}", res); - Err(actix_web::error::ErrorUnauthorized("Key verification failed")) - } - Err(err) => { - error!("Key verification failed: {:?}", err); - Err(actix_web::error::ErrorUnauthorized("Key verification failed")) - } - } -} - -async fn public(_req: HttpRequest) -> String { - "Hello, world!".to_owned() -} - -async fn protected(_req: HttpRequest) -> String { - "Hello, protected world!".to_owned() -} - #[actix_web::main] async fn main() -> std::io::Result<()> { dotenv().ok(); diff --git a/src/middleware.rs b/src/middleware.rs new file mode 100644 index 0000000..a81abe3 --- /dev/null +++ b/src/middleware.rs @@ -0,0 +1,129 @@ +use actix_web::body::MessageBody; +use actix_web::dev::ServiceRequest; +use actix_web::dev::ServiceResponse; +use actix_web::middleware::Next; +use actix_web::{web, Error}; +use log::{error, info}; +use reqwest::Client; +use std::collections::HashMap; +use std::env; +use unkey::models::VerifyKeyRequest; + +use crate::models::{ApiErrorResponse, AppState, RateLimitRequest, RateLimitResponse}; + +pub async fn verify_key( + req: ServiceRequest, + next: Next, +) -> Result, Error> { + info!("Middleware start"); + + let headers = req.headers().clone(); // Clone the headers so req is not borrowed later + + info!("Headers: {:?}", headers); + + let data = req.app_data::>().unwrap(); + let client = req.app_data::>().unwrap(); + + let connection_info = req.connection_info().clone(); + let user_ip = connection_info.realip_remote_addr().unwrap_or("unknown").to_string(); + + let authorization_header = if let Some(header_value) = headers.get("Authorization") { + match header_value.to_str() { + Ok(value) if value.starts_with("Bearer ") => value.trim_start_matches("Bearer ").to_string(), + _ => { + return Err(actix_web::error::ErrorUnauthorized( + "Invalid Authorization header format", + )) + } + } + } else { + return Err(actix_web::error::ErrorUnauthorized("Authorization header missing")); + }; + + let verify_request = VerifyKeyRequest { + key: authorization_header.to_string(), + api_id: data.unkey_api_id.clone().into(), + }; + + match data.unkey_client.verify_key(verify_request).await { + Ok(res) if res.valid => { + let rate_limit_request = RateLimitRequest { + namespace: "test_protected".to_string(), // Namespace for the rate limit + identifier: user_ip, // Identifier for the rate limit + limit: 10, + duration: 60000, + cost: 2, + async_field: true, + meta: HashMap::new(), + resources: vec![], + }; + + let unkey_root_key = env::var("UNKEY_ROOT_KEY").expect("UNKEY_ROOT_KEY must be set"); + + let rate_limit_response = client + .post("https://api.unkey.dev/v1/ratelimits.limit") + .bearer_auth(unkey_root_key) + .header("Content-Type", "application/json") + .json(&rate_limit_request) + .send() + .await + .unwrap(); + + if rate_limit_response.status().is_success() { + let rate_limit_result = match rate_limit_response.json::().await { + Ok(response) => response, + Err(err) => { + log::error!("Failed to deserialize rate limit response: {:?}", err); + return Err(actix_web::error::ErrorInternalServerError( + "Failed to parse rate limit response", + )); + } + }; + + if let Some(remaining) = rate_limit_result.remaining { + if remaining > 0 { + // Rate limit passed, proceed to the next middleware or handler + let res = next.call(req).await?; + Ok(res) + } else { + log::info!("Rate limit exceeded. Resets at: {:?}", rate_limit_result.reset); + return Err(actix_web::error::ErrorTooManyRequests("Rate limit exceeded")); + } + } else { + log::error!("Rate limit response missing 'remaining' field"); + return Err(actix_web::error::ErrorInternalServerError( + "Invalid rate limit response", + )); + } + } else { + // Parse the error response + let error_response: ApiErrorResponse = rate_limit_response.json().await.map_err(|err| { + log::error!("Failed to parse error response: {:?}", err); + actix_web::error::ErrorInternalServerError("Failed to parse error response") + })?; + + // Log the error and return a meaningful error message to the user + log::error!( + "Rate limit request failed. Code: {}, Message: {}, Docs: {}, Request ID: {}", + error_response.error.code, + error_response.error.message, + error_response.error.docs, + error_response.error.request_id + ); + + return Err(actix_web::error::ErrorBadRequest(format!( + "Rate limit request failed: {} (Request ID: {})", + error_response.error.message, error_response.error.request_id + ))); + } + } + Ok(res) => { + error!("Key verification failed: {:?}", res); + Err(actix_web::error::ErrorUnauthorized("Key verification failed")) + } + Err(err) => { + error!("Key verification failed: {:?}", err); + Err(actix_web::error::ErrorUnauthorized("Key verification failed")) + } + } +} diff --git a/src/models.rs b/src/models.rs new file mode 100644 index 0000000..4e7f58e --- /dev/null +++ b/src/models.rs @@ -0,0 +1,55 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; +use unkey::Client as UnkeyClient; + +#[derive(Clone)] +pub struct UnkeyApiId(pub String); + +pub struct AppState { + pub unkey_client: UnkeyClient, + pub unkey_api_id: UnkeyApiId, +} + +#[derive(Serialize)] +pub struct RateLimitRequest { + pub namespace: String, + pub identifier: String, + pub limit: u32, + pub duration: u64, + pub cost: u32, + #[serde(rename = "async")] + pub async_field: bool, + pub meta: HashMap, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub resources: Vec, // Default to an empty array if no resources are provided +} + +#[derive(Serialize)] +pub struct Resource { + pub r#type: String, + pub id: String, + pub name: String, +} + +#[derive(Deserialize)] +pub struct ApiErrorResponse { + pub error: ApiError, +} + +#[derive(Deserialize)] +pub struct ApiError { + pub code: String, + pub docs: String, + pub message: String, + #[serde(rename = "requestId")] + pub request_id: String, +} + +#[derive(Deserialize)] +pub struct RateLimitResponse { + pub limit: Option, + pub remaining: Option, + pub reset: Option, + pub success: Option, +}