extractors for jwt and user id

This commit is contained in:
2024-12-10 14:48:03 -08:00
parent ef722fe0d8
commit 0ecd7c8a4c
11 changed files with 168 additions and 25 deletions

4
Cargo.lock generated
View File

@@ -2084,6 +2084,10 @@ name = "uuid"
version = "1.11.0" version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a"
dependencies = [
"getrandom",
"serde",
]
[[package]] [[package]]
name = "vcpkg" name = "vcpkg"

View File

@@ -20,7 +20,7 @@ sqlx = { version = "0.8.2", features = [
"macros", "macros",
"uuid", "uuid",
] } ] }
uuid = "1.11.0" uuid = { version = "1.4.1", features = ["serde", "v4"] }
chrono = "0.4.39" chrono = "0.4.39"
dotenv = "0.15.0" dotenv = "0.15.0"
argon2 = "0.5.3" argon2 = "0.5.3"

44
src/extractors/jwt.rs Normal file
View File

@@ -0,0 +1,44 @@
use std::sync::Arc;
use crate::*;
use anyhow::bail;
use axum::{
async_trait,
extract::{FromRef, FromRequestParts},
http::{header, request::Parts},
};
use chrono::{DateTime, Utc};
use jwt::VerifyWithKey;
use sqlx::types::Uuid;
use util::auth::JWTClaims;
pub struct JWT(JWTClaims);
#[async_trait]
impl<S> FromRequestParts<S> for JWT
where
AppState: FromRef<S>,
S: Send + Sync,
{
type Rejection = AppError;
async fn from_request_parts(parts: &mut Parts, s: &S) -> Result<Self, Self::Rejection> {
let state = AppState::from_ref(s);
let jwt_string = parts
.headers
.get(header::AUTHORIZATION)
.ok_or(AppError::Error(Errors::Unauthorized))?
.to_str()
.map_err(|_| AppError::Error(Errors::Unauthorized))?
.strip_prefix("Bearer ")
.ok_or(AppError::Error(Errors::Unauthorized))?;
let claims: JWTClaims = jwt_string
.verify_with_key(&state.jwt_key)
.map_err(|_| AppError::Error(Errors::Unauthorized))?;
Ok(JWT(claims))
}
}

2
src/extractors/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod users;
pub mod jwt;

44
src/extractors/users.rs Normal file
View File

@@ -0,0 +1,44 @@
use std::sync::Arc;
use crate::*;
use anyhow::bail;
use axum::{
async_trait,
extract::{FromRef, FromRequestParts},
http::{header, request::Parts},
};
use chrono::{DateTime, Utc};
use jwt::VerifyWithKey;
use sqlx::types::Uuid;
use util::auth::JWTClaims;
pub struct UserId(Uuid);
#[async_trait]
impl<S> FromRequestParts<S> for UserId
where
AppState: FromRef<S>,
S: Send + Sync,
{
type Rejection = AppError;
async fn from_request_parts(parts: &mut Parts, s: &S) -> Result<Self, Self::Rejection> {
let state = AppState::from_ref(s);
let jwt_string = parts
.headers
.get(header::AUTHORIZATION)
.ok_or(AppError::Error(Errors::Unauthorized))?
.to_str()
.map_err(|_| AppError::Error(Errors::Unauthorized))?
.strip_prefix("Bearer ")
.ok_or(AppError::Error(Errors::Unauthorized))?;
let claims: JWTClaims = jwt_string
.verify_with_key(&state.jwt_key)
.map_err(|_| AppError::Error(Errors::Unauthorized))?;
Ok(UserId(claims.sub))
}
}

View File

@@ -1,4 +1,5 @@
use std::{net::Ipv4Addr, sync::Arc}; use std::{net::Ipv4Addr, sync::Arc};
use hmac::{Hmac, Mac};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_axum::{router::OpenApiRouter, routes}; use utoipa_axum::{router::OpenApiRouter, routes};
@@ -12,6 +13,8 @@ mod structs;
mod state; mod state;
mod db; mod db;
mod error; mod error;
mod util;
mod extractors;
pub(crate) use anyhow::Context; pub(crate) use anyhow::Context;
pub(crate) use axum::extract::{Json, State}; pub(crate) use axum::extract::{Json, State};
@@ -60,7 +63,9 @@ async fn main() -> anyhow::Result<()> {
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "secret".to_string()); let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "secret".to_string());
let db = db::db().await?; let db = db::db().await?;
let state = state::AppState { db: Arc::new(db), jwt_secret }; let state = state::AppState { db: Arc::new(db),
jwt_key: Hmac::new_from_slice(jwt_secret.as_bytes()).context("Failed to create HMAC")?
};
let (router, api) = OpenApiRouter::with_openapi(ApiDoc::openapi()) let (router, api) = OpenApiRouter::with_openapi(ApiDoc::openapi())
.routes(routes!(health_check)) .routes(routes!(health_check))

View File

@@ -1,9 +1,13 @@
use std::sync::Arc; use std::sync::Arc;
use hmac::Hmac;
use sha2::Sha256;
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
pub db: DB, pub db: DB,
pub jwt_secret: String, // pub jwt_secret: String,
pub jwt_key: Hmac<Sha256>,
} }
pub type DB = Arc<sqlx::Pool<sqlx::Postgres>>; pub type DB = Arc<sqlx::Pool<sqlx::Postgres>>;

49
src/util/auth.rs Normal file
View File

@@ -0,0 +1,49 @@
use chrono::Utc;
use uuid::Uuid;
use crate::*;
#[derive(Serialize, Deserialize)]
pub struct JWTClaims {
pub sub: Uuid,
pub iat: i64,
pub exp: i64,
pub username: String,
pub real_name: String,
pub email: String,
}
impl JWTClaims {
pub fn new(sub: Uuid, username: String, real_name: String, email: String) -> Self {
let iat = Utc::now().timestamp();
let exp = iat + 60 * 60 * 24 * 7;
Self {
sub,
iat,
exp,
username,
real_name,
email,
}
}
/// Create a new JWT Claims with a custom expiration time. Expiration time is added to current time.
pub fn new_with_exp(
sub: Uuid,
username: String,
real_name: String,
email: String,
exp: i64,
) -> Self {
let iat = Utc::now().timestamp();
Self {
sub,
iat,
exp: iat + exp,
username,
real_name,
email,
}
}
}

1
src/util/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod auth;

View File

@@ -3,11 +3,9 @@ use argon2::{
password_hash::{PasswordHash, PasswordVerifier}, password_hash::{PasswordHash, PasswordVerifier},
Argon2, Argon2,
}; };
use hmac::{Hmac, Mac};
use jwt::SignWithKey; use jwt::SignWithKey;
use sha2::Sha256;
use sqlx::query; use sqlx::query;
use std::collections::BTreeMap; use util::auth::JWTClaims;
#[derive(Serialize, Deserialize, ToSchema)] #[derive(Serialize, Deserialize, ToSchema)]
pub struct LoginBody { pub struct LoginBody {
@@ -40,14 +38,11 @@ pub async fn login(
return Err(AppError::Error(Errors::Unauthorized)); return Err(AppError::Error(Errors::Unauthorized));
} }
let key: Hmac<Sha256> = let claims = JWTClaims::new(user.id, user.username, user.real_name, user.email);
Hmac::new_from_slice(state.jwt_secret.as_bytes()).context("Failed to create HMAC")?;
let mut claims = BTreeMap::new(); let token_str = claims
claims.insert("id", user.id.to_string()); .sign_with_key(&state.jwt_key)
claims.insert("username", user.username); .context("Failed to sign JWT")?;
claims.insert("real_name", user.real_name);
claims.insert("email", user.email);
let token_str = claims.sign_with_key(&key).context("Failed to sign JWT")?;
Ok(token_str) Ok(token_str)
} }

View File

@@ -6,10 +6,8 @@ use argon2::{
Argon2, Argon2,
}; };
use hmac::{Hmac, Mac};
use jwt::SignWithKey; use jwt::SignWithKey;
use sha2::Sha256; use util::auth::JWTClaims;
use std::collections::BTreeMap;
#[derive(Serialize, Deserialize, ToSchema)] #[derive(Serialize, Deserialize, ToSchema)]
pub struct SignupBody { pub struct SignupBody {
@@ -45,14 +43,11 @@ pub async fn signup(
.fetch_one(&*state.db) .fetch_one(&*state.db)
.await?; .await?;
let key: Hmac<Sha256> = let claims = JWTClaims::new(user.id, user.username, user.real_name, user.email);
Hmac::new_from_slice(state.jwt_secret.as_bytes()).context("Failed to create HMAC")?;
let mut claims = BTreeMap::new(); let token_str = claims
claims.insert("id", user.id.to_string()); .sign_with_key(&state.jwt_key)
claims.insert("username", user.username); .context("Failed to sign JWT")?;
claims.insert("real_name", user.real_name);
claims.insert("email", user.email);
let token_str = claims.sign_with_key(&key).context("Failed to sign JWT")?;
Ok(token_str) Ok(token_str)
} }