summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/api/oauth.rs304
-rw-r--r--src/main.rs2
-rw-r--r--src/services/db/client.rs44
-rw-r--r--src/services/jwt.rs54
4 files changed, 368 insertions, 36 deletions
diff --git a/src/api/oauth.rs b/src/api/oauth.rs
index d77695e..920f488 100644
--- a/src/api/oauth.rs
+++ b/src/api/oauth.rs
@@ -1,29 +1,35 @@
+use std::ops::Deref;
use std::str::FromStr;
-use actix_web::{get, post, web, HttpResponse, Scope};
+use actix_web::http::header;
+use actix_web::{get, post, web, HttpRequest, HttpResponse, ResponseError, Scope};
+use chrono::Duration;
use serde::{Deserialize, Serialize};
use sqlx::MySqlPool;
use tera::Tera;
+use thiserror::Error;
use unic_langid::subtags::Language;
use url::Url;
-use uuid::Uuid;
use crate::resources::{languages, templates};
-use crate::services::{authorization, db};
+use crate::scopes;
+use crate::services::{authorization, db, jwt};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ResponseType {
Code,
Token,
+ #[serde(other)]
+ Unsupported,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorizationParameters {
response_type: ResponseType,
- client_id: Uuid,
+ client_id: Box<str>,
redirect_uri: Option<Url>,
- scope: String, // TODO lol no
+ scope: Option<Box<str>>,
state: Option<Box<str>>,
}
@@ -33,14 +39,127 @@ struct AuthorizeCredentials {
password: Box<str>,
}
+#[derive(Clone, Serialize)]
+struct CodeResponse {
+ code: Box<str>,
+ state: Option<Box<str>>,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
+#[serde(rename_all = "camelCase")]
+enum AuthorizeErrorType {
+ InvalidRequest,
+ UnauthorizedClient,
+ AccessDenied,
+ UnsupportedResponseType,
+ InvalidScope,
+ ServerError,
+ TemporarilyUnavailable,
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("{error_description}")]
+struct AuthorizeError {
+ error: AuthorizeErrorType,
+ error_description: Box<str>,
+ // TODO error uri
+ state: Option<Box<str>>,
+ redirect_uri: Url,
+}
+
+impl AuthorizeError {
+ fn no_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::InvalidScope,
+ error_description: Box::from(
+ "No scope was provided, and the client does not have a default scope",
+ ),
+ state,
+ redirect_uri,
+ }
+ }
+}
+
+impl ResponseError for AuthorizeError {
+ fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
+ let error = serde_variant::to_variant_name(&self.error).unwrap_or_default();
+ let mut url = self.redirect_uri.clone();
+ url.query_pairs_mut()
+ .append_pair("error", error)
+ .append_pair("error_description", &self.error_description);
+
+ if let Some(state) = &self.state {
+ url.query_pairs_mut().append_pair("state", &state);
+ }
+
+ HttpResponse::Found()
+ .insert_header((header::LOCATION, url.as_str()))
+ .finish()
+ }
+}
+
#[post("/authorize")]
async fn authorize(
db: web::Data<MySqlPool>,
- query: web::Query<AuthorizationParameters>,
- credentials: web::Form<AuthorizeCredentials>,
+ req: web::Query<AuthorizationParameters>,
+ credentials: web::Json<AuthorizeCredentials>,
) -> HttpResponse {
- // TODO check that the URI is valid
- todo!()
+ // TODO use sessions to verify that the request was previously validated
+ let db = db.get_ref();
+ let Some(client_id) = db::get_client_id_by_alias(db, &req.client_id).await.unwrap() else {
+ todo!("client not found")
+ };
+ let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value
+ let state = req.state.clone();
+
+ // get redirect uri
+ let redirect_uri = if let Some(redirect_uri) = &req.redirect_uri {
+ redirect_uri.clone()
+ } else {
+ let redirect_uris = db::get_client_redirect_uris(db, client_id).await.unwrap();
+ if redirect_uris.len() != 1 {
+ todo!("no redirect uri");
+ }
+
+ redirect_uris[0].clone()
+ };
+
+ // authenticate user
+ let Some(user) = db::get_user_by_username(db, &credentials.username).await.unwrap() else {
+ todo!("bad username")
+ };
+ if !user.check_password(&credentials.password).unwrap() {
+ todo!("bad password")
+ }
+
+ // get scope
+ let scope = if let Some(scope) = &req.scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ return AuthorizeError::no_scope(redirect_uri, state).error_response()
+ };
+ scope
+ };
+
+ match req.response_type {
+ ResponseType::Code => {
+ // create auth code
+ let code = jwt::Claims::auth_code(db, self_id, client_id, &scope, &redirect_uri)
+ .await
+ .unwrap();
+ let code = code.to_jwt().unwrap();
+ let response = CodeResponse { code, state };
+
+ HttpResponse::Ok().json(response)
+ }
+ ResponseType::Token => todo!(),
+ _ => todo!("unsupported response type"),
+ }
}
#[get("/authorize")]
@@ -48,36 +167,187 @@ async fn authorize_page(
db: web::Data<MySqlPool>,
tera: web::Data<Tera>,
translations: web::Data<languages::Translations>,
- query: web::Query<AuthorizationParameters>,
+ request: HttpRequest,
) -> HttpResponse {
+ let params = request.query_string();
+ let params = serde_urlencoded::from_str::<AuthorizationParameters>(params);
+ let Ok(params) = params else {
+ todo!("invalid request")
+ };
+
+ let db = db.get_ref();
+ let Some(client_id) = db::get_client_id_by_alias(db, &params.client_id).await.unwrap() else {
+ todo!("client not found")
+ };
+
+ // verify scope
+ let Some(allowed_scopes) = db::get_client_allowed_scopes(db, client_id).await.unwrap() else {
+ todo!("client not found")
+ };
+
+ let scope = if let Some(scope) = &params.scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ todo!("invalid request")
+ };
+ scope
+ };
+
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ todo!("access_denied")
+ }
+
+ // verify redirect uri
+ if let Some(redirect_uri) = &params.redirect_uri {
+ if !db::client_has_redirect_uri(db, client_id, redirect_uri)
+ .await
+ .unwrap()
+ {
+ todo!("access denied")
+ }
+ } else {
+ let redirect_uris = db::get_client_redirect_uris(db, client_id).await.unwrap();
+ if redirect_uris.len() != 1 {
+ todo!("must have redirect uri")
+ }
+ }
+
+ // verify response type
+ if params.response_type == ResponseType::Unsupported {
+ todo!("unsupported response type")
+ }
+
// TODO find a better way of doing languages
- // TODO check that the URI is valid
let language = Language::from_str("en").unwrap();
let page =
- templates::login_page(&tera, &query, language, translations.get_ref().clone()).unwrap();
+ templates::login_page(&tera, &params, language, translations.get_ref().clone()).unwrap();
HttpResponse::Ok().content_type("text/html").body(page)
}
#[derive(Clone, Deserialize)]
#[serde(tag = "grant_type")]
-enum GrantType {}
+#[serde(rename_all = "snake_case")]
+enum GrantType {
+ AuthorizationCode {
+ code: Box<str>,
+ redirect_uri: Url,
+ #[serde(rename = "client_id")]
+ client_alias: Box<str>,
+ },
+ Password {
+ username: Box<str>,
+ password: Box<str>,
+ scope: Option<Box<str>>,
+ },
+ ClientCredentials {
+ scope: Option<Box<str>>,
+ },
+}
#[derive(Clone, Deserialize)]
struct TokenRequest {
#[serde(flatten)]
grant_type: GrantType,
- scope: String, // TODO lol no
- // TODO support optional client credentials in here
+ // TODO support optional client credentials in here
+}
+
+#[derive(Clone, Serialize)]
+struct TokenResponse {
+ access_token: Box<str>,
+ token_type: Box<str>,
+ expires_in: i64,
+ refresh_token: Box<str>,
+ scope: Box<str>,
}
#[post("/token")]
async fn token(
db: web::Data<MySqlPool>,
- req: web::Form<TokenRequest>,
+ req: web::Bytes,
authorization: Option<web::Header<authorization::BasicAuthorization>>,
) -> HttpResponse {
// TODO protect against brute force attacks
- todo!()
+ let db = db.get_ref();
+ let request = serde_json::from_slice::<TokenRequest>(&req);
+ let Ok(request) = request else {
+ todo!("invalid request")
+ };
+
+ let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value
+ let duration = Duration::hours(1);
+ let token_type = Box::from("bearer");
+ let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
+
+ match request.grant_type {
+ GrantType::AuthorizationCode {
+ code,
+ redirect_uri,
+ client_alias,
+ } => {
+ let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else {
+ todo!("client not found")
+ };
+
+ let Ok(claims) = jwt::verify_auth_code(db, &code, self_id.clone(), client_id, redirect_uri).await else {
+ todo!("invalid code");
+ };
+
+ // verify client, if the client has credentials
+ if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() {
+ let Some(authorization) = authorization else {
+ todo!("no client credentials")
+ };
+
+ if authorization.username() != client_alias.deref() {
+ todo!("bad username")
+ }
+ if !hash.check_password(authorization.password()).unwrap() {
+ todo!("bad password")
+ }
+ }
+
+ let access_token = jwt::Claims::access_token(
+ db,
+ claims.id(),
+ self_id,
+ client_id,
+ duration,
+ claims.scopes(),
+ )
+ .await
+ .unwrap();
+
+ let expires_in = access_token.expires_in();
+ let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
+ let scope = access_token.scopes().into();
+
+ let access_token = access_token.to_jwt().unwrap();
+ let refresh_token = refresh_token.to_jwt().unwrap();
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ GrantType::Password {
+ username,
+ password,
+ scope,
+ } => todo!(),
+ GrantType::ClientCredentials { scope } => todo!(),
+ }
}
pub fn service() -> Scope {
diff --git a/src/main.rs b/src/main.rs
index 1106dc0..da740be 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -29,7 +29,7 @@ fn error_content_language<B>(
async fn delete_expired_tokens(db: MySqlPool) {
let db = db.clone();
- let mut interval = actix_rt::time::interval(Duration::from_secs(60 * 10));
+ let mut interval = actix_rt::time::interval(Duration::from_secs(60 * 20));
loop {
interval.tick().await;
if let Err(e) = db::delete_expired_auth_codes(&db).await {
diff --git a/src/services/db/client.rs b/src/services/db/client.rs
index c25ad0d..70701d7 100644
--- a/src/services/db/client.rs
+++ b/src/services/db/client.rs
@@ -21,6 +21,13 @@ pub struct ClientRow {
pub default_scopes: Option<String>,
}
+#[derive(Clone, FromRow)]
+struct HashRow {
+ secret_hash: Option<Vec<u8>>,
+ secret_salt: Option<Vec<u8>>,
+ secret_version: Option<u32>,
+}
+
pub async fn client_id_exists<'c>(
executor: impl Executor<'c, Database = MySql>,
id: Uuid,
@@ -47,6 +54,19 @@ pub async fn client_alias_exists<'c>(
.unexpect()
}
+pub async fn get_client_id_by_alias<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ alias: &str,
+) -> Result<Option<Uuid>, RawUnexpected> {
+ query_scalar!(
+ "SELECT id as `id: Uuid` FROM clients WHERE alias = ?",
+ alias
+ )
+ .fetch_optional(executor)
+ .await
+ .unexpect()
+}
+
pub async fn get_client_response<'c>(
executor: impl Executor<'c, Database = MySql>,
id: Uuid,
@@ -116,6 +136,28 @@ pub async fn get_client_default_scopes<'c>(
Ok(scopes.map(|s| s.map(Box::from)))
}
+pub async fn get_client_secret<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<PasswordHash>, RawUnexpected> {
+ let hash = query_as!(
+ HashRow,
+ r"SELECT secret_hash, secret_salt, secret_version
+ FROM clients WHERE id = ?",
+ id
+ )
+ .fetch_optional(executor)
+ .await?;
+
+ let Some(hash) = hash else { return Ok(None) };
+ let Some(version) = hash.secret_version else { return Ok(None) };
+ let Some(salt) = hash.secret_hash else { return Ok(None) };
+ let Some(hash) = hash.secret_salt else { return Ok(None) };
+
+ let hash = PasswordHash::from_fields(&hash, &salt, version as u8);
+ Ok(Some(hash))
+}
+
pub async fn get_client_redirect_uris<'c>(
executor: impl Executor<'c, Database = MySql>,
id: Uuid,
@@ -136,7 +178,7 @@ pub async fn get_client_redirect_uris<'c>(
pub async fn client_has_redirect_uri<'c>(
executor: impl Executor<'c, Database = MySql>,
id: Uuid,
- url: Url,
+ url: &Url,
) -> Result<bool, RawUnexpected> {
query_scalar!(
r"SELECT EXISTS(
diff --git a/src/services/jwt.rs b/src/services/jwt.rs
index 7841afb..822101f 100644
--- a/src/services/jwt.rs
+++ b/src/services/jwt.rs
@@ -32,6 +32,7 @@ pub struct Claims {
client_id: Uuid,
auth_code_id: Uuid,
token_type: TokenType,
+ redirect_uri: Option<Url>,
}
#[derive(Debug, Clone, Copy, sqlx::Type)]
@@ -43,18 +44,19 @@ pub enum RevokedRefreshTokenReason {
impl Claims {
pub async fn auth_code<'c>(
- db: MySqlPool,
+ db: &MySqlPool,
self_id: Url,
client_id: Uuid,
scopes: &str,
+ redirect_uri: &Url,
) -> Result<Self, RawUnexpected> {
let five_minutes = Duration::minutes(5);
- let id = new_id(&db, db::auth_code_exists).await?;
+ let id = new_id(db, db::auth_code_exists).await?;
let time = Utc::now();
let exp = time + five_minutes;
- db::create_auth_code(&db, id, exp).await?;
+ db::create_auth_code(db, id, exp).await?;
Ok(Self {
iss: self_id,
@@ -67,22 +69,23 @@ impl Claims {
client_id,
auth_code_id: id,
token_type: TokenType::Authorization,
+ redirect_uri: Some(redirect_uri.clone()),
})
}
pub async fn access_token<'c>(
- db: MySqlPool,
+ db: &MySqlPool,
auth_code_id: Uuid,
self_id: Url,
client_id: Uuid,
duration: Duration,
scopes: &str,
) -> Result<Self, RawUnexpected> {
- let id = new_id(&db, db::access_token_exists).await?;
+ let id = new_id(db, db::access_token_exists).await?;
let time = Utc::now();
let exp = time + duration;
- db::create_access_token(&db, id, auth_code_id, exp)
+ db::create_access_token(db, id, auth_code_id, exp)
.await
.unexpect()?;
@@ -97,19 +100,23 @@ impl Claims {
client_id,
auth_code_id,
token_type: TokenType::Access,
+ redirect_uri: None,
})
}
- pub async fn refresh_token(db: MySqlPool, other_token: Claims) -> Result<Self, RawUnexpected> {
+ pub async fn refresh_token(
+ db: &MySqlPool,
+ other_token: &Claims,
+ ) -> Result<Self, RawUnexpected> {
let one_day = Duration::days(1);
- let id = new_id(&db, db::refresh_token_exists).await?;
+ let id = new_id(db, db::refresh_token_exists).await?;
let time = Utc::now();
let exp = other_token.exp + one_day;
- db::create_refresh_token(&db, id, other_token.auth_code_id, exp).await?;
+ db::create_refresh_token(db, id, other_token.auth_code_id, exp).await?;
- let mut claims = other_token;
+ let mut claims = other_token.clone();
claims.exp = exp;
claims.iat = Some(time);
claims.jti = id;
@@ -119,15 +126,15 @@ impl Claims {
}
pub async fn refreshed_access_token(
- db: MySqlPool,
+ db: &MySqlPool,
refresh_token: Claims,
exp_time: Duration,
) -> Result<Self, RawUnexpected> {
- let id = new_id(&db, db::access_token_exists).await?;
+ let id = new_id(db, db::access_token_exists).await?;
let time = Utc::now();
let exp = time + exp_time;
- db::create_access_token(&db, id, refresh_token.auth_code_id, exp).await?;
+ db::create_access_token(db, id, refresh_token.auth_code_id, exp).await?;
let mut claims = refresh_token;
claims.exp = exp;
@@ -142,6 +149,10 @@ impl Claims {
self.jti
}
+ pub fn expires_in(&self) -> i64 {
+ (self.exp - Utc::now()).num_seconds()
+ }
+
pub fn scopes(&self) -> &str {
&self.scope
}
@@ -163,6 +174,8 @@ pub enum VerifyJwtError {
WrongClient,
#[error("The given audience parameter does not contain this issuer")]
BadAudience,
+ #[error("The redirect URI doesn't match what's in the token")]
+ IncorrectRedirectUri,
#[error("The token is expired")]
ExpiredToken,
#[error("The token cannot be used yet")]
@@ -211,16 +224,23 @@ fn verify_jwt(
}
pub async fn verify_auth_code<'c>(
- db: MySqlPool,
+ db: &MySqlPool,
token: &str,
self_id: Url,
client_id: Uuid,
+ redirect_uri: Url,
) -> Result<Claims, Expect<VerifyJwtError>> {
let claims = verify_jwt(token, self_id, client_id)?;
- if db::delete_auth_code(&db, claims.jti).await? {
- db::delete_access_tokens_with_auth_code(&db, claims.jti).await?;
- db::revoke_refresh_tokens_with_auth_code(&db, claims.jti).await?;
+ if let Some(claimed_uri) = &claims.redirect_uri {
+ if claimed_uri.clone() != redirect_uri {
+ yeet!(VerifyJwtError::IncorrectRedirectUri.into());
+ }
+ }
+
+ if db::delete_auth_code(db, claims.jti).await? {
+ db::delete_access_tokens_with_auth_code(db, claims.jti).await?;
+ db::revoke_refresh_tokens_with_auth_code(db, claims.jti).await?;
yeet!(VerifyJwtError::JwtRevoked.into());
}