diff --git a/Cargo.lock b/Cargo.lock index 9e442bd..d66f6e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2084,6 +2084,10 @@ name = "uuid" version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" +dependencies = [ + "getrandom", + "serde", +] [[package]] name = "vcpkg" diff --git a/Cargo.toml b/Cargo.toml index c55ea94..918761b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ sqlx = { version = "0.8.2", features = [ "macros", "uuid", ] } -uuid = "1.11.0" +uuid = { version = "1.4.1", features = ["serde", "v4"] } chrono = "0.4.39" dotenv = "0.15.0" argon2 = "0.5.3" diff --git a/src/extractors/jwt.rs b/src/extractors/jwt.rs new file mode 100644 index 0000000..681423b --- /dev/null +++ b/src/extractors/jwt.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; + +use crate::*; +use anyhow::bail; +use axum::{ + async_trait, + extract::{FromRef, FromRequestParts}, + http::{header, request::Parts}, +}; + +use chrono::{DateTime, Utc}; +use jwt::VerifyWithKey; +use sqlx::types::Uuid; +use util::auth::JWTClaims; + +pub struct JWT(JWTClaims); + +#[async_trait] +impl FromRequestParts for JWT +where + AppState: FromRef, + S: Send + Sync, +{ + type Rejection = AppError; + + async fn from_request_parts(parts: &mut Parts, s: &S) -> Result { + let state = AppState::from_ref(s); + + let jwt_string = parts + .headers + .get(header::AUTHORIZATION) + .ok_or(AppError::Error(Errors::Unauthorized))? + .to_str() + .map_err(|_| AppError::Error(Errors::Unauthorized))? + .strip_prefix("Bearer ") + .ok_or(AppError::Error(Errors::Unauthorized))?; + + let claims: JWTClaims = jwt_string + .verify_with_key(&state.jwt_key) + .map_err(|_| AppError::Error(Errors::Unauthorized))?; + + Ok(JWT(claims)) + } +} diff --git a/src/extractors/mod.rs b/src/extractors/mod.rs new file mode 100644 index 0000000..bd01805 --- /dev/null +++ b/src/extractors/mod.rs @@ -0,0 +1,2 @@ +pub mod users; +pub mod jwt; diff --git a/src/extractors/users.rs b/src/extractors/users.rs new file mode 100644 index 0000000..968479a --- /dev/null +++ b/src/extractors/users.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; + +use crate::*; +use anyhow::bail; +use axum::{ + async_trait, + extract::{FromRef, FromRequestParts}, + http::{header, request::Parts}, +}; + +use chrono::{DateTime, Utc}; +use jwt::VerifyWithKey; +use sqlx::types::Uuid; +use util::auth::JWTClaims; + +pub struct UserId(Uuid); + +#[async_trait] +impl FromRequestParts for UserId +where + AppState: FromRef, + S: Send + Sync, +{ + type Rejection = AppError; + + async fn from_request_parts(parts: &mut Parts, s: &S) -> Result { + let state = AppState::from_ref(s); + + let jwt_string = parts + .headers + .get(header::AUTHORIZATION) + .ok_or(AppError::Error(Errors::Unauthorized))? + .to_str() + .map_err(|_| AppError::Error(Errors::Unauthorized))? + .strip_prefix("Bearer ") + .ok_or(AppError::Error(Errors::Unauthorized))?; + + let claims: JWTClaims = jwt_string + .verify_with_key(&state.jwt_key) + .map_err(|_| AppError::Error(Errors::Unauthorized))?; + + Ok(UserId(claims.sub)) + } +} diff --git a/src/main.rs b/src/main.rs index 4b2925f..400d047 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ use std::{net::Ipv4Addr, sync::Arc}; +use hmac::{Hmac, Mac}; use tokio::net::TcpListener; use utoipa::OpenApi; use utoipa_axum::{router::OpenApiRouter, routes}; @@ -12,6 +13,8 @@ mod structs; mod state; mod db; mod error; +mod util; +mod extractors; pub(crate) use anyhow::Context; pub(crate) use axum::extract::{Json, State}; @@ -60,7 +63,9 @@ async fn main() -> anyhow::Result<()> { let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "secret".to_string()); let db = db::db().await?; - let state = state::AppState { db: Arc::new(db), jwt_secret }; + let state = state::AppState { db: Arc::new(db), + jwt_key: Hmac::new_from_slice(jwt_secret.as_bytes()).context("Failed to create HMAC")? + }; let (router, api) = OpenApiRouter::with_openapi(ApiDoc::openapi()) .routes(routes!(health_check)) diff --git a/src/state.rs b/src/state.rs index b8a0a09..fdf0f99 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,9 +1,13 @@ use std::sync::Arc; +use hmac::Hmac; +use sha2::Sha256; + #[derive(Clone)] pub struct AppState { pub db: DB, - pub jwt_secret: String, + // pub jwt_secret: String, + pub jwt_key: Hmac, } pub type DB = Arc>; diff --git a/src/util/auth.rs b/src/util/auth.rs new file mode 100644 index 0000000..7ab137f --- /dev/null +++ b/src/util/auth.rs @@ -0,0 +1,49 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::*; + +#[derive(Serialize, Deserialize)] +pub struct JWTClaims { + pub sub: Uuid, + pub iat: i64, + pub exp: i64, + + pub username: String, + pub real_name: String, + pub email: String, +} + +impl JWTClaims { + pub fn new(sub: Uuid, username: String, real_name: String, email: String) -> Self { + let iat = Utc::now().timestamp(); + let exp = iat + 60 * 60 * 24 * 7; + Self { + sub, + iat, + exp, + username, + real_name, + email, + } + } + + /// Create a new JWT Claims with a custom expiration time. Expiration time is added to current time. + pub fn new_with_exp( + sub: Uuid, + username: String, + real_name: String, + email: String, + exp: i64, + ) -> Self { + let iat = Utc::now().timestamp(); + Self { + sub, + iat, + exp: iat + exp, + username, + real_name, + email, + } + } +} diff --git a/src/util/mod.rs b/src/util/mod.rs new file mode 100644 index 0000000..0e4a05d --- /dev/null +++ b/src/util/mod.rs @@ -0,0 +1 @@ +pub mod auth; diff --git a/src/v1/auth/login.rs b/src/v1/auth/login.rs index 07e7607..6d7dd0a 100644 --- a/src/v1/auth/login.rs +++ b/src/v1/auth/login.rs @@ -3,11 +3,9 @@ use argon2::{ password_hash::{PasswordHash, PasswordVerifier}, Argon2, }; -use hmac::{Hmac, Mac}; use jwt::SignWithKey; -use sha2::Sha256; use sqlx::query; -use std::collections::BTreeMap; +use util::auth::JWTClaims; #[derive(Serialize, Deserialize, ToSchema)] pub struct LoginBody { @@ -40,14 +38,11 @@ pub async fn login( return Err(AppError::Error(Errors::Unauthorized)); } - let key: Hmac = - Hmac::new_from_slice(state.jwt_secret.as_bytes()).context("Failed to create HMAC")?; - let mut claims = BTreeMap::new(); - claims.insert("id", user.id.to_string()); - claims.insert("username", user.username); - claims.insert("real_name", user.real_name); - claims.insert("email", user.email); - let token_str = claims.sign_with_key(&key).context("Failed to sign JWT")?; + let claims = JWTClaims::new(user.id, user.username, user.real_name, user.email); + + let token_str = claims + .sign_with_key(&state.jwt_key) + .context("Failed to sign JWT")?; Ok(token_str) } diff --git a/src/v1/auth/signup.rs b/src/v1/auth/signup.rs index f2406f2..0d60568 100644 --- a/src/v1/auth/signup.rs +++ b/src/v1/auth/signup.rs @@ -6,10 +6,8 @@ use argon2::{ Argon2, }; -use hmac::{Hmac, Mac}; use jwt::SignWithKey; -use sha2::Sha256; -use std::collections::BTreeMap; +use util::auth::JWTClaims; #[derive(Serialize, Deserialize, ToSchema)] pub struct SignupBody { @@ -45,14 +43,11 @@ pub async fn signup( .fetch_one(&*state.db) .await?; - let key: Hmac = - Hmac::new_from_slice(state.jwt_secret.as_bytes()).context("Failed to create HMAC")?; - let mut claims = BTreeMap::new(); - claims.insert("id", user.id.to_string()); - claims.insert("username", user.username); - claims.insert("real_name", user.real_name); - claims.insert("email", user.email); - let token_str = claims.sign_with_key(&key).context("Failed to sign JWT")?; + let claims = JWTClaims::new(user.id, user.username, user.real_name, user.email); + + let token_str = claims + .sign_with_key(&state.jwt_key) + .context("Failed to sign JWT")?; Ok(token_str) }