diff options
| author | mrw1593 <botahamec@outlook.com> | 2023-06-30 19:27:33 -0400 |
|---|---|---|
| committer | mrw1593 <botahamec@outlook.com> | 2023-06-30 19:27:33 -0400 |
| commit | 55cfb8187cb814e17a2a99d02bfd9296fc01dcc2 (patch) | |
| tree | c5f7ed60c8a814addd60b1cfb843fb9a107f1458 | |
| parent | 9058b01d6c0e3d1e9e485a537258a312ccfc841c (diff) | |
Added config file
| -rw-r--r-- | Cargo.lock | 102 | ||||
| -rw-r--r-- | Cargo.toml | 2 | ||||
| -rw-r--r-- | src/api/oauth.rs | 28 | ||||
| -rw-r--r-- | src/main.rs | 17 | ||||
| -rw-r--r-- | src/services/config.rs | 74 | ||||
| -rw-r--r-- | src/services/jwt.rs | 20 | ||||
| -rw-r--r-- | src/services/mod.rs | 1 | ||||
| -rw-r--r-- | static/config/dev.toml | 0 | ||||
| -rw-r--r-- | static/config/local.toml | 5 | ||||
| -rw-r--r-- | static/config/prod.toml | 0 | ||||
| -rw-r--r-- | static/config/staging.toml | 0 |
11 files changed, 224 insertions, 25 deletions
@@ -358,6 +358,26 @@ dependencies = [ ] [[package]] +name = "bpaf" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bac242287491ba761f8943d48c2b3eca2b30485187a7a13fa6b2168c058f342" +dependencies = [ + "bpaf_derive", +] + +[[package]] +name = "bpaf_derive" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af3c1dc174c8c49192fe1553cb25f75ba410a4b26b2bf5ca620307579e9ca078" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.2", +] + +[[package]] name = "brotli" version = "3.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -691,6 +711,12 @@ dependencies = [ ] [[package]] +name = "equivalent" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88bffebc5d80432c9b140ee17875ff173a8ab62faad5b257da912bd2f6c1c0a1" + +[[package]] name = "event-listener" version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -841,7 +867,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e4feeef87d958eebd4d55431040768b93a5b088202198e0b203adc3c1d468c6" dependencies = [ "codemap", - "indexmap", + "indexmap 1.9.2", "lasso", "once_cell", "phf", @@ -860,7 +886,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 1.9.2", "slab", "tokio", "tokio-util", @@ -886,6 +912,12 @@ dependencies = [ ] [[package]] +name = "hashbrown" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" + +[[package]] name = "hashlink" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1029,6 +1061,16 @@ dependencies = [ ] [[package]] +name = "indexmap" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +dependencies = [ + "equivalent", + "hashbrown 0.14.0", +] + +[[package]] name = "instant" version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1671,6 +1713,7 @@ dependencies = [ "actix-rt", "actix-web", "base64 0.21.0", + "bpaf", "chrono", "dotenv", "exun", @@ -1693,6 +1736,7 @@ dependencies = [ "sqlx", "tera", "thiserror", + "toml", "unic-langid", "url", "uuid", @@ -1797,6 +1841,15 @@ dependencies = [ ] [[package]] +name = "serde_spanned" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96426c9936fd7a0124915f9185ea1d20aa9445cc9821142f0a73bc9207a2e186" +dependencies = [ + "serde", +] + +[[package]] name = "serde_urlencoded" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1950,7 +2003,7 @@ dependencies = [ "generic-array", "hashlink", "hex", - "indexmap", + "indexmap 1.9.2", "itoa", "libc", "log", @@ -2233,6 +2286,40 @@ dependencies = [ ] [[package]] +name = "toml" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebafdf5ad1220cb59e7d17cf4d2c72015297b75b19a10472f99b89225089240" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.19.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "266f016b7f039eec8a1a80dfe6156b633d208b9fccca5e4db1d6775b0c4e34a7" +dependencies = [ + "indexmap 2.0.0", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + +[[package]] name = "tracing" version = "0.1.37" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2692,6 +2779,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" [[package]] +name = "winnow" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca0ace3845f0d96209f0375e6d367e3eb87eb65d27d445bdc9f1843a26f39448" +dependencies = [ + "memchr", +] + +[[package]] name = "zeroize" version = "1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -28,7 +28,9 @@ grass = "0.12" sha2 = "0.10" unic-langid = { version = "0.9", features = ["serde"] } rand = "0.8" +bpaf = { version = "0.8", features = ["derive"] } serde_urlencoded = "0.7" +toml = { version = "0.7", features = ["parse"] } sqlx = { version = "0.6", features = [ "runtime-actix-rustls", "mysql", "uuid", "chrono", "offline" ] } log = "0.4" chrono = { version = "0.4", features = ["serde"] } diff --git a/src/api/oauth.rs b/src/api/oauth.rs index ef40637..fe1c361 100644 --- a/src/api/oauth.rs +++ b/src/api/oauth.rs @@ -20,7 +20,7 @@ use crate::models::client::ClientType; use crate::resources::{languages, templates}; use crate::scopes; use crate::services::jwt::VerifyJwtError; -use crate::services::{authorization, db, jwt}; +use crate::services::{authorization, config, db, jwt}; const REALLY_BAD_ERROR_PAGE: &str = "<!DOCTYPE html><html><head><title>Internal Server Error</title></head><body>Internal Server Error</body></html>"; @@ -243,7 +243,12 @@ async fn authorize( let page = error_page(&tera, &translations, templates::ErrorPage::ClientNotFound).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); return HttpResponse::NotFound().content_type("text/html").body(page); }; - let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value + let Ok(config) = config::get_config() else { + let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); + return HttpResponse::InternalServerError().content_type("text/html").body(page); + }; + + let self_id = config.id; let state = req.state.clone(); // get redirect uri @@ -284,7 +289,7 @@ async fn authorize( match req.response_type { ResponseType::Code => { // create auth code - let code = jwt::Claims::auth_code(db, self_id, client_id, &scope, &redirect_uri) + let code = jwt::Claims::auth_code(db, &self_id, client_id, &scope, &redirect_uri) .await .unwrap(); let code = code.to_jwt().unwrap(); @@ -302,7 +307,7 @@ async fn authorize( // create access token let duration = Duration::hours(1); let access_token = - jwt::Claims::access_token(db, None, self_id, client_id, duration, &scope) + jwt::Claims::access_token(db, None, &self_id, client_id, duration, &scope) .await .unwrap(); @@ -628,8 +633,9 @@ async fn token( let Ok(request) = request else { return TokenError::invalid_request().error_response(); }; + let config = config::get_config().unwrap(); - let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value + let self_id = config.id; let duration = Duration::hours(1); let token_type = Box::from("bearer"); let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]); @@ -646,9 +652,7 @@ async fn token( // validate auth code let claims = - match jwt::verify_auth_code(db, &code, self_id.clone(), client_id, redirect_uri) - .await - { + match jwt::verify_auth_code(db, &code, &self_id, client_id, redirect_uri).await { Ok(claims) => claims, Err(err) => { let err = err.unwrap(); @@ -673,7 +677,7 @@ async fn token( let access_token = jwt::Claims::access_token( db, Some(claims.id()), - self_id, + &self_id, client_id, duration, claims.scopes(), @@ -751,7 +755,7 @@ async fn token( } let access_token = - jwt::Claims::access_token(db, None, self_id, client_id, duration, &scope) + jwt::Claims::access_token(db, None, &self_id, client_id, duration, &scope) .await .unwrap(); let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap(); @@ -815,7 +819,7 @@ async fn token( } let access_token = - jwt::Claims::access_token(db, None, self_id, client_id, duration, &scope) + jwt::Claims::access_token(db, None, &self_id, client_id, duration, &scope) .await .unwrap(); @@ -851,7 +855,7 @@ async fn token( } let claims = - match jwt::verify_refresh_token(db, &refresh_token, self_id, client_id).await { + match jwt::verify_refresh_token(db, &refresh_token, &self_id, client_id).await { Ok(claims) => claims, Err(e) => { let e = e.unwrap(); diff --git a/src/main.rs b/src/main.rs index da740be..e946161 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ use actix_web::middleware::{ErrorHandlerResponse, ErrorHandlers, Logger, Normali use actix_web::web::Data; use actix_web::{dev, App, HttpServer}; +use bpaf::Bpaf; use exun::*; mod api; @@ -44,12 +45,28 @@ async fn delete_expired_tokens(db: MySqlPool) { } } +#[derive(Debug, Clone, Bpaf)] +#[bpaf(options, version)] +struct Opts { + /// The environment that the server is running in. Must be one of: local, + /// dev, staging, prod. + #[bpaf( + env("LOCKDAGGER_ENVIRONMENT"), + fallback(config::Environment::Local), + display_fallback + )] + env: config::Environment, +} + #[actix_web::main] async fn main() -> Result<(), RawUnexpected> { // load the environment file, but only in debug mode #[cfg(debug_assertions)] dotenv::dotenv()?; + let args = opts().run(); + config::set_environment(args.env); + // initialize the database let db_url = secrets::database_url()?; let sql_pool = db::initialize(&db_url).await?; diff --git a/src/services/config.rs b/src/services/config.rs new file mode 100644 index 0000000..6468126 --- /dev/null +++ b/src/services/config.rs @@ -0,0 +1,74 @@ +use std::{ + fmt::{self, Display}, + str::FromStr, +}; + +use exun::RawUnexpected; +use parking_lot::RwLock; +use serde::Deserialize; +use thiserror::Error; +use url::Url; + +static ENVIRONMENT: RwLock<Environment> = RwLock::new(Environment::Local); + +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + pub id: Box<str>, + pub url: Url, +} + +pub fn get_config() -> Result<Config, RawUnexpected> { + let env = get_environment(); + let path = format!("static/config/{env}.toml"); + let string = std::fs::read_to_string(path)?; + let config = toml::from_str(&string)?; + Ok(config) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Environment { + Local, + Dev, + Staging, + Production, +} + +impl Display for Environment { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Local => f.write_str("local"), + Self::Dev => f.write_str("dev"), + Self::Staging => f.write_str("staging"), + Self::Production => f.write_str("prod"), + } + } +} + +#[derive(Debug, Clone, Error)] +#[error("Expected one of the following environments: local, dev, staging, prod. Found {string}")] +pub struct ParseEnvironmentError { + string: Box<str>, +} + +impl FromStr for Environment { + type Err = ParseEnvironmentError; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s { + "local" => Ok(Self::Local), + "dev" => Ok(Self::Dev), + "staging" => Ok(Self::Staging), + "prod" => Ok(Self::Production), + _ => Err(ParseEnvironmentError { string: s.into() }), + } + } +} + +pub fn set_environment(env: Environment) { + let mut env_ptr = ENVIRONMENT.write(); + *env_ptr = env; +} + +fn get_environment() -> Environment { + ENVIRONMENT.read().clone() +} diff --git a/src/services/jwt.rs b/src/services/jwt.rs index 488e0ac..86252c4 100644 --- a/src/services/jwt.rs +++ b/src/services/jwt.rs @@ -19,7 +19,7 @@ pub enum TokenType { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Claims { - iss: Url, + iss: Box<str>, aud: Option<Box<[String]>>, #[serde(with = "ts_milliseconds")] exp: DateTime<Utc>, @@ -45,7 +45,7 @@ pub enum RevokedRefreshTokenReason { impl Claims { pub async fn auth_code<'c>( db: &MySqlPool, - self_id: Url, + self_id: &str, client_id: Uuid, scopes: &str, redirect_uri: &Url, @@ -59,7 +59,7 @@ impl Claims { db::create_auth_code(db, id, exp).await?; Ok(Self { - iss: self_id, + iss: Box::from(self_id), aud: None, exp, nbf: None, @@ -76,7 +76,7 @@ impl Claims { pub async fn access_token<'c>( db: &MySqlPool, auth_code_id: Option<Uuid>, - self_id: Url, + self_id: &str, client_id: Uuid, duration: Duration, scopes: &str, @@ -90,7 +90,7 @@ impl Claims { .unexpect()?; Ok(Self { - iss: self_id, + iss: Box::from(self_id), aud: None, exp, nbf: None, @@ -186,7 +186,7 @@ pub enum VerifyJwtError { fn verify_jwt( token: &str, - self_id: Url, + self_id: &str, client_id: Option<Uuid>, ) -> Result<Claims, Expect<VerifyJwtError>> { let key = secrets::signing_key()?; @@ -194,7 +194,7 @@ fn verify_jwt( .verify_with_key(&key) .map_err(|e| VerifyJwtError::from(e))?; - if claims.iss != self_id { + if claims.iss != self_id.into() { yeet!(VerifyJwtError::IncorrectIssuer.into()) } @@ -228,7 +228,7 @@ fn verify_jwt( pub async fn verify_auth_code<'c>( db: &MySqlPool, token: &str, - self_id: Url, + self_id: &str, client_id: Uuid, redirect_uri: Url, ) -> Result<Claims, Expect<VerifyJwtError>> { @@ -252,7 +252,7 @@ pub async fn verify_auth_code<'c>( pub async fn verify_access_token<'c>( db: impl Executor<'c, Database = MySql>, token: &str, - self_id: Url, + self_id: &str, client_id: Uuid, ) -> Result<Claims, Expect<VerifyJwtError>> { let claims = verify_jwt(token, self_id, Some(client_id))?; @@ -267,7 +267,7 @@ pub async fn verify_access_token<'c>( pub async fn verify_refresh_token<'c>( db: impl Executor<'c, Database = MySql>, token: &str, - self_id: Url, + self_id: &str, client_id: Option<Uuid>, ) -> Result<Claims, Expect<VerifyJwtError>> { let claims = verify_jwt(token, self_id, client_id)?; diff --git a/src/services/mod.rs b/src/services/mod.rs index 5339594..de08b58 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -1,4 +1,5 @@ pub mod authorization; +pub mod config; pub mod crypto; pub mod db; pub mod id; diff --git a/static/config/dev.toml b/static/config/dev.toml new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/static/config/dev.toml diff --git a/static/config/local.toml b/static/config/local.toml new file mode 100644 index 0000000..ed6f9d0 --- /dev/null +++ b/static/config/local.toml @@ -0,0 +1,5 @@ +# used to identify the issuer of JWTs +self_id = "LockDagger" + +# The URL which the server is hosted on +url = "http://localhost:8080" diff --git a/static/config/prod.toml b/static/config/prod.toml new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/static/config/prod.toml diff --git a/static/config/staging.toml b/static/config/staging.toml new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/static/config/staging.toml |
