diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/api/oauth.rs | 278 | ||||
| -rw-r--r-- | src/resources/templates.rs | 17 |
2 files changed, 197 insertions, 98 deletions
diff --git a/src/api/oauth.rs b/src/api/oauth.rs index 43ad402..ef40637 100644 --- a/src/api/oauth.rs +++ b/src/api/oauth.rs @@ -6,6 +6,8 @@ use actix_web::{ get, post, web, HttpRequest, HttpResponse, HttpResponseBuilder, ResponseError, Scope, }; use chrono::Duration; +use exun::{Expect, RawUnexpected, ResultErrorExt, UnexpectedError}; +use raise::yeet; use serde::{Deserialize, Serialize}; use sqlx::MySqlPool; use tera::Tera; @@ -20,6 +22,8 @@ use crate::scopes; use crate::services::jwt::VerifyJwtError; use crate::services::{authorization, db, jwt}; +const REALLY_BAD_ERROR_PAGE: &str = "<!DOCTYPE html><html><head><title>Internal Server Error</title></head><body>Internal Server Error</body></html>"; + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] enum ResponseType { @@ -111,6 +115,15 @@ impl AuthorizeError { redirect_uri, } } + + fn internal_server_error(redirect_uri: Url, state: Option<Box<str>>) -> Self { + Self { + error: AuthorizeErrorType::ServerError, + error_description: "An unexpected error occurred".into(), + state, + redirect_uri, + } + } } impl ResponseError for AuthorizeError { @@ -126,6 +139,91 @@ impl ResponseError for AuthorizeError { } } +fn error_page( + tera: &Tera, + translations: &languages::Translations, + error: templates::ErrorPage, +) -> Result<String, RawUnexpected> { + // TODO find a better way of doing languages + let language = Language::from_str("en").unwrap(); + let translations = translations.clone(); + let page = templates::error_page(&tera, language, translations, error)?; + Ok(page) +} + +async fn get_redirect_uri( + redirect_uri: &Option<Url>, + db: &MySqlPool, + client_id: Uuid, +) -> Result<Url, Expect<templates::ErrorPage>> { + if let Some(uri) = &redirect_uri { + let redirect_uri = uri.clone(); + if !db::client_has_redirect_uri(db, client_id, &redirect_uri) + .await + .map_err(|e| UnexpectedError::from(e)) + .unexpect()? + { + yeet!(Expect::Expected(templates::ErrorPage::InvalidRedirectUri)); + } + + Ok(redirect_uri) + } else { + let redirect_uris = db::get_client_redirect_uris(db, client_id) + .await + .map_err(|e| UnexpectedError::from(e)) + .unexpect()?; + if redirect_uris.len() != 1 { + yeet!(Expect::Expected(templates::ErrorPage::MissingRedirectUri)); + } + + Ok(redirect_uris.get(0).unwrap().clone()) + } +} + +async fn get_scope( + scope: &Option<Box<str>>, + db: &MySqlPool, + client_id: Uuid, + redirect_uri: &Url, + state: &Option<Box<str>>, +) -> Result<Box<str>, Expect<AuthorizeError>> { + let scope = if let Some(scope) = &scope { + scope.clone() + } else { + let default_scopes = db::get_client_default_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + let Some(scope) = default_scopes else { + yeet!(AuthorizeError::no_scope(redirect_uri.clone(), state.clone()).into()) + }; + scope + }; + + // verify scope is valid + let allowed_scopes = db::get_client_allowed_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + if !scopes::is_subset_of(&scope, &allowed_scopes) { + yeet!(AuthorizeError::invalid_scope(redirect_uri.clone(), state.clone()).into()); + } + + Ok(scope) +} + +async fn authenticate_user( + db: &MySqlPool, + username: &str, + password: &str, +) -> Result<bool, RawUnexpected> { + let Some(user) = db::get_user_by_username(db, username).await? else { + return Ok(false); + }; + + Ok(user.check_password(password)?) +} + #[post("/authorize")] async fn authorize( db: web::Data<MySqlPool>, @@ -134,62 +232,53 @@ async fn authorize( tera: web::Data<Tera>, translations: web::Data<languages::Translations>, ) -> HttpResponse { - // TODO use sessions to verify that the request was previously validated // TODO handle internal server error + // TODO protect against brute force attacks let db = db.get_ref(); - let Some(client_id) = db::get_client_id_by_alias(db, &req.client_id).await.unwrap() else { - // TODO find a better way of doing languages - let language = Language::from_str("en").unwrap(); - let translations = translations.get_ref().clone(); - let page = templates::error_page(&tera, language, translations, templates::ErrorPage::ClientNotFound).unwrap(); + let Ok(client_id) = db::get_client_id_by_alias(db, &req.client_id).await else { + let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); + return HttpResponse::InternalServerError().content_type("text/html").body(page); + }; + let Some(client_id) = client_id else { + let page = error_page(&tera, &translations, templates::ErrorPage::ClientNotFound).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); return HttpResponse::NotFound().content_type("text/html").body(page); }; let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value let state = req.state.clone(); // get redirect uri - let mut redirect_uri = if let Some(redirect_uri) = &req.redirect_uri { - redirect_uri.clone() - } else { - let redirect_uris = db::get_client_redirect_uris(db, client_id).await.unwrap(); - if redirect_uris.len() != 1 { - let language = Language::from_str("en").unwrap(); - let translations = translations.get_ref().clone(); - let page = templates::error_page( - &tera, - language, - translations, - templates::ErrorPage::MissingRedirectUri, - ) - .unwrap(); - return HttpResponse::NotFound() + let mut redirect_uri = match get_redirect_uri(&req.redirect_uri, db, client_id).await { + Ok(uri) => uri, + Err(e) => { + let e = e + .expected() + .unwrap_or(templates::ErrorPage::InternalServerError); + let page = error_page(&tera, &translations, e) + .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); + return HttpResponse::BadRequest() .content_type("text/html") .body(page); } - - redirect_uris[0].clone() }; // authenticate user - let Some(user) = db::get_user_by_username(db, &credentials.username).await.unwrap() else { - todo!("bad username") + if !authenticate_user(db, &credentials.username, &credentials.password) + .await + .unwrap() + { + let language = Language::from_str("en").unwrap(); + let translations = translations.get_ref().clone(); + let page = templates::login_error_page(&tera, &req, language, translations).unwrap(); + return HttpResponse::Ok().content_type("text/html").body(page); }; - if !user.check_password(&credentials.password).unwrap() { - todo!("bad password") - } // get scope - let scope = if let Some(scope) = &req.scope { - scope.clone() - } else { - let default_scopes = db::get_client_default_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - let Some(scope) = default_scopes else { - return AuthorizeError::no_scope(redirect_uri, state).error_response() - }; - scope + let scope = match get_scope(&req.scope, db, client_id, &redirect_uri, &state).await { + Ok(scope) => scope, + Err(e) => { + let e = e.unwrap(); + return e.error_response(); + } }; match req.response_type { @@ -248,97 +337,77 @@ async fn authorize_page( request: HttpRequest, ) -> HttpResponse { // TODO handle internal server error - let language = Language::from_str("en").unwrap(); + let Ok(language) = Language::from_str("en") else { + let page = String::from(REALLY_BAD_ERROR_PAGE); + return HttpResponse::InternalServerError() + .content_type("text/html") + .body(page); + }; let translations = translations.get_ref().clone(); let params = request.query_string(); let params = serde_urlencoded::from_str::<AuthorizationParameters>(params); let Ok(params) = params else { - let page = templates::error_page( + let page = error_page( &tera, - language, - translations, + &translations, templates::ErrorPage::InvalidRequest, ) - .unwrap(); + .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); return HttpResponse::BadRequest() .content_type("text/html") .body(page); }; let db = db.get_ref(); - let Some(client_id) = db::get_client_id_by_alias(db, ¶ms.client_id).await.unwrap() else { + let Ok(client_id) = db::get_client_id_by_alias(db, ¶ms.client_id).await else { + let page = templates::error_page( + &tera, + language, + translations, + templates::ErrorPage::InternalServerError, + ) + .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); + return HttpResponse::InternalServerError() + .content_type("text/html") + .body(page); + }; + let Some(client_id) = client_id else { let page = templates::error_page( &tera, language, translations, templates::ErrorPage::ClientNotFound, ) - .unwrap(); + .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); return HttpResponse::NotFound() .content_type("text/html") .body(page); }; - // verify scope - let allowed_scopes = db::get_client_allowed_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - // verify redirect uri - let redirect_uri: Url; - if let Some(uri) = ¶ms.redirect_uri { - redirect_uri = uri.clone(); - if !db::client_has_redirect_uri(db, client_id, &redirect_uri) - .await - .unwrap() - { - let page = templates::error_page( - &tera, - language, - translations, - templates::ErrorPage::InvalidRedirectUri, - ) - .unwrap(); + let redirect_uri = match get_redirect_uri(¶ms.redirect_uri, db, client_id).await { + Ok(uri) => uri, + Err(e) => { + let e = e + .expected() + .unwrap_or(templates::ErrorPage::InternalServerError); + let page = error_page(&tera, &translations, e) + .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); return HttpResponse::BadRequest() .content_type("text/html") .body(page); } - } else { - let redirect_uris = db::get_client_redirect_uris(db, client_id).await.unwrap(); - if redirect_uris.len() != 1 { - let page = templates::error_page( - &tera, - language, - translations, - templates::ErrorPage::MissingRedirectUri, - ) - .unwrap(); - return HttpResponse::NotFound() - .content_type("text/html") - .body(page); - } - - redirect_uri = redirect_uris.get(0).unwrap().clone(); - } - - let scope = if let Some(scope) = ¶ms.scope { - scope.clone() - } else { - let default_scopes = db::get_client_default_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - let Some(scope) = default_scopes else { - return AuthorizeError::no_scope(redirect_uri, params.state).error_response(); - }; - scope }; - if !scopes::is_subset_of(&scope, &allowed_scopes) { - return AuthorizeError::invalid_scope(redirect_uri, params.state).error_response(); - } + // verify scope + let _ = match get_scope(¶ms.scope, db, client_id, &redirect_uri, ¶ms.state).await { + Ok(scope) => scope, + Err(e) => { + let e = e.unwrap(); + return e.error_response(); + } + }; // verify response type if params.response_type == ResponseType::Unsupported { @@ -520,6 +589,14 @@ impl TokenError { error_description: "Only trusted clients may use this grant".into(), } } + + fn incorrect_user_credentials() -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::InvalidRequest, + error_description: "The given credentials are incorrect".into(), + } + } } impl ResponseError for TokenError { @@ -647,6 +724,11 @@ async fn token( return TokenError::incorrect_client_secret().error_response(); } + // authenticate user + if !authenticate_user(db, &username, &password).await.unwrap() { + return TokenError::incorrect_user_credentials().error_response(); + }; + // verify scope let allowed_scopes = db::get_client_allowed_scopes(db, client_id) .await diff --git a/src/resources/templates.rs b/src/resources/templates.rs index 88c1fad..9168fb9 100644 --- a/src/resources/templates.rs +++ b/src/resources/templates.rs @@ -44,6 +44,7 @@ pub enum ErrorPage { ClientNotFound, MissingRedirectUri, InvalidRedirectUri, + InternalServerError, } pub fn error_page( @@ -82,3 +83,19 @@ pub fn login_page( context.insert("params", &serde_urlencoded::to_string(params)?); tera.render("login.html", &context).unexpect() } + +pub fn login_error_page( + tera: &Tera, + params: &AuthorizationParameters, + language: Language, + mut translations: languages::Translations, +) -> Result<String, RawUnexpected> { + translations.refresh()?; + let mut tera = extend_tera(tera, language, translations)?; + tera.full_reload()?; + let mut context = tera::Context::new(); + context.insert("lang", language.as_str()); + context.insert("params", &serde_urlencoded::to_string(params)?); + context.insert("errorMessage", "loginErrorMessage"); + tera.render("login.html", &context).unexpect() +} |
