summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/api/clients.rs48
-rw-r--r--src/api/oauth.rs73
-rw-r--r--src/models/client.rs34
-rw-r--r--src/services/db/client.rs28
4 files changed, 163 insertions, 20 deletions
diff --git a/src/api/clients.rs b/src/api/clients.rs
index 7b6ec94..27ef995 100644
--- a/src/api/clients.rs
+++ b/src/api/clients.rs
@@ -7,7 +7,7 @@ use thiserror::Error;
use url::Url;
use uuid::Uuid;
-use crate::models::client::{Client, ClientType, NoSecretError};
+use crate::models::client::{Client, ClientType, CreateClientError};
use crate::services::crypto::PasswordHash;
use crate::services::db::ClientRow;
use crate::services::{db, id};
@@ -20,6 +20,7 @@ struct ClientResponse {
client_type: ClientType,
allowed_scopes: Box<[Box<str>]>,
default_scopes: Option<Box<[Box<str>]>>,
+ is_trusted: bool,
}
impl From<ClientRow> for ClientResponse {
@@ -36,6 +37,7 @@ impl From<ClientRow> for ClientResponse {
default_scopes: value
.default_scopes
.map(|s| s.split_whitespace().map(Box::from).collect()),
+ is_trusted: value.is_trusted,
}
}
}
@@ -164,6 +166,21 @@ async fn get_client_default_scopes(
Ok(HttpResponse::Ok().json(default_scopes))
}
+#[get("/{client_id}/is-trusted")]
+async fn get_client_is_trusted(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(is_trusted) = db::is_client_trusted(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ Ok(HttpResponse::Ok().json(is_trusted))
+}
+
#[derive(Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
struct ClientRequest {
@@ -173,6 +190,7 @@ struct ClientRequest {
secret: Option<Box<str>>,
allowed_scopes: Box<[Box<str>]>,
default_scopes: Option<Box<[Box<str>]>>,
+ trusted: bool,
}
#[derive(Debug, Clone, Error)]
@@ -216,6 +234,7 @@ async fn create_client(
body.allowed_scopes.clone(),
body.default_scopes.clone(),
&body.redirect_uris,
+ body.trusted,
)
.map_err(|e| e.unwrap())?;
@@ -233,7 +252,7 @@ enum UpdateClientError {
#[error(transparent)]
NotFound(#[from] ClientNotFound),
#[error(transparent)]
- NoSecret(#[from] NoSecretError),
+ ClientError(#[from] CreateClientError),
#[error(transparent)]
AliasTaken(#[from] AliasTakenError),
}
@@ -242,7 +261,7 @@ impl ResponseError for UpdateClientError {
fn status_code(&self) -> StatusCode {
match self {
Self::NotFound(e) => e.status_code(),
- Self::NoSecret(e) => e.status_code(),
+ Self::ClientError(e) => e.status_code(),
Self::AliasTaken(e) => e.status_code(),
}
}
@@ -273,6 +292,7 @@ async fn update_client(
body.allowed_scopes.clone(),
body.default_scopes.clone(),
&body.redirect_uris,
+ body.trusted,
)
.map_err(|e| e.unwrap())?;
@@ -370,6 +390,25 @@ async fn update_client_default_scopes(
Ok(HttpResponse::NoContent().finish())
}
+#[put("/{id}/is-trusted")]
+async fn update_client_is_trusted(
+ id: web::Path<Uuid>,
+ body: web::Json<bool>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let is_trusted = *body;
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_trusted(db, id, is_trusted).await.unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
#[put("/{id}/redirect-uris")]
async fn update_client_redirect_uris(
id: web::Path<Uuid>,
@@ -405,7 +444,7 @@ async fn update_client_secret(
};
if client_type == ClientType::Confidential && body.is_none() {
- yeet!(NoSecretError::new().into())
+ yeet!(CreateClientError::NoSecret.into())
}
let secret = body.0.map(|s| PasswordHash::new(&s).unwrap());
@@ -422,6 +461,7 @@ pub fn service() -> Scope {
.service(get_client_allowed_scopes)
.service(get_client_default_scopes)
.service(get_client_redirect_uris)
+ .service(get_client_is_trusted)
.service(create_client)
.service(update_client)
.service(update_client_alias)
diff --git a/src/api/oauth.rs b/src/api/oauth.rs
index 5d1f12a..43ad402 100644
--- a/src/api/oauth.rs
+++ b/src/api/oauth.rs
@@ -512,6 +512,14 @@ impl TokenError {
error_description: err.to_string().into_boxed_str(),
}
}
+
+ fn untrusted_client() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: "Only trusted clients may use this grant".into(),
+ }
+ }
}
impl ResponseError for TokenError {
@@ -619,7 +627,70 @@ async fn token(
username,
password,
scope,
- } => todo!(),
+ } => {
+ let Some(authorization) = authorization else {
+ return TokenError::no_authorization().error_response();
+ };
+ let client_alias = authorization.username();
+ let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
+ return TokenError::client_not_found(client_alias).error_response();
+ };
+
+ let trusted = db::is_client_trusted(db, client_id).await.unwrap().unwrap();
+ if !trusted {
+ return TokenError::untrusted_client().error_response();
+ }
+
+ // verify client
+ let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap();
+ if !hash.check_password(authorization.password()).unwrap() {
+ return TokenError::incorrect_client_secret().error_response();
+ }
+
+ // verify scope
+ let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let scope = if let Some(scope) = &scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ return TokenError::no_scope().error_response();
+ };
+ scope
+ };
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ return TokenError::excessive_scope().error_response();
+ }
+
+ let access_token =
+ jwt::Claims::access_token(db, None, self_id, client_id, duration, &scope)
+ .await
+ .unwrap();
+ let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
+
+ let expires_in = access_token.expires_in();
+ let scope = access_token.scopes().into();
+ let access_token = access_token.to_jwt().unwrap();
+ let refresh_token = Some(refresh_token.to_jwt().unwrap());
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
GrantType::ClientCredentials { scope } => {
let Some(authorization) = authorization else {
return TokenError::no_authorization().error_response();
diff --git a/src/models/client.rs b/src/models/client.rs
index 90c5902..56b0ae6 100644
--- a/src/models/client.rs
+++ b/src/models/client.rs
@@ -36,6 +36,7 @@ pub struct Client {
allowed_scopes: Box<[Box<str>]>,
default_scopes: Option<Box<[Box<str>]>>,
redirect_uris: Box<[Url]>,
+ trusted: bool,
}
impl PartialEq for Client {
@@ -54,24 +55,19 @@ impl Hash for Client {
#[derive(Debug, Clone, Copy, Error)]
#[error("Confidential clients must have a secret, but it was not provided")]
-pub struct NoSecretError {
- _phantom: PhantomData<()>,
+pub enum CreateClientError {
+ #[error("Confidential clients must have a secret, but it was not provided")]
+ NoSecret,
+ #[error("Only confidential clients may be trusted")]
+ TrustedError,
}
-impl ResponseError for NoSecretError {
+impl ResponseError for CreateClientError {
fn status_code(&self) -> StatusCode {
StatusCode::BAD_REQUEST
}
}
-impl NoSecretError {
- pub(crate) fn new() -> Self {
- Self {
- _phantom: PhantomData,
- }
- }
-}
-
impl Client {
pub fn new(
id: Uuid,
@@ -81,7 +77,8 @@ impl Client {
allowed_scopes: Box<[Box<str>]>,
default_scopes: Option<Box<[Box<str>]>>,
redirect_uris: &[Url],
- ) -> Result<Self, Expect<NoSecretError>> {
+ trusted: bool,
+ ) -> Result<Self, Expect<CreateClientError>> {
let secret = if let Some(secret) = secret {
Some(PasswordHash::new(secret)?)
} else {
@@ -89,17 +86,22 @@ impl Client {
};
if ty == ClientType::Confidential && secret.is_none() {
- yeet!(NoSecretError::new().into());
+ yeet!(CreateClientError::NoSecret.into());
+ }
+
+ if ty == ClientType::Public && trusted {
+ yeet!(CreateClientError::TrustedError.into());
}
Ok(Self {
id,
alias: Box::from(alias),
- ty: ClientType::Public,
+ ty,
secret,
allowed_scopes,
default_scopes,
redirect_uris: redirect_uris.into_iter().cloned().collect(),
+ trusted,
})
}
@@ -139,6 +141,10 @@ impl Client {
self.default_scopes.clone().map(|s| s.join(" "))
}
+ pub fn is_trusted(&self) -> bool {
+ self.trusted
+ }
+
pub fn check_secret(&self, secret: &str) -> Option<Result<bool, RawUnexpected>> {
self.secret.as_ref().map(|s| s.check_password(secret))
}
diff --git a/src/services/db/client.rs b/src/services/db/client.rs
index 70701d7..b8942e9 100644
--- a/src/services/db/client.rs
+++ b/src/services/db/client.rs
@@ -19,6 +19,7 @@ pub struct ClientRow {
pub client_type: ClientType,
pub allowed_scopes: String,
pub default_scopes: Option<String>,
+ pub is_trusted: bool,
}
#[derive(Clone, FromRow)]
@@ -77,7 +78,8 @@ pub async fn get_client_response<'c>(
alias,
type as `client_type: ClientType`,
allowed_scopes,
- default_scopes
+ default_scopes,
+ trusted as `is_trusted: bool`
FROM clients WHERE id = ?",
id
)
@@ -158,6 +160,16 @@ pub async fn get_client_secret<'c>(
Ok(Some(hash))
}
+pub async fn is_client_trusted<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<bool>, RawUnexpected> {
+ query_scalar!("SELECT trusted as `t: bool` FROM clients WHERE id = ?", id)
+ .fetch_optional(executor)
+ .await
+ .unexpect()
+}
+
pub async fn get_client_redirect_uris<'c>(
executor: impl Executor<'c, Database = MySql>,
id: Uuid,
@@ -328,6 +340,20 @@ pub async fn update_client_default_scopes<'c>(
.await
}
+pub async fn update_client_trusted<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ is_trusted: bool,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ "UPDATE clients SET trusted = ? WHERE id = ?",
+ is_trusted,
+ id
+ )
+ .execute(executor)
+ .await
+}
+
pub async fn update_client_redirect_uris<'c>(
mut transaction: Transaction<'c, MySql>,
id: Uuid,