summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormrw1593 <botahamec@outlook.com>2023-06-18 13:40:31 -0400
committermrw1593 <botahamec@outlook.com>2023-06-18 13:40:31 -0400
commit3feb8911aeff353238f6fdb8f71d4b970625b28d (patch)
tree3450e35ce38921259839d57d570fc19b1c906751
parentac7317226405fc90e8439a0c1bef91cecd539d02 (diff)
Implement the client credentials flow
-rw-r--r--src/api/oauth.rs223
-rw-r--r--src/services/db/jwt.rs4
-rw-r--r--src/services/jwt.rs6
3 files changed, 215 insertions, 18 deletions
diff --git a/src/api/oauth.rs b/src/api/oauth.rs
index 920f488..48c3210 100644
--- a/src/api/oauth.rs
+++ b/src/api/oauth.rs
@@ -1,8 +1,10 @@
use std::ops::Deref;
use std::str::FromStr;
-use actix_web::http::header;
-use actix_web::{get, post, web, HttpRequest, HttpResponse, ResponseError, Scope};
+use actix_web::http::{header, StatusCode};
+use actix_web::{
+ get, post, web, HttpRequest, HttpResponse, HttpResponseBuilder, ResponseError, Scope,
+};
use chrono::Duration;
use serde::{Deserialize, Serialize};
use sqlx::MySqlPool;
@@ -11,8 +13,10 @@ use thiserror::Error;
use unic_langid::subtags::Language;
use url::Url;
+use crate::models::client::ClientType;
use crate::resources::{languages, templates};
use crate::scopes;
+use crate::services::jwt::VerifyJwtError;
use crate::services::{authorization, db, jwt};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
@@ -247,6 +251,8 @@ enum GrantType {
ClientCredentials {
scope: Option<Box<str>>,
},
+ #[serde(other)]
+ Unsupported,
}
#[derive(Clone, Deserialize)]
@@ -261,10 +267,131 @@ struct TokenResponse {
access_token: Box<str>,
token_type: Box<str>,
expires_in: i64,
- refresh_token: Box<str>,
+ refresh_token: Option<Box<str>>,
scope: Box<str>,
}
+#[derive(Debug, Clone, Serialize)]
+#[serde(rename_all = "snake_case")]
+enum TokenErrorType {
+ InvalidRequest,
+ InvalidClient,
+ InvalidGrant,
+ UnauthorizedClient,
+ UnsupportedGrantType,
+ InvalidScope,
+}
+
+#[derive(Debug, Clone, Error, Serialize)]
+#[error("{error_description}")]
+struct TokenError {
+ #[serde(skip)]
+ status_code: StatusCode,
+ error: TokenErrorType,
+ error_description: Box<str>,
+ // TODO error uri
+}
+
+impl TokenError {
+ fn invalid_request() -> Self {
+ // TODO make this description better, and all the other ones while you're at it
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidRequest,
+ error_description: "Invalid request".into(),
+ }
+ }
+
+ fn unsupported_grant_type() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::UnsupportedGrantType,
+ error_description: "The given grant type is not supported".into(),
+ }
+ }
+
+ fn bad_auth_code(error: VerifyJwtError) -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidGrant,
+ error_description: error.to_string().into_boxed_str(),
+ }
+ }
+
+ fn no_authorization() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: Box::from(
+ "Client credentials must be provided in the HTTP Authorization header",
+ ),
+ }
+ }
+
+ fn client_not_found(alias: &str) -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: format!("No client with the client id: {alias} was found")
+ .into_boxed_str(),
+ }
+ }
+
+ fn incorrect_client_secret() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: "The client secret is incorrect".into(),
+ }
+ }
+
+ fn client_not_confidential(alias: &str) -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::UnauthorizedClient,
+ error_description: format!("Only a confidential client may be used with this endpoint. The {alias} client is a public client.")
+ .into_boxed_str(),
+ }
+ }
+
+ fn no_scope() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidScope,
+ error_description: Box::from(
+ "No scope was provided, and the client doesn't have a default scope",
+ ),
+ }
+ }
+
+ fn excessive_scope() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidScope,
+ error_description: Box::from(
+ "The given scope exceeds what the client is allowed to have",
+ ),
+ }
+ }
+}
+
+impl ResponseError for TokenError {
+ fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
+ let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
+
+ let mut builder = HttpResponseBuilder::new(self.status_code);
+
+ if self.status_code.as_u16() == 401 {
+ builder.insert_header((header::WWW_AUTHENTICATE, "Basic charset=\"UTF-8\""));
+ }
+
+ builder
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(self.clone())
+ }
+}
+
#[post("/token")]
async fn token(
db: web::Data<MySqlPool>,
@@ -275,7 +402,7 @@ async fn token(
let db = db.get_ref();
let request = serde_json::from_slice::<TokenRequest>(&req);
let Ok(request) = request else {
- todo!("invalid request")
+ return TokenError::invalid_request().error_response();
};
let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value
@@ -290,30 +417,38 @@ async fn token(
client_alias,
} => {
let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else {
- todo!("client not found")
+ return TokenError::client_not_found(&client_alias).error_response();
};
- let Ok(claims) = jwt::verify_auth_code(db, &code, self_id.clone(), client_id, redirect_uri).await else {
- todo!("invalid code");
- };
+ // validate auth code
+ let claims =
+ match jwt::verify_auth_code(db, &code, self_id.clone(), client_id, redirect_uri)
+ .await
+ {
+ Ok(claims) => claims,
+ Err(err) => {
+ let err = err.unwrap();
+ return TokenError::bad_auth_code(err).error_response();
+ }
+ };
// verify client, if the client has credentials
if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() {
let Some(authorization) = authorization else {
- todo!("no client credentials")
+ return TokenError::no_authorization().error_response();
};
if authorization.username() != client_alias.deref() {
todo!("bad username")
}
if !hash.check_password(authorization.password()).unwrap() {
- todo!("bad password")
+ return TokenError::incorrect_client_secret().error_response();
}
}
let access_token = jwt::Claims::access_token(
db,
- claims.id(),
+ Some(claims.id()),
self_id,
client_id,
duration,
@@ -327,7 +462,7 @@ async fn token(
let scope = access_token.scopes().into();
let access_token = access_token.to_jwt().unwrap();
- let refresh_token = refresh_token.to_jwt().unwrap();
+ let refresh_token = Some(refresh_token.to_jwt().unwrap());
let response = TokenResponse {
access_token,
@@ -346,7 +481,69 @@ async fn token(
password,
scope,
} => todo!(),
- GrantType::ClientCredentials { scope } => todo!(),
+ GrantType::ClientCredentials { scope } => {
+ let Some(authorization) = authorization else {
+ return TokenError::no_authorization().error_response();
+ };
+ let client_alias = authorization.username();
+ let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
+ return TokenError::client_not_found(client_alias).error_response();
+ };
+
+ let ty = db::get_client_type(db, client_id).await.unwrap().unwrap();
+ if ty != ClientType::Confidential {
+ return TokenError::client_not_confidential(client_alias).error_response();
+ }
+
+ // verify client
+ let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap();
+ if !hash.check_password(authorization.password()).unwrap() {
+ return TokenError::incorrect_client_secret().error_response();
+ }
+
+ // verify scope
+ let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let scope = if let Some(scope) = &scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ return TokenError::no_scope().error_response();
+ };
+ scope
+ };
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ return TokenError::excessive_scope().error_response();
+ }
+
+ let access_token =
+ jwt::Claims::access_token(db, None, self_id, client_id, duration, &scope)
+ .await
+ .unwrap();
+
+ let expires_in = access_token.expires_in();
+ let scope = access_token.scopes().into();
+ let access_token = access_token.to_jwt().unwrap();
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token: None,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ _ => TokenError::unsupported_grant_type().error_response(),
}
}
diff --git a/src/services/db/jwt.rs b/src/services/db/jwt.rs
index a3edef2..b2f1367 100644
--- a/src/services/db/jwt.rs
+++ b/src/services/db/jwt.rs
@@ -81,7 +81,7 @@ pub async fn create_auth_code<'c>(
pub async fn create_access_token<'c>(
executor: impl Executor<'c, Database = MySql>,
jti: Uuid,
- auth_code: Uuid,
+ auth_code: Option<Uuid>,
exp: DateTime<Utc>,
) -> Result<(), sqlx::Error> {
query!(
@@ -100,7 +100,7 @@ pub async fn create_access_token<'c>(
pub async fn create_refresh_token<'c>(
executor: impl Executor<'c, Database = MySql>,
jti: Uuid,
- auth_code: Uuid,
+ auth_code: Option<Uuid>,
exp: DateTime<Utc>,
) -> Result<(), sqlx::Error> {
query!(
diff --git a/src/services/jwt.rs b/src/services/jwt.rs
index 822101f..c86fb01 100644
--- a/src/services/jwt.rs
+++ b/src/services/jwt.rs
@@ -30,8 +30,8 @@ pub struct Claims {
jti: Uuid,
scope: Box<str>,
client_id: Uuid,
- auth_code_id: Uuid,
token_type: TokenType,
+ auth_code_id: Option<Uuid>,
redirect_uri: Option<Url>,
}
@@ -67,7 +67,7 @@ impl Claims {
jti: id,
scope: scopes.into(),
client_id,
- auth_code_id: id,
+ auth_code_id: Some(id),
token_type: TokenType::Authorization,
redirect_uri: Some(redirect_uri.clone()),
})
@@ -75,7 +75,7 @@ impl Claims {
pub async fn access_token<'c>(
db: &MySqlPool,
- auth_code_id: Uuid,
+ auth_code_id: Option<Uuid>,
self_id: Url,
client_id: Uuid,
duration: Duration,