summaryrefslogtreecommitdiff
path: root/src/api/oauth.rs
diff options
context:
space:
mode:
authormrw1593 <botahamec@outlook.com>2023-06-11 15:34:00 -0400
committermrw1593 <botahamec@outlook.com>2023-06-11 15:34:00 -0400
commitac7317226405fc90e8439a0c1bef91cecd539d02 (patch)
tree61e7009238b6002d51e93ede8f52aa7bef526f2b /src/api/oauth.rs
parent83fdd59b13d4bf45bd35d9693ae361ff896636ab (diff)
Implement the authorization code grant
Diffstat (limited to 'src/api/oauth.rs')
-rw-r--r--src/api/oauth.rs304
1 files changed, 287 insertions, 17 deletions
diff --git a/src/api/oauth.rs b/src/api/oauth.rs
index d77695e..920f488 100644
--- a/src/api/oauth.rs
+++ b/src/api/oauth.rs
@@ -1,29 +1,35 @@
+use std::ops::Deref;
use std::str::FromStr;
-use actix_web::{get, post, web, HttpResponse, Scope};
+use actix_web::http::header;
+use actix_web::{get, post, web, HttpRequest, HttpResponse, ResponseError, Scope};
+use chrono::Duration;
use serde::{Deserialize, Serialize};
use sqlx::MySqlPool;
use tera::Tera;
+use thiserror::Error;
use unic_langid::subtags::Language;
use url::Url;
-use uuid::Uuid;
use crate::resources::{languages, templates};
-use crate::services::{authorization, db};
+use crate::scopes;
+use crate::services::{authorization, db, jwt};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ResponseType {
Code,
Token,
+ #[serde(other)]
+ Unsupported,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorizationParameters {
response_type: ResponseType,
- client_id: Uuid,
+ client_id: Box<str>,
redirect_uri: Option<Url>,
- scope: String, // TODO lol no
+ scope: Option<Box<str>>,
state: Option<Box<str>>,
}
@@ -33,14 +39,127 @@ struct AuthorizeCredentials {
password: Box<str>,
}
+#[derive(Clone, Serialize)]
+struct CodeResponse {
+ code: Box<str>,
+ state: Option<Box<str>>,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
+#[serde(rename_all = "camelCase")]
+enum AuthorizeErrorType {
+ InvalidRequest,
+ UnauthorizedClient,
+ AccessDenied,
+ UnsupportedResponseType,
+ InvalidScope,
+ ServerError,
+ TemporarilyUnavailable,
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("{error_description}")]
+struct AuthorizeError {
+ error: AuthorizeErrorType,
+ error_description: Box<str>,
+ // TODO error uri
+ state: Option<Box<str>>,
+ redirect_uri: Url,
+}
+
+impl AuthorizeError {
+ fn no_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::InvalidScope,
+ error_description: Box::from(
+ "No scope was provided, and the client does not have a default scope",
+ ),
+ state,
+ redirect_uri,
+ }
+ }
+}
+
+impl ResponseError for AuthorizeError {
+ fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
+ let error = serde_variant::to_variant_name(&self.error).unwrap_or_default();
+ let mut url = self.redirect_uri.clone();
+ url.query_pairs_mut()
+ .append_pair("error", error)
+ .append_pair("error_description", &self.error_description);
+
+ if let Some(state) = &self.state {
+ url.query_pairs_mut().append_pair("state", &state);
+ }
+
+ HttpResponse::Found()
+ .insert_header((header::LOCATION, url.as_str()))
+ .finish()
+ }
+}
+
#[post("/authorize")]
async fn authorize(
db: web::Data<MySqlPool>,
- query: web::Query<AuthorizationParameters>,
- credentials: web::Form<AuthorizeCredentials>,
+ req: web::Query<AuthorizationParameters>,
+ credentials: web::Json<AuthorizeCredentials>,
) -> HttpResponse {
- // TODO check that the URI is valid
- todo!()
+ // TODO use sessions to verify that the request was previously validated
+ let db = db.get_ref();
+ let Some(client_id) = db::get_client_id_by_alias(db, &req.client_id).await.unwrap() else {
+ todo!("client not found")
+ };
+ let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value
+ let state = req.state.clone();
+
+ // get redirect uri
+ let redirect_uri = if let Some(redirect_uri) = &req.redirect_uri {
+ redirect_uri.clone()
+ } else {
+ let redirect_uris = db::get_client_redirect_uris(db, client_id).await.unwrap();
+ if redirect_uris.len() != 1 {
+ todo!("no redirect uri");
+ }
+
+ redirect_uris[0].clone()
+ };
+
+ // authenticate user
+ let Some(user) = db::get_user_by_username(db, &credentials.username).await.unwrap() else {
+ todo!("bad username")
+ };
+ if !user.check_password(&credentials.password).unwrap() {
+ todo!("bad password")
+ }
+
+ // get scope
+ let scope = if let Some(scope) = &req.scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ return AuthorizeError::no_scope(redirect_uri, state).error_response()
+ };
+ scope
+ };
+
+ match req.response_type {
+ ResponseType::Code => {
+ // create auth code
+ let code = jwt::Claims::auth_code(db, self_id, client_id, &scope, &redirect_uri)
+ .await
+ .unwrap();
+ let code = code.to_jwt().unwrap();
+ let response = CodeResponse { code, state };
+
+ HttpResponse::Ok().json(response)
+ }
+ ResponseType::Token => todo!(),
+ _ => todo!("unsupported response type"),
+ }
}
#[get("/authorize")]
@@ -48,36 +167,187 @@ async fn authorize_page(
db: web::Data<MySqlPool>,
tera: web::Data<Tera>,
translations: web::Data<languages::Translations>,
- query: web::Query<AuthorizationParameters>,
+ request: HttpRequest,
) -> HttpResponse {
+ let params = request.query_string();
+ let params = serde_urlencoded::from_str::<AuthorizationParameters>(params);
+ let Ok(params) = params else {
+ todo!("invalid request")
+ };
+
+ let db = db.get_ref();
+ let Some(client_id) = db::get_client_id_by_alias(db, &params.client_id).await.unwrap() else {
+ todo!("client not found")
+ };
+
+ // verify scope
+ let Some(allowed_scopes) = db::get_client_allowed_scopes(db, client_id).await.unwrap() else {
+ todo!("client not found")
+ };
+
+ let scope = if let Some(scope) = &params.scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ todo!("invalid request")
+ };
+ scope
+ };
+
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ todo!("access_denied")
+ }
+
+ // verify redirect uri
+ if let Some(redirect_uri) = &params.redirect_uri {
+ if !db::client_has_redirect_uri(db, client_id, redirect_uri)
+ .await
+ .unwrap()
+ {
+ todo!("access denied")
+ }
+ } else {
+ let redirect_uris = db::get_client_redirect_uris(db, client_id).await.unwrap();
+ if redirect_uris.len() != 1 {
+ todo!("must have redirect uri")
+ }
+ }
+
+ // verify response type
+ if params.response_type == ResponseType::Unsupported {
+ todo!("unsupported response type")
+ }
+
// TODO find a better way of doing languages
- // TODO check that the URI is valid
let language = Language::from_str("en").unwrap();
let page =
- templates::login_page(&tera, &query, language, translations.get_ref().clone()).unwrap();
+ templates::login_page(&tera, &params, language, translations.get_ref().clone()).unwrap();
HttpResponse::Ok().content_type("text/html").body(page)
}
#[derive(Clone, Deserialize)]
#[serde(tag = "grant_type")]
-enum GrantType {}
+#[serde(rename_all = "snake_case")]
+enum GrantType {
+ AuthorizationCode {
+ code: Box<str>,
+ redirect_uri: Url,
+ #[serde(rename = "client_id")]
+ client_alias: Box<str>,
+ },
+ Password {
+ username: Box<str>,
+ password: Box<str>,
+ scope: Option<Box<str>>,
+ },
+ ClientCredentials {
+ scope: Option<Box<str>>,
+ },
+}
#[derive(Clone, Deserialize)]
struct TokenRequest {
#[serde(flatten)]
grant_type: GrantType,
- scope: String, // TODO lol no
- // TODO support optional client credentials in here
+ // TODO support optional client credentials in here
+}
+
+#[derive(Clone, Serialize)]
+struct TokenResponse {
+ access_token: Box<str>,
+ token_type: Box<str>,
+ expires_in: i64,
+ refresh_token: Box<str>,
+ scope: Box<str>,
}
#[post("/token")]
async fn token(
db: web::Data<MySqlPool>,
- req: web::Form<TokenRequest>,
+ req: web::Bytes,
authorization: Option<web::Header<authorization::BasicAuthorization>>,
) -> HttpResponse {
// TODO protect against brute force attacks
- todo!()
+ let db = db.get_ref();
+ let request = serde_json::from_slice::<TokenRequest>(&req);
+ let Ok(request) = request else {
+ todo!("invalid request")
+ };
+
+ let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value
+ let duration = Duration::hours(1);
+ let token_type = Box::from("bearer");
+ let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
+
+ match request.grant_type {
+ GrantType::AuthorizationCode {
+ code,
+ redirect_uri,
+ client_alias,
+ } => {
+ let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else {
+ todo!("client not found")
+ };
+
+ let Ok(claims) = jwt::verify_auth_code(db, &code, self_id.clone(), client_id, redirect_uri).await else {
+ todo!("invalid code");
+ };
+
+ // verify client, if the client has credentials
+ if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() {
+ let Some(authorization) = authorization else {
+ todo!("no client credentials")
+ };
+
+ if authorization.username() != client_alias.deref() {
+ todo!("bad username")
+ }
+ if !hash.check_password(authorization.password()).unwrap() {
+ todo!("bad password")
+ }
+ }
+
+ let access_token = jwt::Claims::access_token(
+ db,
+ claims.id(),
+ self_id,
+ client_id,
+ duration,
+ claims.scopes(),
+ )
+ .await
+ .unwrap();
+
+ let expires_in = access_token.expires_in();
+ let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
+ let scope = access_token.scopes().into();
+
+ let access_token = access_token.to_jwt().unwrap();
+ let refresh_token = refresh_token.to_jwt().unwrap();
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ GrantType::Password {
+ username,
+ password,
+ scope,
+ } => todo!(),
+ GrantType::ClientCredentials { scope } => todo!(),
+ }
}
pub fn service() -> Scope {