summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/api/clients.rs966
-rw-r--r--src/api/liveops.rs22
-rw-r--r--src/api/mod.rs26
-rw-r--r--src/api/oauth.rs1852
-rw-r--r--src/api/ops.rs140
-rw-r--r--src/api/users.rs544
-rw-r--r--src/main.rs216
-rw-r--r--src/models/client.rs330
-rw-r--r--src/models/mod.rs4
-rw-r--r--src/models/user.rs98
-rw-r--r--src/resources/languages.rs134
-rw-r--r--src/resources/mod.rs8
-rw-r--r--src/resources/scripts.rs76
-rw-r--r--src/resources/style.rs108
-rw-r--r--src/resources/templates.rs202
-rw-r--r--src/scopes/admin.rs56
-rw-r--r--src/scopes/mod.rs256
-rw-r--r--src/services/authorization.rs164
-rw-r--r--src/services/config.rs148
-rw-r--r--src/services/crypto.rs194
-rw-r--r--src/services/db.rs30
-rw-r--r--src/services/db/client.rs784
-rw-r--r--src/services/db/jwt.rs398
-rw-r--r--src/services/db/user.rs472
-rw-r--r--src/services/id.rs54
-rw-r--r--src/services/jwt.rs582
-rw-r--r--src/services/mod.rs14
-rw-r--r--src/services/secrets.rs48
28 files changed, 3963 insertions, 3963 deletions
diff --git a/src/api/clients.rs b/src/api/clients.rs
index 3f906bb..ded8b81 100644
--- a/src/api/clients.rs
+++ b/src/api/clients.rs
@@ -1,483 +1,483 @@
-use actix_web::http::{header, StatusCode};
-use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope};
-use raise::yeet;
-use serde::{Deserialize, Serialize};
-use sqlx::MySqlPool;
-use thiserror::Error;
-use url::Url;
-use uuid::Uuid;
-
-use crate::models::client::{Client, ClientType, CreateClientError};
-use crate::services::crypto::PasswordHash;
-use crate::services::db::ClientRow;
-use crate::services::{db, id};
-
-#[derive(Debug, Clone, Serialize)]
-#[serde(rename_all = "camelCase")]
-struct ClientResponse {
- client_id: Uuid,
- alias: Box<str>,
- client_type: ClientType,
- allowed_scopes: Box<[Box<str>]>,
- default_scopes: Option<Box<[Box<str>]>>,
- is_trusted: bool,
-}
-
-impl From<ClientRow> for ClientResponse {
- fn from(value: ClientRow) -> Self {
- Self {
- client_id: value.id,
- alias: value.alias.into_boxed_str(),
- client_type: value.client_type,
- allowed_scopes: value
- .allowed_scopes
- .split_whitespace()
- .map(Box::from)
- .collect(),
- default_scopes: value
- .default_scopes
- .map(|s| s.split_whitespace().map(Box::from).collect()),
- is_trusted: value.is_trusted,
- }
- }
-}
-
-#[derive(Debug, Clone, Copy, Error)]
-#[error("No client with the given client ID was found")]
-struct ClientNotFound {
- id: Uuid,
-}
-
-impl ResponseError for ClientNotFound {
- fn status_code(&self) -> StatusCode {
- StatusCode::NOT_FOUND
- }
-}
-
-impl ClientNotFound {
- fn new(id: Uuid) -> Self {
- Self { id }
- }
-}
-
-#[get("/{client_id}")]
-async fn get_client(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- let Some(client) = db::get_client_response(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id))
- };
-
- let redirect_uris_link = format!("</clients/{client_id}/redirect-uris>; rel=\"redirect-uris\"");
- let response: ClientResponse = client.into();
- let response = HttpResponse::Ok()
- .append_header((header::LINK, redirect_uris_link))
- .json(response);
- Ok(response)
-}
-
-#[get("/{client_id}/alias")]
-async fn get_client_alias(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- let Some(alias) = db::get_client_alias(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id))
- };
-
- Ok(HttpResponse::Ok().json(alias))
-}
-
-#[get("/{client_id}/type")]
-async fn get_client_type(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- let Some(client_type) = db::get_client_type(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id))
- };
-
- Ok(HttpResponse::Ok().json(client_type))
-}
-
-#[get("/{client_id}/redirect-uris")]
-async fn get_client_redirect_uris(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- if !db::client_id_exists(db, id).await.unwrap() {
- yeet!(ClientNotFound::new(id))
- };
-
- let redirect_uris = db::get_client_redirect_uris(db, id).await.unwrap();
-
- Ok(HttpResponse::Ok().json(redirect_uris))
-}
-
-#[get("/{client_id}/allowed-scopes")]
-async fn get_client_allowed_scopes(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- let Some(allowed_scopes) = db::get_client_allowed_scopes(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id))
- };
-
- let allowed_scopes = allowed_scopes.split_whitespace().collect::<Box<[&str]>>();
-
- Ok(HttpResponse::Ok().json(allowed_scopes))
-}
-
-#[get("/{client_id}/default-scopes")]
-async fn get_client_default_scopes(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- let Some(default_scopes) = db::get_client_default_scopes(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id))
- };
-
- let default_scopes = default_scopes.map(|scopes| {
- scopes
- .split_whitespace()
- .map(Box::from)
- .collect::<Box<[Box<str>]>>()
- });
-
- Ok(HttpResponse::Ok().json(default_scopes))
-}
-
-#[get("/{client_id}/is-trusted")]
-async fn get_client_is_trusted(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- let Some(is_trusted) = db::is_client_trusted(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id))
- };
-
- Ok(HttpResponse::Ok().json(is_trusted))
-}
-
-#[derive(Clone, Deserialize)]
-#[serde(rename_all = "camelCase")]
-struct ClientRequest {
- alias: Box<str>,
- ty: ClientType,
- redirect_uris: Box<[Url]>,
- secret: Option<Box<str>>,
- allowed_scopes: Box<[Box<str>]>,
- default_scopes: Option<Box<[Box<str>]>>,
- trusted: bool,
-}
-
-#[derive(Debug, Clone, Error)]
-#[error("The given client alias is already taken")]
-struct AliasTakenError {
- alias: Box<str>,
-}
-
-impl ResponseError for AliasTakenError {
- fn status_code(&self) -> StatusCode {
- StatusCode::CONFLICT
- }
-}
-
-impl AliasTakenError {
- fn new(alias: &str) -> Self {
- Self {
- alias: Box::from(alias),
- }
- }
-}
-
-#[post("")]
-async fn create_client(
- body: web::Json<ClientRequest>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let alias = &body.alias;
-
- if db::client_alias_exists(db, &alias).await.unwrap() {
- yeet!(AliasTakenError::new(&alias).into());
- }
-
- let id = id::new_id(db, db::client_id_exists).await.unwrap();
- let client = Client::new(
- id,
- &alias,
- body.ty,
- body.secret.as_deref(),
- body.allowed_scopes.clone(),
- body.default_scopes.clone(),
- &body.redirect_uris,
- body.trusted,
- )
- .map_err(|e| e.unwrap())?;
-
- let transaction = db.begin().await.unwrap();
- db::create_client(transaction, &client).await.unwrap();
-
- let response = HttpResponse::Created()
- .insert_header((header::LOCATION, format!("clients/{id}")))
- .finish();
- Ok(response)
-}
-
-#[derive(Debug, Clone, Error)]
-enum UpdateClientError {
- #[error(transparent)]
- NotFound(#[from] ClientNotFound),
- #[error(transparent)]
- ClientError(#[from] CreateClientError),
- #[error(transparent)]
- AliasTaken(#[from] AliasTakenError),
-}
-
-impl ResponseError for UpdateClientError {
- fn status_code(&self) -> StatusCode {
- match self {
- Self::NotFound(e) => e.status_code(),
- Self::ClientError(e) => e.status_code(),
- Self::AliasTaken(e) => e.status_code(),
- }
- }
-}
-
-#[put("/{id}")]
-async fn update_client(
- id: web::Path<Uuid>,
- body: web::Json<ClientRequest>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
- let alias = &body.alias;
-
- let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id).into())
- };
- if old_alias != alias.clone() && db::client_alias_exists(db, &alias).await.unwrap() {
- yeet!(AliasTakenError::new(&alias).into());
- }
-
- let client = Client::new(
- id,
- &alias,
- body.ty,
- body.secret.as_deref(),
- body.allowed_scopes.clone(),
- body.default_scopes.clone(),
- &body.redirect_uris,
- body.trusted,
- )
- .map_err(|e| e.unwrap())?;
-
- let transaction = db.begin().await.unwrap();
- db::update_client(transaction, &client).await.unwrap();
-
- let response = HttpResponse::NoContent().finish();
- Ok(response)
-}
-
-#[put("/{id}/alias")]
-async fn update_client_alias(
- id: web::Path<Uuid>,
- body: web::Json<Box<str>>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
- let alias = body.0;
-
- let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id).into())
- };
- if old_alias == alias {
- return Ok(HttpResponse::NoContent().finish());
- }
- if db::client_alias_exists(db, &alias).await.unwrap() {
- yeet!(AliasTakenError::new(&alias).into());
- }
-
- db::update_client_alias(db, id, &alias).await.unwrap();
-
- let response = HttpResponse::NoContent().finish();
- Ok(response)
-}
-
-#[put("/{id}/type")]
-async fn update_client_type(
- id: web::Path<Uuid>,
- body: web::Json<ClientType>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
- let ty = body.0;
-
- if !db::client_id_exists(db, id).await.unwrap() {
- yeet!(ClientNotFound::new(id).into());
- }
-
- db::update_client_type(db, id, ty).await.unwrap();
-
- Ok(HttpResponse::NoContent().finish())
-}
-
-#[put("/{id}/allowed-scopes")]
-async fn update_client_allowed_scopes(
- id: web::Path<Uuid>,
- body: web::Json<Box<[Box<str>]>>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
- let allowed_scopes = body.0.join(" ");
-
- if !db::client_id_exists(db, id).await.unwrap() {
- yeet!(ClientNotFound::new(id).into());
- }
-
- db::update_client_allowed_scopes(db, id, &allowed_scopes)
- .await
- .unwrap();
-
- Ok(HttpResponse::NoContent().finish())
-}
-
-#[put("/{id}/default-scopes")]
-async fn update_client_default_scopes(
- id: web::Path<Uuid>,
- body: web::Json<Option<Box<[Box<str>]>>>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
- let default_scopes = body.0.map(|s| s.join(" "));
-
- if !db::client_id_exists(db, id).await.unwrap() {
- yeet!(ClientNotFound::new(id).into());
- }
-
- db::update_client_default_scopes(db, id, default_scopes)
- .await
- .unwrap();
-
- Ok(HttpResponse::NoContent().finish())
-}
-
-#[put("/{id}/is-trusted")]
-async fn update_client_is_trusted(
- id: web::Path<Uuid>,
- body: web::Json<bool>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
- let is_trusted = *body;
-
- if !db::client_id_exists(db, id).await.unwrap() {
- yeet!(ClientNotFound::new(id).into());
- }
-
- db::update_client_trusted(db, id, is_trusted).await.unwrap();
-
- Ok(HttpResponse::NoContent().finish())
-}
-
-#[put("/{id}/redirect-uris")]
-async fn update_client_redirect_uris(
- id: web::Path<Uuid>,
- body: web::Json<Box<[Url]>>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
-
- for uri in body.0.iter() {
- if uri.scheme() != "https" {
- yeet!(CreateClientError::NonHttpsUri.into());
- }
-
- if uri.fragment().is_some() {
- yeet!(CreateClientError::UriFragment.into())
- }
- }
-
- if !db::client_id_exists(db, id).await.unwrap() {
- yeet!(ClientNotFound::new(id).into());
- }
-
- let transaction = db.begin().await.unwrap();
- db::update_client_redirect_uris(transaction, id, &body.0)
- .await
- .unwrap();
-
- Ok(HttpResponse::NoContent().finish())
-}
-
-#[put("{id}/secret")]
-async fn update_client_secret(
- id: web::Path<Uuid>,
- body: web::Json<Option<Box<str>>>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
-
- let Some(client_type) = db::get_client_type(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id).into())
- };
-
- if client_type == ClientType::Confidential && body.is_none() {
- yeet!(CreateClientError::NoSecret.into())
- }
-
- let secret = body.0.map(|s| PasswordHash::new(&s).unwrap());
- db::update_client_secret(db, id, secret).await.unwrap();
-
- Ok(HttpResponse::NoContent().finish())
-}
-
-pub fn service() -> Scope {
- web::scope("/clients")
- .service(get_client)
- .service(get_client_alias)
- .service(get_client_type)
- .service(get_client_allowed_scopes)
- .service(get_client_default_scopes)
- .service(get_client_redirect_uris)
- .service(get_client_is_trusted)
- .service(create_client)
- .service(update_client)
- .service(update_client_alias)
- .service(update_client_type)
- .service(update_client_allowed_scopes)
- .service(update_client_default_scopes)
- .service(update_client_redirect_uris)
- .service(update_client_secret)
-}
+use actix_web::http::{header, StatusCode};
+use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope};
+use raise::yeet;
+use serde::{Deserialize, Serialize};
+use sqlx::MySqlPool;
+use thiserror::Error;
+use url::Url;
+use uuid::Uuid;
+
+use crate::models::client::{Client, ClientType, CreateClientError};
+use crate::services::crypto::PasswordHash;
+use crate::services::db::ClientRow;
+use crate::services::{db, id};
+
+#[derive(Debug, Clone, Serialize)]
+#[serde(rename_all = "camelCase")]
+struct ClientResponse {
+ client_id: Uuid,
+ alias: Box<str>,
+ client_type: ClientType,
+ allowed_scopes: Box<[Box<str>]>,
+ default_scopes: Option<Box<[Box<str>]>>,
+ is_trusted: bool,
+}
+
+impl From<ClientRow> for ClientResponse {
+ fn from(value: ClientRow) -> Self {
+ Self {
+ client_id: value.id,
+ alias: value.alias.into_boxed_str(),
+ client_type: value.client_type,
+ allowed_scopes: value
+ .allowed_scopes
+ .split_whitespace()
+ .map(Box::from)
+ .collect(),
+ default_scopes: value
+ .default_scopes
+ .map(|s| s.split_whitespace().map(Box::from).collect()),
+ is_trusted: value.is_trusted,
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, Error)]
+#[error("No client with the given client ID was found")]
+struct ClientNotFound {
+ id: Uuid,
+}
+
+impl ResponseError for ClientNotFound {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::NOT_FOUND
+ }
+}
+
+impl ClientNotFound {
+ fn new(id: Uuid) -> Self {
+ Self { id }
+ }
+}
+
+#[get("/{client_id}")]
+async fn get_client(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(client) = db::get_client_response(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let redirect_uris_link = format!("</clients/{client_id}/redirect-uris>; rel=\"redirect-uris\"");
+ let response: ClientResponse = client.into();
+ let response = HttpResponse::Ok()
+ .append_header((header::LINK, redirect_uris_link))
+ .json(response);
+ Ok(response)
+}
+
+#[get("/{client_id}/alias")]
+async fn get_client_alias(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(alias) = db::get_client_alias(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ Ok(HttpResponse::Ok().json(alias))
+}
+
+#[get("/{client_id}/type")]
+async fn get_client_type(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(client_type) = db::get_client_type(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ Ok(HttpResponse::Ok().json(client_type))
+}
+
+#[get("/{client_id}/redirect-uris")]
+async fn get_client_redirect_uris(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let redirect_uris = db::get_client_redirect_uris(db, id).await.unwrap();
+
+ Ok(HttpResponse::Ok().json(redirect_uris))
+}
+
+#[get("/{client_id}/allowed-scopes")]
+async fn get_client_allowed_scopes(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(allowed_scopes) = db::get_client_allowed_scopes(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let allowed_scopes = allowed_scopes.split_whitespace().collect::<Box<[&str]>>();
+
+ Ok(HttpResponse::Ok().json(allowed_scopes))
+}
+
+#[get("/{client_id}/default-scopes")]
+async fn get_client_default_scopes(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(default_scopes) = db::get_client_default_scopes(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let default_scopes = default_scopes.map(|scopes| {
+ scopes
+ .split_whitespace()
+ .map(Box::from)
+ .collect::<Box<[Box<str>]>>()
+ });
+
+ Ok(HttpResponse::Ok().json(default_scopes))
+}
+
+#[get("/{client_id}/is-trusted")]
+async fn get_client_is_trusted(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(is_trusted) = db::is_client_trusted(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ Ok(HttpResponse::Ok().json(is_trusted))
+}
+
+#[derive(Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct ClientRequest {
+ alias: Box<str>,
+ ty: ClientType,
+ redirect_uris: Box<[Url]>,
+ secret: Option<Box<str>>,
+ allowed_scopes: Box<[Box<str>]>,
+ default_scopes: Option<Box<[Box<str>]>>,
+ trusted: bool,
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("The given client alias is already taken")]
+struct AliasTakenError {
+ alias: Box<str>,
+}
+
+impl ResponseError for AliasTakenError {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::CONFLICT
+ }
+}
+
+impl AliasTakenError {
+ fn new(alias: &str) -> Self {
+ Self {
+ alias: Box::from(alias),
+ }
+ }
+}
+
+#[post("")]
+async fn create_client(
+ body: web::Json<ClientRequest>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let alias = &body.alias;
+
+ if db::client_alias_exists(db, &alias).await.unwrap() {
+ yeet!(AliasTakenError::new(&alias).into());
+ }
+
+ let id = id::new_id(db, db::client_id_exists).await.unwrap();
+ let client = Client::new(
+ id,
+ &alias,
+ body.ty,
+ body.secret.as_deref(),
+ body.allowed_scopes.clone(),
+ body.default_scopes.clone(),
+ &body.redirect_uris,
+ body.trusted,
+ )
+ .map_err(|e| e.unwrap())?;
+
+ let transaction = db.begin().await.unwrap();
+ db::create_client(transaction, &client).await.unwrap();
+
+ let response = HttpResponse::Created()
+ .insert_header((header::LOCATION, format!("clients/{id}")))
+ .finish();
+ Ok(response)
+}
+
+#[derive(Debug, Clone, Error)]
+enum UpdateClientError {
+ #[error(transparent)]
+ NotFound(#[from] ClientNotFound),
+ #[error(transparent)]
+ ClientError(#[from] CreateClientError),
+ #[error(transparent)]
+ AliasTaken(#[from] AliasTakenError),
+}
+
+impl ResponseError for UpdateClientError {
+ fn status_code(&self) -> StatusCode {
+ match self {
+ Self::NotFound(e) => e.status_code(),
+ Self::ClientError(e) => e.status_code(),
+ Self::AliasTaken(e) => e.status_code(),
+ }
+ }
+}
+
+#[put("/{id}")]
+async fn update_client(
+ id: web::Path<Uuid>,
+ body: web::Json<ClientRequest>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let alias = &body.alias;
+
+ let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id).into())
+ };
+ if old_alias != alias.clone() && db::client_alias_exists(db, &alias).await.unwrap() {
+ yeet!(AliasTakenError::new(&alias).into());
+ }
+
+ let client = Client::new(
+ id,
+ &alias,
+ body.ty,
+ body.secret.as_deref(),
+ body.allowed_scopes.clone(),
+ body.default_scopes.clone(),
+ &body.redirect_uris,
+ body.trusted,
+ )
+ .map_err(|e| e.unwrap())?;
+
+ let transaction = db.begin().await.unwrap();
+ db::update_client(transaction, &client).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{id}/alias")]
+async fn update_client_alias(
+ id: web::Path<Uuid>,
+ body: web::Json<Box<str>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let alias = body.0;
+
+ let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id).into())
+ };
+ if old_alias == alias {
+ return Ok(HttpResponse::NoContent().finish());
+ }
+ if db::client_alias_exists(db, &alias).await.unwrap() {
+ yeet!(AliasTakenError::new(&alias).into());
+ }
+
+ db::update_client_alias(db, id, &alias).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{id}/type")]
+async fn update_client_type(
+ id: web::Path<Uuid>,
+ body: web::Json<ClientType>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let ty = body.0;
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_type(db, id, ty).await.unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("/{id}/allowed-scopes")]
+async fn update_client_allowed_scopes(
+ id: web::Path<Uuid>,
+ body: web::Json<Box<[Box<str>]>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let allowed_scopes = body.0.join(" ");
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_allowed_scopes(db, id, &allowed_scopes)
+ .await
+ .unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("/{id}/default-scopes")]
+async fn update_client_default_scopes(
+ id: web::Path<Uuid>,
+ body: web::Json<Option<Box<[Box<str>]>>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let default_scopes = body.0.map(|s| s.join(" "));
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_default_scopes(db, id, default_scopes)
+ .await
+ .unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("/{id}/is-trusted")]
+async fn update_client_is_trusted(
+ id: web::Path<Uuid>,
+ body: web::Json<bool>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let is_trusted = *body;
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_trusted(db, id, is_trusted).await.unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("/{id}/redirect-uris")]
+async fn update_client_redirect_uris(
+ id: web::Path<Uuid>,
+ body: web::Json<Box<[Url]>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+
+ for uri in body.0.iter() {
+ if uri.scheme() != "https" {
+ yeet!(CreateClientError::NonHttpsUri.into());
+ }
+
+ if uri.fragment().is_some() {
+ yeet!(CreateClientError::UriFragment.into())
+ }
+ }
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ let transaction = db.begin().await.unwrap();
+ db::update_client_redirect_uris(transaction, id, &body.0)
+ .await
+ .unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("{id}/secret")]
+async fn update_client_secret(
+ id: web::Path<Uuid>,
+ body: web::Json<Option<Box<str>>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+
+ let Some(client_type) = db::get_client_type(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id).into())
+ };
+
+ if client_type == ClientType::Confidential && body.is_none() {
+ yeet!(CreateClientError::NoSecret.into())
+ }
+
+ let secret = body.0.map(|s| PasswordHash::new(&s).unwrap());
+ db::update_client_secret(db, id, secret).await.unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+pub fn service() -> Scope {
+ web::scope("/clients")
+ .service(get_client)
+ .service(get_client_alias)
+ .service(get_client_type)
+ .service(get_client_allowed_scopes)
+ .service(get_client_default_scopes)
+ .service(get_client_redirect_uris)
+ .service(get_client_is_trusted)
+ .service(create_client)
+ .service(update_client)
+ .service(update_client_alias)
+ .service(update_client_type)
+ .service(update_client_allowed_scopes)
+ .service(update_client_default_scopes)
+ .service(update_client_redirect_uris)
+ .service(update_client_secret)
+}
diff --git a/src/api/liveops.rs b/src/api/liveops.rs
index d4bf129..2caf6e3 100644
--- a/src/api/liveops.rs
+++ b/src/api/liveops.rs
@@ -1,11 +1,11 @@
-use actix_web::{get, web, HttpResponse, Scope};
-
-/// Simple ping
-#[get("/ping")]
-async fn ping() -> HttpResponse {
- HttpResponse::Ok().finish()
-}
-
-pub fn service() -> Scope {
- web::scope("/liveops").service(ping)
-}
+use actix_web::{get, web, HttpResponse, Scope};
+
+/// Simple ping
+#[get("/ping")]
+async fn ping() -> HttpResponse {
+ HttpResponse::Ok().finish()
+}
+
+pub fn service() -> Scope {
+ web::scope("/liveops").service(ping)
+}
diff --git a/src/api/mod.rs b/src/api/mod.rs
index 0ab4037..9059e71 100644
--- a/src/api/mod.rs
+++ b/src/api/mod.rs
@@ -1,13 +1,13 @@
-mod clients;
-mod liveops;
-mod oauth;
-mod ops;
-mod users;
-
-pub use clients::service as clients;
-pub use liveops::service as liveops;
-pub use oauth::service as oauth;
-pub use ops::service as ops;
-pub use users::service as users;
-
-pub use oauth::AuthorizationParameters;
+mod clients;
+mod liveops;
+mod oauth;
+mod ops;
+mod users;
+
+pub use clients::service as clients;
+pub use liveops::service as liveops;
+pub use oauth::service as oauth;
+pub use ops::service as ops;
+pub use users::service as users;
+
+pub use oauth::AuthorizationParameters;
diff --git a/src/api/oauth.rs b/src/api/oauth.rs
index f1aa012..3422d2f 100644
--- a/src/api/oauth.rs
+++ b/src/api/oauth.rs
@@ -1,926 +1,926 @@
-use std::ops::Deref;
-use std::str::FromStr;
-
-use actix_web::http::{header, StatusCode};
-use actix_web::{
- get, post, web, HttpRequest, HttpResponse, HttpResponseBuilder, ResponseError, Scope,
-};
-use chrono::Duration;
-use exun::{Expect, RawUnexpected, ResultErrorExt, UnexpectedError};
-use raise::yeet;
-use serde::{Deserialize, Serialize};
-use sqlx::MySqlPool;
-use tera::Tera;
-use thiserror::Error;
-use unic_langid::subtags::Language;
-use url::Url;
-use uuid::Uuid;
-
-use crate::models::client::ClientType;
-use crate::resources::{languages, templates};
-use crate::scopes;
-use crate::services::jwt::VerifyJwtError;
-use crate::services::{authorization, config, db, jwt};
-
-const REALLY_BAD_ERROR_PAGE: &str = "<!DOCTYPE html><html><head><title>Internal Server Error</title></head><body>Internal Server Error</body></html>";
-
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
-#[serde(rename_all = "snake_case")]
-enum ResponseType {
- Code,
- Token,
- #[serde(other)]
- Unsupported,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct AuthorizationParameters {
- response_type: ResponseType,
- client_id: Box<str>,
- redirect_uri: Option<Url>,
- scope: Option<Box<str>>,
- state: Option<Box<str>>,
-}
-
-#[derive(Clone, Deserialize)]
-struct AuthorizeCredentials {
- username: Box<str>,
- password: Box<str>,
-}
-
-#[derive(Clone, Serialize)]
-struct AuthCodeResponse {
- code: Box<str>,
- state: Option<Box<str>>,
-}
-
-#[derive(Clone, Serialize)]
-struct AuthTokenResponse {
- access_token: Box<str>,
- token_type: &'static str,
- expires_in: i64,
- scope: Box<str>,
- state: Option<Box<str>>,
-}
-
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
-#[serde(rename_all = "camelCase")]
-enum AuthorizeErrorType {
- InvalidRequest,
- UnauthorizedClient,
- AccessDenied,
- UnsupportedResponseType,
- InvalidScope,
- ServerError,
- TemporarilyUnavailable,
-}
-
-#[derive(Debug, Clone, Error, Serialize)]
-#[error("{error_description}")]
-struct AuthorizeError {
- error: AuthorizeErrorType,
- error_description: Box<str>,
- // TODO error uri
- state: Option<Box<str>>,
- #[serde(skip)]
- redirect_uri: Url,
-}
-
-impl AuthorizeError {
- fn no_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self {
- Self {
- error: AuthorizeErrorType::InvalidScope,
- error_description: Box::from(
- "No scope was provided, and the client does not have a default scope",
- ),
- state,
- redirect_uri,
- }
- }
-
- fn unsupported_response_type(redirect_uri: Url, state: Option<Box<str>>) -> Self {
- Self {
- error: AuthorizeErrorType::UnsupportedResponseType,
- error_description: Box::from("The given response type is not supported"),
- state,
- redirect_uri,
- }
- }
-
- fn invalid_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self {
- Self {
- error: AuthorizeErrorType::InvalidScope,
- error_description: Box::from("The given scope exceeds what the client is allowed"),
- state,
- redirect_uri,
- }
- }
-
- fn internal_server_error(redirect_uri: Url, state: Option<Box<str>>) -> Self {
- Self {
- error: AuthorizeErrorType::ServerError,
- error_description: "An unexpected error occurred".into(),
- state,
- redirect_uri,
- }
- }
-}
-
-impl ResponseError for AuthorizeError {
- fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
- let query = Some(serde_urlencoded::to_string(self).unwrap());
- let query = query.as_deref();
- let mut url = self.redirect_uri.clone();
- url.set_query(query);
-
- HttpResponse::Found()
- .insert_header((header::LOCATION, url.as_str()))
- .finish()
- }
-}
-
-fn error_page(
- tera: &Tera,
- translations: &languages::Translations,
- error: templates::ErrorPage,
-) -> Result<String, RawUnexpected> {
- // TODO find a better way of doing languages
- let language = Language::from_str("en").unwrap();
- let translations = translations.clone();
- let page = templates::error_page(&tera, language, translations, error)?;
- Ok(page)
-}
-
-async fn get_redirect_uri(
- redirect_uri: &Option<Url>,
- db: &MySqlPool,
- client_id: Uuid,
-) -> Result<Url, Expect<templates::ErrorPage>> {
- if let Some(uri) = &redirect_uri {
- let redirect_uri = uri.clone();
- if !db::client_has_redirect_uri(db, client_id, &redirect_uri)
- .await
- .map_err(|e| UnexpectedError::from(e))
- .unexpect()?
- {
- yeet!(Expect::Expected(templates::ErrorPage::InvalidRedirectUri));
- }
-
- Ok(redirect_uri)
- } else {
- let redirect_uris = db::get_client_redirect_uris(db, client_id)
- .await
- .map_err(|e| UnexpectedError::from(e))
- .unexpect()?;
- if redirect_uris.len() != 1 {
- yeet!(Expect::Expected(templates::ErrorPage::MissingRedirectUri));
- }
-
- Ok(redirect_uris.get(0).unwrap().clone())
- }
-}
-
-async fn get_scope(
- scope: &Option<Box<str>>,
- db: &MySqlPool,
- client_id: Uuid,
- redirect_uri: &Url,
- state: &Option<Box<str>>,
-) -> Result<Box<str>, Expect<AuthorizeError>> {
- let scope = if let Some(scope) = &scope {
- scope.clone()
- } else {
- let default_scopes = db::get_client_default_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- let Some(scope) = default_scopes else {
- yeet!(AuthorizeError::no_scope(redirect_uri.clone(), state.clone()).into())
- };
- scope
- };
-
- // verify scope is valid
- let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- if !scopes::is_subset_of(&scope, &allowed_scopes) {
- yeet!(AuthorizeError::invalid_scope(redirect_uri.clone(), state.clone()).into());
- }
-
- Ok(scope)
-}
-
-async fn authenticate_user(
- db: &MySqlPool,
- username: &str,
- password: &str,
-) -> Result<Option<Uuid>, RawUnexpected> {
- let Some(user) = db::get_user_by_username(db, username).await? else {
- return Ok(None);
- };
-
- if user.check_password(password)? {
- Ok(Some(user.id))
- } else {
- Ok(None)
- }
-}
-
-#[post("/authorize")]
-async fn authorize(
- db: web::Data<MySqlPool>,
- req: web::Query<AuthorizationParameters>,
- credentials: web::Json<AuthorizeCredentials>,
- tera: web::Data<Tera>,
- translations: web::Data<languages::Translations>,
-) -> Result<HttpResponse, AuthorizeError> {
- // TODO protect against brute force attacks
- let db = db.get_ref();
- let Ok(client_id) = db::get_client_id_by_alias(db, &req.client_id).await else {
- let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page));
- };
- let Some(client_id) = client_id else {
- let page = error_page(&tera, &translations, templates::ErrorPage::ClientNotFound).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::NotFound().content_type("text/html").body(page));
- };
- let Ok(config) = config::get_config() else {
- let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page));
- };
-
- let self_id = config.url;
- let state = req.state.clone();
-
- // get redirect uri
- let mut redirect_uri = match get_redirect_uri(&req.redirect_uri, db, client_id).await {
- Ok(uri) => uri,
- Err(e) => {
- let e = e
- .expected()
- .unwrap_or(templates::ErrorPage::InternalServerError);
- let page = error_page(&tera, &translations, e)
- .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::BadRequest()
- .content_type("text/html")
- .body(page));
- }
- };
-
- // authenticate user
- let Some(user_id) = authenticate_user(db, &credentials.username, &credentials.password)
- .await
- .unwrap() else
- {
- let language = Language::from_str("en").unwrap();
- let translations = translations.get_ref().clone();
- let page = templates::login_error_page(&tera, &req, language, translations).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::Ok().content_type("text/html").body(page));
- };
-
- let internal_server_error =
- AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone());
-
- // get scope
- let scope = match get_scope(&req.scope, db, client_id, &redirect_uri, &state).await {
- Ok(scope) => scope,
- Err(e) => {
- let e = e.expected().unwrap_or(internal_server_error);
- return Err(e);
- }
- };
-
- match req.response_type {
- ResponseType::Code => {
- // create auth code
- let code =
- jwt::Claims::auth_code(db, self_id, client_id, user_id, &scope, &redirect_uri)
- .await
- .map_err(|_| internal_server_error.clone())?;
- let code = code.to_jwt().map_err(|_| internal_server_error.clone())?;
-
- let response = AuthCodeResponse { code, state };
- let query =
- Some(serde_urlencoded::to_string(response).map_err(|_| internal_server_error)?);
- let query = query.as_deref();
- redirect_uri.set_query(query);
-
- Ok(HttpResponse::Found()
- .append_header((header::LOCATION, redirect_uri.as_str()))
- .finish())
- }
- ResponseType::Token => {
- // create access token
- let duration = Duration::hours(1);
- let access_token =
- jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope)
- .await
- .map_err(|_| internal_server_error.clone())?;
-
- let access_token = access_token
- .to_jwt()
- .map_err(|_| internal_server_error.clone())?;
- let expires_in = duration.num_seconds();
- let token_type = "bearer";
- let response = AuthTokenResponse {
- access_token,
- expires_in,
- token_type,
- scope,
- state,
- };
-
- let fragment = Some(
- serde_urlencoded::to_string(response).map_err(|_| internal_server_error.clone())?,
- );
- let fragment = fragment.as_deref();
- redirect_uri.set_fragment(fragment);
-
- Ok(HttpResponse::Found()
- .append_header((header::LOCATION, redirect_uri.as_str()))
- .finish())
- }
- _ => Err(AuthorizeError::invalid_scope(redirect_uri, state)),
- }
-}
-
-#[get("/authorize")]
-async fn authorize_page(
- db: web::Data<MySqlPool>,
- tera: web::Data<Tera>,
- translations: web::Data<languages::Translations>,
- request: HttpRequest,
-) -> Result<HttpResponse, AuthorizeError> {
- let Ok(language) = Language::from_str("en") else {
- let page = String::from(REALLY_BAD_ERROR_PAGE);
- return Ok(HttpResponse::InternalServerError()
- .content_type("text/html")
- .body(page));
- };
- let translations = translations.get_ref().clone();
-
- let params = request.query_string();
- let params = serde_urlencoded::from_str::<AuthorizationParameters>(params);
- let Ok(params) = params else {
- let page = error_page(
- &tera,
- &translations,
- templates::ErrorPage::InvalidRequest,
- )
- .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::BadRequest()
- .content_type("text/html")
- .body(page));
- };
-
- let db = db.get_ref();
- let Ok(client_id) = db::get_client_id_by_alias(db, &params.client_id).await else {
- let page = templates::error_page(
- &tera,
- language,
- translations,
- templates::ErrorPage::InternalServerError,
- )
- .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::InternalServerError()
- .content_type("text/html")
- .body(page));
- };
- let Some(client_id) = client_id else {
- let page = templates::error_page(
- &tera,
- language,
- translations,
- templates::ErrorPage::ClientNotFound,
- )
- .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::NotFound()
- .content_type("text/html")
- .body(page));
- };
-
- // verify redirect uri
- let redirect_uri = match get_redirect_uri(&params.redirect_uri, db, client_id).await {
- Ok(uri) => uri,
- Err(e) => {
- let e = e
- .expected()
- .unwrap_or(templates::ErrorPage::InternalServerError);
- let page = error_page(&tera, &translations, e)
- .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::BadRequest()
- .content_type("text/html")
- .body(page));
- }
- };
-
- let state = &params.state;
- let internal_server_error =
- AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone());
-
- // verify scope
- let _ = match get_scope(&params.scope, db, client_id, &redirect_uri, &params.state).await {
- Ok(scope) => scope,
- Err(e) => {
- let e = e.expected().unwrap_or(internal_server_error);
- return Err(e);
- }
- };
-
- // verify response type
- if params.response_type == ResponseType::Unsupported {
- return Err(AuthorizeError::unsupported_response_type(
- redirect_uri,
- params.state,
- ));
- }
-
- // TODO find a better way of doing languages
- let language = Language::from_str("en").unwrap();
- let page = templates::login_page(&tera, &params, language, translations).unwrap();
- Ok(HttpResponse::Ok().content_type("text/html").body(page))
-}
-
-#[derive(Clone, Deserialize)]
-#[serde(tag = "grant_type")]
-#[serde(rename_all = "snake_case")]
-enum GrantType {
- AuthorizationCode {
- code: Box<str>,
- redirect_uri: Url,
- #[serde(rename = "client_id")]
- client_alias: Box<str>,
- },
- Password {
- username: Box<str>,
- password: Box<str>,
- scope: Option<Box<str>>,
- },
- ClientCredentials {
- scope: Option<Box<str>>,
- },
- RefreshToken {
- refresh_token: Box<str>,
- scope: Option<Box<str>>,
- },
- #[serde(other)]
- Unsupported,
-}
-
-#[derive(Clone, Deserialize)]
-struct TokenRequest {
- #[serde(flatten)]
- grant_type: GrantType,
- // TODO support optional client credentials in here
-}
-
-#[derive(Clone, Serialize)]
-struct TokenResponse {
- access_token: Box<str>,
- token_type: Box<str>,
- expires_in: i64,
- refresh_token: Option<Box<str>>,
- scope: Box<str>,
-}
-
-#[derive(Debug, Clone, Serialize)]
-#[serde(rename_all = "snake_case")]
-enum TokenErrorType {
- InvalidRequest,
- InvalidClient,
- InvalidGrant,
- UnauthorizedClient,
- UnsupportedGrantType,
- InvalidScope,
-}
-
-#[derive(Debug, Clone, Error, Serialize)]
-#[error("{error_description}")]
-struct TokenError {
- #[serde(skip)]
- status_code: StatusCode,
- error: TokenErrorType,
- error_description: Box<str>,
- // TODO error uri
-}
-
-impl TokenError {
- fn invalid_request() -> Self {
- // TODO make this description better, and all the other ones while you're at it
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::InvalidRequest,
- error_description: "Invalid request".into(),
- }
- }
-
- fn unsupported_grant_type() -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::UnsupportedGrantType,
- error_description: "The given grant type is not supported".into(),
- }
- }
-
- fn bad_auth_code(error: VerifyJwtError) -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::InvalidGrant,
- error_description: error.to_string().into_boxed_str(),
- }
- }
-
- fn no_authorization() -> Self {
- Self {
- status_code: StatusCode::UNAUTHORIZED,
- error: TokenErrorType::InvalidClient,
- error_description: Box::from(
- "Client credentials must be provided in the HTTP Authorization header",
- ),
- }
- }
-
- fn client_not_found(alias: &str) -> Self {
- Self {
- status_code: StatusCode::UNAUTHORIZED,
- error: TokenErrorType::InvalidClient,
- error_description: format!("No client with the client id: {alias} was found")
- .into_boxed_str(),
- }
- }
-
- fn mismatch_client_id() -> Self {
- Self {
- status_code: StatusCode::UNAUTHORIZED,
- error: TokenErrorType::InvalidClient,
- error_description: Box::from("The client ID in the Authorization header is not the same as the client ID in the request body"),
- }
- }
-
- fn incorrect_client_secret() -> Self {
- Self {
- status_code: StatusCode::UNAUTHORIZED,
- error: TokenErrorType::InvalidClient,
- error_description: "The client secret is incorrect".into(),
- }
- }
-
- fn client_not_confidential(alias: &str) -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::UnauthorizedClient,
- error_description: format!("Only a confidential client may be used with this endpoint. The {alias} client is a public client.")
- .into_boxed_str(),
- }
- }
-
- fn no_scope() -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::InvalidScope,
- error_description: Box::from(
- "No scope was provided, and the client doesn't have a default scope",
- ),
- }
- }
-
- fn excessive_scope() -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::InvalidScope,
- error_description: Box::from(
- "The given scope exceeds what the client is allowed to have",
- ),
- }
- }
-
- fn bad_refresh_token(err: VerifyJwtError) -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::InvalidGrant,
- error_description: err.to_string().into_boxed_str(),
- }
- }
-
- fn untrusted_client() -> Self {
- Self {
- status_code: StatusCode::UNAUTHORIZED,
- error: TokenErrorType::InvalidClient,
- error_description: "Only trusted clients may use this grant".into(),
- }
- }
-
- fn incorrect_user_credentials() -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::InvalidRequest,
- error_description: "The given credentials are incorrect".into(),
- }
- }
-}
-
-impl ResponseError for TokenError {
- fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
- let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
-
- let mut builder = HttpResponseBuilder::new(self.status_code);
-
- if self.status_code.as_u16() == 401 {
- builder.insert_header((header::WWW_AUTHENTICATE, "Basic charset=\"UTF-8\""));
- }
-
- builder
- .insert_header(cache_control)
- .insert_header((header::PRAGMA, "no-cache"))
- .json(self.clone())
- }
-}
-
-#[post("/token")]
-async fn token(
- db: web::Data<MySqlPool>,
- req: web::Bytes,
- authorization: Option<web::Header<authorization::BasicAuthorization>>,
-) -> HttpResponse {
- // TODO protect against brute force attacks
- let db = db.get_ref();
- let request = serde_json::from_slice::<TokenRequest>(&req);
- let Ok(request) = request else {
- return TokenError::invalid_request().error_response();
- };
- let config = config::get_config().unwrap();
-
- let self_id = config.url;
- let duration = Duration::hours(1);
- let token_type = Box::from("bearer");
- let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
-
- match request.grant_type {
- GrantType::AuthorizationCode {
- code,
- redirect_uri,
- client_alias,
- } => {
- let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else {
- return TokenError::client_not_found(&client_alias).error_response();
- };
-
- // validate auth code
- let claims =
- match jwt::verify_auth_code(db, &code, &self_id, client_id, redirect_uri).await {
- Ok(claims) => claims,
- Err(err) => {
- let err = err.unwrap();
- return TokenError::bad_auth_code(err).error_response();
- }
- };
-
- // verify client, if the client has credentials
- if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() {
- let Some(authorization) = authorization else {
- return TokenError::no_authorization().error_response();
- };
-
- if authorization.username() != client_alias.deref() {
- return TokenError::mismatch_client_id().error_response();
- }
- if !hash.check_password(authorization.password()).unwrap() {
- return TokenError::incorrect_client_secret().error_response();
- }
- }
-
- let access_token = jwt::Claims::access_token(
- db,
- Some(claims.id()),
- self_id,
- client_id,
- claims.subject(),
- duration,
- claims.scopes(),
- )
- .await
- .unwrap();
-
- let expires_in = access_token.expires_in();
- let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
- let scope = access_token.scopes().into();
-
- let access_token = access_token.to_jwt().unwrap();
- let refresh_token = Some(refresh_token.to_jwt().unwrap());
-
- let response = TokenResponse {
- access_token,
- token_type,
- expires_in,
- refresh_token,
- scope,
- };
- HttpResponse::Ok()
- .insert_header(cache_control)
- .insert_header((header::PRAGMA, "no-cache"))
- .json(response)
- }
- GrantType::Password {
- username,
- password,
- scope,
- } => {
- let Some(authorization) = authorization else {
- return TokenError::no_authorization().error_response();
- };
- let client_alias = authorization.username();
- let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
- return TokenError::client_not_found(client_alias).error_response();
- };
-
- let trusted = db::is_client_trusted(db, client_id).await.unwrap().unwrap();
- if !trusted {
- return TokenError::untrusted_client().error_response();
- }
-
- // verify client
- let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap();
- if !hash.check_password(authorization.password()).unwrap() {
- return TokenError::incorrect_client_secret().error_response();
- }
-
- // authenticate user
- let Some(user_id) = authenticate_user(db, &username, &password).await.unwrap() else {
- return TokenError::incorrect_user_credentials().error_response();
- };
-
- // verify scope
- let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- let scope = if let Some(scope) = &scope {
- scope.clone()
- } else {
- let default_scopes = db::get_client_default_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- let Some(scope) = default_scopes else {
- return TokenError::no_scope().error_response();
- };
- scope
- };
- if !scopes::is_subset_of(&scope, &allowed_scopes) {
- return TokenError::excessive_scope().error_response();
- }
-
- let access_token =
- jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope)
- .await
- .unwrap();
- let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
-
- let expires_in = access_token.expires_in();
- let scope = access_token.scopes().into();
- let access_token = access_token.to_jwt().unwrap();
- let refresh_token = Some(refresh_token.to_jwt().unwrap());
-
- let response = TokenResponse {
- access_token,
- token_type,
- expires_in,
- refresh_token,
- scope,
- };
- HttpResponse::Ok()
- .insert_header(cache_control)
- .insert_header((header::PRAGMA, "no-cache"))
- .json(response)
- }
- GrantType::ClientCredentials { scope } => {
- let Some(authorization) = authorization else {
- return TokenError::no_authorization().error_response();
- };
- let client_alias = authorization.username();
- let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
- return TokenError::client_not_found(client_alias).error_response();
- };
-
- let ty = db::get_client_type(db, client_id).await.unwrap().unwrap();
- if ty != ClientType::Confidential {
- return TokenError::client_not_confidential(client_alias).error_response();
- }
-
- // verify client
- let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap();
- if !hash.check_password(authorization.password()).unwrap() {
- return TokenError::incorrect_client_secret().error_response();
- }
-
- // verify scope
- let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- let scope = if let Some(scope) = &scope {
- scope.clone()
- } else {
- let default_scopes = db::get_client_default_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- let Some(scope) = default_scopes else {
- return TokenError::no_scope().error_response();
- };
- scope
- };
- if !scopes::is_subset_of(&scope, &allowed_scopes) {
- return TokenError::excessive_scope().error_response();
- }
-
- let access_token = jwt::Claims::access_token(
- db, None, self_id, client_id, client_id, duration, &scope,
- )
- .await
- .unwrap();
-
- let expires_in = access_token.expires_in();
- let scope = access_token.scopes().into();
- let access_token = access_token.to_jwt().unwrap();
-
- let response = TokenResponse {
- access_token,
- token_type,
- expires_in,
- refresh_token: None,
- scope,
- };
- HttpResponse::Ok()
- .insert_header(cache_control)
- .insert_header((header::PRAGMA, "no-cache"))
- .json(response)
- }
- GrantType::RefreshToken {
- refresh_token,
- scope,
- } => {
- let client_id: Option<Uuid>;
- if let Some(authorization) = authorization {
- let client_alias = authorization.username();
- let Some(id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
- return TokenError::client_not_found(client_alias).error_response();
- };
- client_id = Some(id);
- } else {
- client_id = None;
- }
-
- let claims =
- match jwt::verify_refresh_token(db, &refresh_token, &self_id, client_id).await {
- Ok(claims) => claims,
- Err(e) => {
- let e = e.unwrap();
- return TokenError::bad_refresh_token(e).error_response();
- }
- };
-
- let scope = if let Some(scope) = scope {
- if !scopes::is_subset_of(&scope, claims.scopes()) {
- return TokenError::excessive_scope().error_response();
- }
-
- scope
- } else {
- claims.scopes().into()
- };
-
- let exp_time = Duration::hours(1);
- let access_token = jwt::Claims::refreshed_access_token(db, &claims, exp_time)
- .await
- .unwrap();
- let refresh_token = jwt::Claims::refresh_token(db, &claims).await.unwrap();
-
- let access_token = access_token.to_jwt().unwrap();
- let refresh_token = Some(refresh_token.to_jwt().unwrap());
- let expires_in = exp_time.num_seconds();
-
- let response = TokenResponse {
- access_token,
- token_type,
- expires_in,
- refresh_token,
- scope,
- };
- HttpResponse::Ok()
- .insert_header(cache_control)
- .insert_header((header::PRAGMA, "no-cache"))
- .json(response)
- }
- _ => TokenError::unsupported_grant_type().error_response(),
- }
-}
-
-pub fn service() -> Scope {
- web::scope("/oauth")
- .service(authorize_page)
- .service(authorize)
- .service(token)
-}
+use std::ops::Deref;
+use std::str::FromStr;
+
+use actix_web::http::{header, StatusCode};
+use actix_web::{
+ get, post, web, HttpRequest, HttpResponse, HttpResponseBuilder, ResponseError, Scope,
+};
+use chrono::Duration;
+use exun::{Expect, RawUnexpected, ResultErrorExt, UnexpectedError};
+use raise::yeet;
+use serde::{Deserialize, Serialize};
+use sqlx::MySqlPool;
+use tera::Tera;
+use thiserror::Error;
+use unic_langid::subtags::Language;
+use url::Url;
+use uuid::Uuid;
+
+use crate::models::client::ClientType;
+use crate::resources::{languages, templates};
+use crate::scopes;
+use crate::services::jwt::VerifyJwtError;
+use crate::services::{authorization, config, db, jwt};
+
+const REALLY_BAD_ERROR_PAGE: &str = "<!DOCTYPE html><html><head><title>Internal Server Error</title></head><body>Internal Server Error</body></html>";
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+enum ResponseType {
+ Code,
+ Token,
+ #[serde(other)]
+ Unsupported,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct AuthorizationParameters {
+ response_type: ResponseType,
+ client_id: Box<str>,
+ redirect_uri: Option<Url>,
+ scope: Option<Box<str>>,
+ state: Option<Box<str>>,
+}
+
+#[derive(Clone, Deserialize)]
+struct AuthorizeCredentials {
+ username: Box<str>,
+ password: Box<str>,
+}
+
+#[derive(Clone, Serialize)]
+struct AuthCodeResponse {
+ code: Box<str>,
+ state: Option<Box<str>>,
+}
+
+#[derive(Clone, Serialize)]
+struct AuthTokenResponse {
+ access_token: Box<str>,
+ token_type: &'static str,
+ expires_in: i64,
+ scope: Box<str>,
+ state: Option<Box<str>>,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
+#[serde(rename_all = "camelCase")]
+enum AuthorizeErrorType {
+ InvalidRequest,
+ UnauthorizedClient,
+ AccessDenied,
+ UnsupportedResponseType,
+ InvalidScope,
+ ServerError,
+ TemporarilyUnavailable,
+}
+
+#[derive(Debug, Clone, Error, Serialize)]
+#[error("{error_description}")]
+struct AuthorizeError {
+ error: AuthorizeErrorType,
+ error_description: Box<str>,
+ // TODO error uri
+ state: Option<Box<str>>,
+ #[serde(skip)]
+ redirect_uri: Url,
+}
+
+impl AuthorizeError {
+ fn no_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::InvalidScope,
+ error_description: Box::from(
+ "No scope was provided, and the client does not have a default scope",
+ ),
+ state,
+ redirect_uri,
+ }
+ }
+
+ fn unsupported_response_type(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::UnsupportedResponseType,
+ error_description: Box::from("The given response type is not supported"),
+ state,
+ redirect_uri,
+ }
+ }
+
+ fn invalid_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::InvalidScope,
+ error_description: Box::from("The given scope exceeds what the client is allowed"),
+ state,
+ redirect_uri,
+ }
+ }
+
+ fn internal_server_error(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::ServerError,
+ error_description: "An unexpected error occurred".into(),
+ state,
+ redirect_uri,
+ }
+ }
+}
+
+impl ResponseError for AuthorizeError {
+ fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
+ let query = Some(serde_urlencoded::to_string(self).unwrap());
+ let query = query.as_deref();
+ let mut url = self.redirect_uri.clone();
+ url.set_query(query);
+
+ HttpResponse::Found()
+ .insert_header((header::LOCATION, url.as_str()))
+ .finish()
+ }
+}
+
+fn error_page(
+ tera: &Tera,
+ translations: &languages::Translations,
+ error: templates::ErrorPage,
+) -> Result<String, RawUnexpected> {
+ // TODO find a better way of doing languages
+ let language = Language::from_str("en").unwrap();
+ let translations = translations.clone();
+ let page = templates::error_page(&tera, language, translations, error)?;
+ Ok(page)
+}
+
+async fn get_redirect_uri(
+ redirect_uri: &Option<Url>,
+ db: &MySqlPool,
+ client_id: Uuid,
+) -> Result<Url, Expect<templates::ErrorPage>> {
+ if let Some(uri) = &redirect_uri {
+ let redirect_uri = uri.clone();
+ if !db::client_has_redirect_uri(db, client_id, &redirect_uri)
+ .await
+ .map_err(|e| UnexpectedError::from(e))
+ .unexpect()?
+ {
+ yeet!(Expect::Expected(templates::ErrorPage::InvalidRedirectUri));
+ }
+
+ Ok(redirect_uri)
+ } else {
+ let redirect_uris = db::get_client_redirect_uris(db, client_id)
+ .await
+ .map_err(|e| UnexpectedError::from(e))
+ .unexpect()?;
+ if redirect_uris.len() != 1 {
+ yeet!(Expect::Expected(templates::ErrorPage::MissingRedirectUri));
+ }
+
+ Ok(redirect_uris.get(0).unwrap().clone())
+ }
+}
+
+async fn get_scope(
+ scope: &Option<Box<str>>,
+ db: &MySqlPool,
+ client_id: Uuid,
+ redirect_uri: &Url,
+ state: &Option<Box<str>>,
+) -> Result<Box<str>, Expect<AuthorizeError>> {
+ let scope = if let Some(scope) = &scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ yeet!(AuthorizeError::no_scope(redirect_uri.clone(), state.clone()).into())
+ };
+ scope
+ };
+
+ // verify scope is valid
+ let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ yeet!(AuthorizeError::invalid_scope(redirect_uri.clone(), state.clone()).into());
+ }
+
+ Ok(scope)
+}
+
+async fn authenticate_user(
+ db: &MySqlPool,
+ username: &str,
+ password: &str,
+) -> Result<Option<Uuid>, RawUnexpected> {
+ let Some(user) = db::get_user_by_username(db, username).await? else {
+ return Ok(None);
+ };
+
+ if user.check_password(password)? {
+ Ok(Some(user.id))
+ } else {
+ Ok(None)
+ }
+}
+
+#[post("/authorize")]
+async fn authorize(
+ db: web::Data<MySqlPool>,
+ req: web::Query<AuthorizationParameters>,
+ credentials: web::Json<AuthorizeCredentials>,
+ tera: web::Data<Tera>,
+ translations: web::Data<languages::Translations>,
+) -> Result<HttpResponse, AuthorizeError> {
+ // TODO protect against brute force attacks
+ let db = db.get_ref();
+ let Ok(client_id) = db::get_client_id_by_alias(db, &req.client_id).await else {
+ let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page));
+ };
+ let Some(client_id) = client_id else {
+ let page = error_page(&tera, &translations, templates::ErrorPage::ClientNotFound).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::NotFound().content_type("text/html").body(page));
+ };
+ let Ok(config) = config::get_config() else {
+ let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page));
+ };
+
+ let self_id = config.url;
+ let state = req.state.clone();
+
+ // get redirect uri
+ let mut redirect_uri = match get_redirect_uri(&req.redirect_uri, db, client_id).await {
+ Ok(uri) => uri,
+ Err(e) => {
+ let e = e
+ .expected()
+ .unwrap_or(templates::ErrorPage::InternalServerError);
+ let page = error_page(&tera, &translations, e)
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::BadRequest()
+ .content_type("text/html")
+ .body(page));
+ }
+ };
+
+ // authenticate user
+ let Some(user_id) = authenticate_user(db, &credentials.username, &credentials.password)
+ .await
+ .unwrap() else
+ {
+ let language = Language::from_str("en").unwrap();
+ let translations = translations.get_ref().clone();
+ let page = templates::login_error_page(&tera, &req, language, translations).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::Ok().content_type("text/html").body(page));
+ };
+
+ let internal_server_error =
+ AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone());
+
+ // get scope
+ let scope = match get_scope(&req.scope, db, client_id, &redirect_uri, &state).await {
+ Ok(scope) => scope,
+ Err(e) => {
+ let e = e.expected().unwrap_or(internal_server_error);
+ return Err(e);
+ }
+ };
+
+ match req.response_type {
+ ResponseType::Code => {
+ // create auth code
+ let code =
+ jwt::Claims::auth_code(db, self_id, client_id, user_id, &scope, &redirect_uri)
+ .await
+ .map_err(|_| internal_server_error.clone())?;
+ let code = code.to_jwt().map_err(|_| internal_server_error.clone())?;
+
+ let response = AuthCodeResponse { code, state };
+ let query =
+ Some(serde_urlencoded::to_string(response).map_err(|_| internal_server_error)?);
+ let query = query.as_deref();
+ redirect_uri.set_query(query);
+
+ Ok(HttpResponse::Found()
+ .append_header((header::LOCATION, redirect_uri.as_str()))
+ .finish())
+ }
+ ResponseType::Token => {
+ // create access token
+ let duration = Duration::hours(1);
+ let access_token =
+ jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope)
+ .await
+ .map_err(|_| internal_server_error.clone())?;
+
+ let access_token = access_token
+ .to_jwt()
+ .map_err(|_| internal_server_error.clone())?;
+ let expires_in = duration.num_seconds();
+ let token_type = "bearer";
+ let response = AuthTokenResponse {
+ access_token,
+ expires_in,
+ token_type,
+ scope,
+ state,
+ };
+
+ let fragment = Some(
+ serde_urlencoded::to_string(response).map_err(|_| internal_server_error.clone())?,
+ );
+ let fragment = fragment.as_deref();
+ redirect_uri.set_fragment(fragment);
+
+ Ok(HttpResponse::Found()
+ .append_header((header::LOCATION, redirect_uri.as_str()))
+ .finish())
+ }
+ _ => Err(AuthorizeError::invalid_scope(redirect_uri, state)),
+ }
+}
+
+#[get("/authorize")]
+async fn authorize_page(
+ db: web::Data<MySqlPool>,
+ tera: web::Data<Tera>,
+ translations: web::Data<languages::Translations>,
+ request: HttpRequest,
+) -> Result<HttpResponse, AuthorizeError> {
+ let Ok(language) = Language::from_str("en") else {
+ let page = String::from(REALLY_BAD_ERROR_PAGE);
+ return Ok(HttpResponse::InternalServerError()
+ .content_type("text/html")
+ .body(page));
+ };
+ let translations = translations.get_ref().clone();
+
+ let params = request.query_string();
+ let params = serde_urlencoded::from_str::<AuthorizationParameters>(params);
+ let Ok(params) = params else {
+ let page = error_page(
+ &tera,
+ &translations,
+ templates::ErrorPage::InvalidRequest,
+ )
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::BadRequest()
+ .content_type("text/html")
+ .body(page));
+ };
+
+ let db = db.get_ref();
+ let Ok(client_id) = db::get_client_id_by_alias(db, &params.client_id).await else {
+ let page = templates::error_page(
+ &tera,
+ language,
+ translations,
+ templates::ErrorPage::InternalServerError,
+ )
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::InternalServerError()
+ .content_type("text/html")
+ .body(page));
+ };
+ let Some(client_id) = client_id else {
+ let page = templates::error_page(
+ &tera,
+ language,
+ translations,
+ templates::ErrorPage::ClientNotFound,
+ )
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::NotFound()
+ .content_type("text/html")
+ .body(page));
+ };
+
+ // verify redirect uri
+ let redirect_uri = match get_redirect_uri(&params.redirect_uri, db, client_id).await {
+ Ok(uri) => uri,
+ Err(e) => {
+ let e = e
+ .expected()
+ .unwrap_or(templates::ErrorPage::InternalServerError);
+ let page = error_page(&tera, &translations, e)
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::BadRequest()
+ .content_type("text/html")
+ .body(page));
+ }
+ };
+
+ let state = &params.state;
+ let internal_server_error =
+ AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone());
+
+ // verify scope
+ let _ = match get_scope(&params.scope, db, client_id, &redirect_uri, &params.state).await {
+ Ok(scope) => scope,
+ Err(e) => {
+ let e = e.expected().unwrap_or(internal_server_error);
+ return Err(e);
+ }
+ };
+
+ // verify response type
+ if params.response_type == ResponseType::Unsupported {
+ return Err(AuthorizeError::unsupported_response_type(
+ redirect_uri,
+ params.state,
+ ));
+ }
+
+ // TODO find a better way of doing languages
+ let language = Language::from_str("en").unwrap();
+ let page = templates::login_page(&tera, &params, language, translations).unwrap();
+ Ok(HttpResponse::Ok().content_type("text/html").body(page))
+}
+
+#[derive(Clone, Deserialize)]
+#[serde(tag = "grant_type")]
+#[serde(rename_all = "snake_case")]
+enum GrantType {
+ AuthorizationCode {
+ code: Box<str>,
+ redirect_uri: Url,
+ #[serde(rename = "client_id")]
+ client_alias: Box<str>,
+ },
+ Password {
+ username: Box<str>,
+ password: Box<str>,
+ scope: Option<Box<str>>,
+ },
+ ClientCredentials {
+ scope: Option<Box<str>>,
+ },
+ RefreshToken {
+ refresh_token: Box<str>,
+ scope: Option<Box<str>>,
+ },
+ #[serde(other)]
+ Unsupported,
+}
+
+#[derive(Clone, Deserialize)]
+struct TokenRequest {
+ #[serde(flatten)]
+ grant_type: GrantType,
+ // TODO support optional client credentials in here
+}
+
+#[derive(Clone, Serialize)]
+struct TokenResponse {
+ access_token: Box<str>,
+ token_type: Box<str>,
+ expires_in: i64,
+ refresh_token: Option<Box<str>>,
+ scope: Box<str>,
+}
+
+#[derive(Debug, Clone, Serialize)]
+#[serde(rename_all = "snake_case")]
+enum TokenErrorType {
+ InvalidRequest,
+ InvalidClient,
+ InvalidGrant,
+ UnauthorizedClient,
+ UnsupportedGrantType,
+ InvalidScope,
+}
+
+#[derive(Debug, Clone, Error, Serialize)]
+#[error("{error_description}")]
+struct TokenError {
+ #[serde(skip)]
+ status_code: StatusCode,
+ error: TokenErrorType,
+ error_description: Box<str>,
+ // TODO error uri
+}
+
+impl TokenError {
+ fn invalid_request() -> Self {
+ // TODO make this description better, and all the other ones while you're at it
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidRequest,
+ error_description: "Invalid request".into(),
+ }
+ }
+
+ fn unsupported_grant_type() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::UnsupportedGrantType,
+ error_description: "The given grant type is not supported".into(),
+ }
+ }
+
+ fn bad_auth_code(error: VerifyJwtError) -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidGrant,
+ error_description: error.to_string().into_boxed_str(),
+ }
+ }
+
+ fn no_authorization() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: Box::from(
+ "Client credentials must be provided in the HTTP Authorization header",
+ ),
+ }
+ }
+
+ fn client_not_found(alias: &str) -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: format!("No client with the client id: {alias} was found")
+ .into_boxed_str(),
+ }
+ }
+
+ fn mismatch_client_id() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: Box::from("The client ID in the Authorization header is not the same as the client ID in the request body"),
+ }
+ }
+
+ fn incorrect_client_secret() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: "The client secret is incorrect".into(),
+ }
+ }
+
+ fn client_not_confidential(alias: &str) -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::UnauthorizedClient,
+ error_description: format!("Only a confidential client may be used with this endpoint. The {alias} client is a public client.")
+ .into_boxed_str(),
+ }
+ }
+
+ fn no_scope() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidScope,
+ error_description: Box::from(
+ "No scope was provided, and the client doesn't have a default scope",
+ ),
+ }
+ }
+
+ fn excessive_scope() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidScope,
+ error_description: Box::from(
+ "The given scope exceeds what the client is allowed to have",
+ ),
+ }
+ }
+
+ fn bad_refresh_token(err: VerifyJwtError) -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidGrant,
+ error_description: err.to_string().into_boxed_str(),
+ }
+ }
+
+ fn untrusted_client() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: "Only trusted clients may use this grant".into(),
+ }
+ }
+
+ fn incorrect_user_credentials() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidRequest,
+ error_description: "The given credentials are incorrect".into(),
+ }
+ }
+}
+
+impl ResponseError for TokenError {
+ fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
+ let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
+
+ let mut builder = HttpResponseBuilder::new(self.status_code);
+
+ if self.status_code.as_u16() == 401 {
+ builder.insert_header((header::WWW_AUTHENTICATE, "Basic charset=\"UTF-8\""));
+ }
+
+ builder
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(self.clone())
+ }
+}
+
+#[post("/token")]
+async fn token(
+ db: web::Data<MySqlPool>,
+ req: web::Bytes,
+ authorization: Option<web::Header<authorization::BasicAuthorization>>,
+) -> HttpResponse {
+ // TODO protect against brute force attacks
+ let db = db.get_ref();
+ let request = serde_json::from_slice::<TokenRequest>(&req);
+ let Ok(request) = request else {
+ return TokenError::invalid_request().error_response();
+ };
+ let config = config::get_config().unwrap();
+
+ let self_id = config.url;
+ let duration = Duration::hours(1);
+ let token_type = Box::from("bearer");
+ let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
+
+ match request.grant_type {
+ GrantType::AuthorizationCode {
+ code,
+ redirect_uri,
+ client_alias,
+ } => {
+ let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else {
+ return TokenError::client_not_found(&client_alias).error_response();
+ };
+
+ // validate auth code
+ let claims =
+ match jwt::verify_auth_code(db, &code, &self_id, client_id, redirect_uri).await {
+ Ok(claims) => claims,
+ Err(err) => {
+ let err = err.unwrap();
+ return TokenError::bad_auth_code(err).error_response();
+ }
+ };
+
+ // verify client, if the client has credentials
+ if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() {
+ let Some(authorization) = authorization else {
+ return TokenError::no_authorization().error_response();
+ };
+
+ if authorization.username() != client_alias.deref() {
+ return TokenError::mismatch_client_id().error_response();
+ }
+ if !hash.check_password(authorization.password()).unwrap() {
+ return TokenError::incorrect_client_secret().error_response();
+ }
+ }
+
+ let access_token = jwt::Claims::access_token(
+ db,
+ Some(claims.id()),
+ self_id,
+ client_id,
+ claims.subject(),
+ duration,
+ claims.scopes(),
+ )
+ .await
+ .unwrap();
+
+ let expires_in = access_token.expires_in();
+ let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
+ let scope = access_token.scopes().into();
+
+ let access_token = access_token.to_jwt().unwrap();
+ let refresh_token = Some(refresh_token.to_jwt().unwrap());
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ GrantType::Password {
+ username,
+ password,
+ scope,
+ } => {
+ let Some(authorization) = authorization else {
+ return TokenError::no_authorization().error_response();
+ };
+ let client_alias = authorization.username();
+ let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
+ return TokenError::client_not_found(client_alias).error_response();
+ };
+
+ let trusted = db::is_client_trusted(db, client_id).await.unwrap().unwrap();
+ if !trusted {
+ return TokenError::untrusted_client().error_response();
+ }
+
+ // verify client
+ let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap();
+ if !hash.check_password(authorization.password()).unwrap() {
+ return TokenError::incorrect_client_secret().error_response();
+ }
+
+ // authenticate user
+ let Some(user_id) = authenticate_user(db, &username, &password).await.unwrap() else {
+ return TokenError::incorrect_user_credentials().error_response();
+ };
+
+ // verify scope
+ let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let scope = if let Some(scope) = &scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ return TokenError::no_scope().error_response();
+ };
+ scope
+ };
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ return TokenError::excessive_scope().error_response();
+ }
+
+ let access_token =
+ jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope)
+ .await
+ .unwrap();
+ let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
+
+ let expires_in = access_token.expires_in();
+ let scope = access_token.scopes().into();
+ let access_token = access_token.to_jwt().unwrap();
+ let refresh_token = Some(refresh_token.to_jwt().unwrap());
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ GrantType::ClientCredentials { scope } => {
+ let Some(authorization) = authorization else {
+ return TokenError::no_authorization().error_response();
+ };
+ let client_alias = authorization.username();
+ let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
+ return TokenError::client_not_found(client_alias).error_response();
+ };
+
+ let ty = db::get_client_type(db, client_id).await.unwrap().unwrap();
+ if ty != ClientType::Confidential {
+ return TokenError::client_not_confidential(client_alias).error_response();
+ }
+
+ // verify client
+ let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap();
+ if !hash.check_password(authorization.password()).unwrap() {
+ return TokenError::incorrect_client_secret().error_response();
+ }
+
+ // verify scope
+ let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let scope = if let Some(scope) = &scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ return TokenError::no_scope().error_response();
+ };
+ scope
+ };
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ return TokenError::excessive_scope().error_response();
+ }
+
+ let access_token = jwt::Claims::access_token(
+ db, None, self_id, client_id, client_id, duration, &scope,
+ )
+ .await
+ .unwrap();
+
+ let expires_in = access_token.expires_in();
+ let scope = access_token.scopes().into();
+ let access_token = access_token.to_jwt().unwrap();
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token: None,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ GrantType::RefreshToken {
+ refresh_token,
+ scope,
+ } => {
+ let client_id: Option<Uuid>;
+ if let Some(authorization) = authorization {
+ let client_alias = authorization.username();
+ let Some(id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
+ return TokenError::client_not_found(client_alias).error_response();
+ };
+ client_id = Some(id);
+ } else {
+ client_id = None;
+ }
+
+ let claims =
+ match jwt::verify_refresh_token(db, &refresh_token, &self_id, client_id).await {
+ Ok(claims) => claims,
+ Err(e) => {
+ let e = e.unwrap();
+ return TokenError::bad_refresh_token(e).error_response();
+ }
+ };
+
+ let scope = if let Some(scope) = scope {
+ if !scopes::is_subset_of(&scope, claims.scopes()) {
+ return TokenError::excessive_scope().error_response();
+ }
+
+ scope
+ } else {
+ claims.scopes().into()
+ };
+
+ let exp_time = Duration::hours(1);
+ let access_token = jwt::Claims::refreshed_access_token(db, &claims, exp_time)
+ .await
+ .unwrap();
+ let refresh_token = jwt::Claims::refresh_token(db, &claims).await.unwrap();
+
+ let access_token = access_token.to_jwt().unwrap();
+ let refresh_token = Some(refresh_token.to_jwt().unwrap());
+ let expires_in = exp_time.num_seconds();
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ _ => TokenError::unsupported_grant_type().error_response(),
+ }
+}
+
+pub fn service() -> Scope {
+ web::scope("/oauth")
+ .service(authorize_page)
+ .service(authorize)
+ .service(token)
+}
diff --git a/src/api/ops.rs b/src/api/ops.rs
index 555bb1b..2164f1f 100644
--- a/src/api/ops.rs
+++ b/src/api/ops.rs
@@ -1,70 +1,70 @@
-use std::str::FromStr;
-
-use actix_web::{get, http::StatusCode, post, web, HttpResponse, ResponseError, Scope};
-use raise::yeet;
-use serde::Deserialize;
-use sqlx::MySqlPool;
-use tera::Tera;
-use thiserror::Error;
-use unic_langid::subtags::Language;
-
-use crate::resources::{languages, templates};
-use crate::services::db;
-
-/// A request to login
-#[derive(Debug, Clone, Deserialize)]
-struct LoginRequest {
- username: Box<str>,
- password: Box<str>,
-}
-
-/// An error occurred when authenticating, because either the username or
-/// password was invalid.
-#[derive(Debug, Clone, Error)]
-enum LoginFailure {
- #[error("No user found with the given username")]
- UserNotFound { username: Box<str> },
- #[error("The given password is incorrect")]
- IncorrectPassword { username: Box<str> },
-}
-
-impl ResponseError for LoginFailure {
- fn status_code(&self) -> actix_web::http::StatusCode {
- match self {
- Self::UserNotFound { .. } => StatusCode::NOT_FOUND,
- Self::IncorrectPassword { .. } => StatusCode::UNAUTHORIZED,
- }
- }
-}
-
-/// Returns `200` if login was successful.
-/// Returns `404` if the username is invalid.
-/// Returns `401` if the password was invalid.
-#[post("/login")]
-async fn login(
- body: web::Json<LoginRequest>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, LoginFailure> {
- let conn = conn.get_ref();
-
- let user = db::get_user_by_username(conn, &body.username)
- .await
- .unwrap();
- let Some(user) = user else {
- yeet!(LoginFailure::UserNotFound{ username: body.username.clone() });
- };
-
- let good_password = user.check_password(&body.password).unwrap();
- let response = if good_password {
- HttpResponse::Ok().finish()
- } else {
- yeet!(LoginFailure::IncorrectPassword {
- username: body.username.clone()
- });
- };
- Ok(response)
-}
-
-pub fn service() -> Scope {
- web::scope("").service(login)
-}
+use std::str::FromStr;
+
+use actix_web::{get, http::StatusCode, post, web, HttpResponse, ResponseError, Scope};
+use raise::yeet;
+use serde::Deserialize;
+use sqlx::MySqlPool;
+use tera::Tera;
+use thiserror::Error;
+use unic_langid::subtags::Language;
+
+use crate::resources::{languages, templates};
+use crate::services::db;
+
+/// A request to login
+#[derive(Debug, Clone, Deserialize)]
+struct LoginRequest {
+ username: Box<str>,
+ password: Box<str>,
+}
+
+/// An error occurred when authenticating, because either the username or
+/// password was invalid.
+#[derive(Debug, Clone, Error)]
+enum LoginFailure {
+ #[error("No user found with the given username")]
+ UserNotFound { username: Box<str> },
+ #[error("The given password is incorrect")]
+ IncorrectPassword { username: Box<str> },
+}
+
+impl ResponseError for LoginFailure {
+ fn status_code(&self) -> actix_web::http::StatusCode {
+ match self {
+ Self::UserNotFound { .. } => StatusCode::NOT_FOUND,
+ Self::IncorrectPassword { .. } => StatusCode::UNAUTHORIZED,
+ }
+ }
+}
+
+/// Returns `200` if login was successful.
+/// Returns `404` if the username is invalid.
+/// Returns `401` if the password was invalid.
+#[post("/login")]
+async fn login(
+ body: web::Json<LoginRequest>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, LoginFailure> {
+ let conn = conn.get_ref();
+
+ let user = db::get_user_by_username(conn, &body.username)
+ .await
+ .unwrap();
+ let Some(user) = user else {
+ yeet!(LoginFailure::UserNotFound{ username: body.username.clone() });
+ };
+
+ let good_password = user.check_password(&body.password).unwrap();
+ let response = if good_password {
+ HttpResponse::Ok().finish()
+ } else {
+ yeet!(LoginFailure::IncorrectPassword {
+ username: body.username.clone()
+ });
+ };
+ Ok(response)
+}
+
+pub fn service() -> Scope {
+ web::scope("").service(login)
+}
diff --git a/src/api/users.rs b/src/api/users.rs
index 391a059..da2a0d0 100644
--- a/src/api/users.rs
+++ b/src/api/users.rs
@@ -1,272 +1,272 @@
-use actix_web::http::{header, StatusCode};
-use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope};
-use raise::yeet;
-use serde::{Deserialize, Serialize};
-use sqlx::MySqlPool;
-use thiserror::Error;
-use uuid::Uuid;
-
-use crate::models::user::User;
-use crate::services::crypto::PasswordHash;
-use crate::services::{db, id};
-
-/// Just a username. No password hash, because that'd be tempting fate.
-#[derive(Debug, Clone, Serialize)]
-#[serde(rename_all = "camelCase")]
-struct UserResponse {
- id: Uuid,
- username: Box<str>,
-}
-
-impl From<User> for UserResponse {
- fn from(user: User) -> Self {
- Self {
- id: user.id,
- username: user.username,
- }
- }
-}
-
-#[derive(Debug, Clone, Deserialize)]
-#[serde(rename_all = "camelCase")]
-struct SearchUsers {
- username: Option<Box<str>>,
- limit: Option<u32>,
- offset: Option<u32>,
-}
-
-#[get("")]
-async fn search_users(params: web::Query<SearchUsers>, conn: web::Data<MySqlPool>) -> HttpResponse {
- let conn = conn.get_ref();
-
- let username = params.username.clone().unwrap_or_default();
- let offset = params.offset.unwrap_or_default();
-
- let results: Box<[UserResponse]> = if let Some(limit) = params.limit {
- db::search_users_limit(conn, &username, offset, limit)
- .await
- .unwrap()
- .iter()
- .cloned()
- .map(|u| u.into())
- .collect()
- } else {
- db::search_users(conn, &username)
- .await
- .unwrap()
- .into_iter()
- .skip(offset as usize)
- .cloned()
- .map(|u| u.into())
- .collect()
- };
-
- let response = HttpResponse::Ok().json(results);
- response
-}
-
-#[derive(Debug, Clone, Error)]
-#[error("No user with the given ID exists")]
-struct UserNotFoundError {
- user_id: Uuid,
-}
-
-impl ResponseError for UserNotFoundError {
- fn status_code(&self) -> StatusCode {
- StatusCode::NOT_FOUND
- }
-}
-
-#[get("/{user_id}")]
-async fn get_user(
- user_id: web::Path<Uuid>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UserNotFoundError> {
- let conn = conn.get_ref();
-
- let id = user_id.to_owned();
- let username = db::get_username(conn, id).await.unwrap();
-
- let Some(username) = username else {
- yeet!(UserNotFoundError { user_id: id });
- };
-
- let response = UserResponse { id, username };
- let response = HttpResponse::Ok().json(response);
- Ok(response)
-}
-
-#[get("/{user_id}/username")]
-async fn get_username(
- user_id: web::Path<Uuid>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UserNotFoundError> {
- let conn = conn.get_ref();
-
- let user_id = user_id.to_owned();
- let username = db::get_username(conn, user_id).await.unwrap();
-
- let Some(username) = username else {
- yeet!(UserNotFoundError { user_id });
- };
-
- let response = HttpResponse::Ok().json(username);
- Ok(response)
-}
-
-/// A request to create or update user information
-#[derive(Clone, Deserialize)]
-#[serde(rename_all = "camelCase")]
-struct UserRequest {
- username: Box<str>,
- password: Box<str>,
-}
-
-#[derive(Debug, Clone, Error)]
-#[error("An account with the given username already exists.")]
-struct UsernameTakenError {
- username: Box<str>,
-}
-
-impl ResponseError for UsernameTakenError {
- fn status_code(&self) -> StatusCode {
- StatusCode::CONFLICT
- }
-}
-
-#[post("")]
-async fn create_user(
- body: web::Json<UserRequest>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UsernameTakenError> {
- let conn = conn.get_ref();
-
- let user_id = id::new_id(conn, db::user_id_exists).await.unwrap();
- let username = body.username.clone();
- let password = PasswordHash::new(&body.password).unwrap();
-
- if db::username_is_used(conn, &body.username).await.unwrap() {
- yeet!(UsernameTakenError { username });
- }
-
- let user = User {
- id: user_id,
- username,
- password,
- };
-
- db::create_user(conn, &user).await.unwrap();
-
- let response = HttpResponse::Created()
- .insert_header((header::LOCATION, format!("users/{user_id}")))
- .finish();
- Ok(response)
-}
-
-#[derive(Debug, Clone, Error)]
-enum UpdateUserError {
- #[error(transparent)]
- UsernameTaken(#[from] UsernameTakenError),
- #[error(transparent)]
- NotFound(#[from] UserNotFoundError),
-}
-
-impl ResponseError for UpdateUserError {
- fn status_code(&self) -> StatusCode {
- match self {
- Self::UsernameTaken(e) => e.status_code(),
- Self::NotFound(e) => e.status_code(),
- }
- }
-}
-
-#[put("/{user_id}")]
-async fn update_user(
- user_id: web::Path<Uuid>,
- body: web::Json<UserRequest>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateUserError> {
- let conn = conn.get_ref();
-
- let user_id = user_id.to_owned();
- let username = body.username.clone();
- let password = PasswordHash::new(&body.password).unwrap();
-
- let old_username = db::get_username(conn, user_id).await.unwrap().unwrap();
- if username != old_username && db::username_is_used(conn, &body.username).await.unwrap() {
- yeet!(UsernameTakenError { username }.into())
- }
-
- if !db::user_id_exists(conn, user_id).await.unwrap() {
- yeet!(UserNotFoundError { user_id }.into())
- }
-
- let user = User {
- id: user_id,
- username,
- password,
- };
-
- db::update_user(conn, &user).await.unwrap();
-
- let response = HttpResponse::NoContent().finish();
- Ok(response)
-}
-
-#[put("/{user_id}/username")]
-async fn update_username(
- user_id: web::Path<Uuid>,
- body: web::Json<Box<str>>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateUserError> {
- let conn = conn.get_ref();
-
- let user_id = user_id.to_owned();
- let username = body.clone();
-
- let old_username = db::get_username(conn, user_id).await.unwrap().unwrap();
- if username != old_username && db::username_is_used(conn, &body).await.unwrap() {
- yeet!(UsernameTakenError { username }.into())
- }
-
- if !db::user_id_exists(conn, user_id).await.unwrap() {
- yeet!(UserNotFoundError { user_id }.into())
- }
-
- db::update_username(conn, user_id, &body).await.unwrap();
-
- let response = HttpResponse::NoContent().finish();
- Ok(response)
-}
-
-#[put("/{user_id}/password")]
-async fn update_password(
- user_id: web::Path<Uuid>,
- body: web::Json<Box<str>>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UserNotFoundError> {
- let conn = conn.get_ref();
-
- let user_id = user_id.to_owned();
- let password = PasswordHash::new(&body).unwrap();
-
- if !db::user_id_exists(conn, user_id).await.unwrap() {
- yeet!(UserNotFoundError { user_id })
- }
-
- db::update_password(conn, user_id, &password).await.unwrap();
-
- let response = HttpResponse::NoContent().finish();
- Ok(response)
-}
-
-pub fn service() -> Scope {
- web::scope("/users")
- .service(search_users)
- .service(get_user)
- .service(get_username)
- .service(create_user)
- .service(update_user)
- .service(update_username)
- .service(update_password)
-}
+use actix_web::http::{header, StatusCode};
+use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope};
+use raise::yeet;
+use serde::{Deserialize, Serialize};
+use sqlx::MySqlPool;
+use thiserror::Error;
+use uuid::Uuid;
+
+use crate::models::user::User;
+use crate::services::crypto::PasswordHash;
+use crate::services::{db, id};
+
+/// Just a username. No password hash, because that'd be tempting fate.
+#[derive(Debug, Clone, Serialize)]
+#[serde(rename_all = "camelCase")]
+struct UserResponse {
+ id: Uuid,
+ username: Box<str>,
+}
+
+impl From<User> for UserResponse {
+ fn from(user: User) -> Self {
+ Self {
+ id: user.id,
+ username: user.username,
+ }
+ }
+}
+
+#[derive(Debug, Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct SearchUsers {
+ username: Option<Box<str>>,
+ limit: Option<u32>,
+ offset: Option<u32>,
+}
+
+#[get("")]
+async fn search_users(params: web::Query<SearchUsers>, conn: web::Data<MySqlPool>) -> HttpResponse {
+ let conn = conn.get_ref();
+
+ let username = params.username.clone().unwrap_or_default();
+ let offset = params.offset.unwrap_or_default();
+
+ let results: Box<[UserResponse]> = if let Some(limit) = params.limit {
+ db::search_users_limit(conn, &username, offset, limit)
+ .await
+ .unwrap()
+ .iter()
+ .cloned()
+ .map(|u| u.into())
+ .collect()
+ } else {
+ db::search_users(conn, &username)
+ .await
+ .unwrap()
+ .into_iter()
+ .skip(offset as usize)
+ .cloned()
+ .map(|u| u.into())
+ .collect()
+ };
+
+ let response = HttpResponse::Ok().json(results);
+ response
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("No user with the given ID exists")]
+struct UserNotFoundError {
+ user_id: Uuid,
+}
+
+impl ResponseError for UserNotFoundError {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::NOT_FOUND
+ }
+}
+
+#[get("/{user_id}")]
+async fn get_user(
+ user_id: web::Path<Uuid>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UserNotFoundError> {
+ let conn = conn.get_ref();
+
+ let id = user_id.to_owned();
+ let username = db::get_username(conn, id).await.unwrap();
+
+ let Some(username) = username else {
+ yeet!(UserNotFoundError { user_id: id });
+ };
+
+ let response = UserResponse { id, username };
+ let response = HttpResponse::Ok().json(response);
+ Ok(response)
+}
+
+#[get("/{user_id}/username")]
+async fn get_username(
+ user_id: web::Path<Uuid>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UserNotFoundError> {
+ let conn = conn.get_ref();
+
+ let user_id = user_id.to_owned();
+ let username = db::get_username(conn, user_id).await.unwrap();
+
+ let Some(username) = username else {
+ yeet!(UserNotFoundError { user_id });
+ };
+
+ let response = HttpResponse::Ok().json(username);
+ Ok(response)
+}
+
+/// A request to create or update user information
+#[derive(Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct UserRequest {
+ username: Box<str>,
+ password: Box<str>,
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("An account with the given username already exists.")]
+struct UsernameTakenError {
+ username: Box<str>,
+}
+
+impl ResponseError for UsernameTakenError {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::CONFLICT
+ }
+}
+
+#[post("")]
+async fn create_user(
+ body: web::Json<UserRequest>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UsernameTakenError> {
+ let conn = conn.get_ref();
+
+ let user_id = id::new_id(conn, db::user_id_exists).await.unwrap();
+ let username = body.username.clone();
+ let password = PasswordHash::new(&body.password).unwrap();
+
+ if db::username_is_used(conn, &body.username).await.unwrap() {
+ yeet!(UsernameTakenError { username });
+ }
+
+ let user = User {
+ id: user_id,
+ username,
+ password,
+ };
+
+ db::create_user(conn, &user).await.unwrap();
+
+ let response = HttpResponse::Created()
+ .insert_header((header::LOCATION, format!("users/{user_id}")))
+ .finish();
+ Ok(response)
+}
+
+#[derive(Debug, Clone, Error)]
+enum UpdateUserError {
+ #[error(transparent)]
+ UsernameTaken(#[from] UsernameTakenError),
+ #[error(transparent)]
+ NotFound(#[from] UserNotFoundError),
+}
+
+impl ResponseError for UpdateUserError {
+ fn status_code(&self) -> StatusCode {
+ match self {
+ Self::UsernameTaken(e) => e.status_code(),
+ Self::NotFound(e) => e.status_code(),
+ }
+ }
+}
+
+#[put("/{user_id}")]
+async fn update_user(
+ user_id: web::Path<Uuid>,
+ body: web::Json<UserRequest>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateUserError> {
+ let conn = conn.get_ref();
+
+ let user_id = user_id.to_owned();
+ let username = body.username.clone();
+ let password = PasswordHash::new(&body.password).unwrap();
+
+ let old_username = db::get_username(conn, user_id).await.unwrap().unwrap();
+ if username != old_username && db::username_is_used(conn, &body.username).await.unwrap() {
+ yeet!(UsernameTakenError { username }.into())
+ }
+
+ if !db::user_id_exists(conn, user_id).await.unwrap() {
+ yeet!(UserNotFoundError { user_id }.into())
+ }
+
+ let user = User {
+ id: user_id,
+ username,
+ password,
+ };
+
+ db::update_user(conn, &user).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{user_id}/username")]
+async fn update_username(
+ user_id: web::Path<Uuid>,
+ body: web::Json<Box<str>>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateUserError> {
+ let conn = conn.get_ref();
+
+ let user_id = user_id.to_owned();
+ let username = body.clone();
+
+ let old_username = db::get_username(conn, user_id).await.unwrap().unwrap();
+ if username != old_username && db::username_is_used(conn, &body).await.unwrap() {
+ yeet!(UsernameTakenError { username }.into())
+ }
+
+ if !db::user_id_exists(conn, user_id).await.unwrap() {
+ yeet!(UserNotFoundError { user_id }.into())
+ }
+
+ db::update_username(conn, user_id, &body).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{user_id}/password")]
+async fn update_password(
+ user_id: web::Path<Uuid>,
+ body: web::Json<Box<str>>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UserNotFoundError> {
+ let conn = conn.get_ref();
+
+ let user_id = user_id.to_owned();
+ let password = PasswordHash::new(&body).unwrap();
+
+ if !db::user_id_exists(conn, user_id).await.unwrap() {
+ yeet!(UserNotFoundError { user_id })
+ }
+
+ db::update_password(conn, user_id, &password).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+pub fn service() -> Scope {
+ web::scope("/users")
+ .service(search_users)
+ .service(get_user)
+ .service(get_username)
+ .service(create_user)
+ .service(update_user)
+ .service(update_username)
+ .service(update_password)
+}
diff --git a/src/main.rs b/src/main.rs
index e946161..e403798 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,108 +1,108 @@
-use std::time::Duration;
-
-use actix_web::http::header::{self, HeaderValue};
-use actix_web::middleware::{ErrorHandlerResponse, ErrorHandlers, Logger, NormalizePath};
-use actix_web::web::Data;
-use actix_web::{dev, App, HttpServer};
-
-use bpaf::Bpaf;
-use exun::*;
-
-mod api;
-mod models;
-mod resources;
-mod scopes;
-mod services;
-
-use resources::*;
-use services::*;
-use sqlx::MySqlPool;
-
-fn error_content_language<B>(
- mut res: dev::ServiceResponse,
-) -> actix_web::Result<ErrorHandlerResponse<B>> {
- res.response_mut()
- .headers_mut()
- .insert(header::CONTENT_LANGUAGE, HeaderValue::from_static("en"));
-
- Ok(ErrorHandlerResponse::Response(res.map_into_right_body()))
-}
-
-async fn delete_expired_tokens(db: MySqlPool) {
- let db = db.clone();
- let mut interval = actix_rt::time::interval(Duration::from_secs(60 * 20));
- loop {
- interval.tick().await;
- if let Err(e) = db::delete_expired_auth_codes(&db).await {
- log::error!("{}", e);
- }
- if let Err(e) = db::delete_expired_access_tokens(&db).await {
- log::error!("{}", e);
- }
- if let Err(e) = db::delete_expired_refresh_tokens(&db).await {
- log::error!("{}", e);
- }
- }
-}
-
-#[derive(Debug, Clone, Bpaf)]
-#[bpaf(options, version)]
-struct Opts {
- /// The environment that the server is running in. Must be one of: local,
- /// dev, staging, prod.
- #[bpaf(
- env("LOCKDAGGER_ENVIRONMENT"),
- fallback(config::Environment::Local),
- display_fallback
- )]
- env: config::Environment,
-}
-
-#[actix_web::main]
-async fn main() -> Result<(), RawUnexpected> {
- // load the environment file, but only in debug mode
- #[cfg(debug_assertions)]
- dotenv::dotenv()?;
-
- let args = opts().run();
- config::set_environment(args.env);
-
- // initialize the database
- let db_url = secrets::database_url()?;
- let sql_pool = db::initialize(&db_url).await?;
-
- let tera = templates::initialize()?;
-
- let translations = languages::initialize()?;
-
- actix_rt::spawn(delete_expired_tokens(sql_pool.clone()));
-
- // start the server
- HttpServer::new(move || {
- App::new()
- // middleware
- .wrap(ErrorHandlers::new().default_handler(error_content_language))
- .wrap(NormalizePath::trim())
- .wrap(Logger::new("\"%r\" %s %Dms"))
- // app shared state
- .app_data(Data::new(sql_pool.clone()))
- .app_data(Data::new(tera.clone()))
- .app_data(Data::new(translations.clone()))
- // frontend services
- .service(style::get_css)
- .service(scripts::get_js)
- .service(languages::languages())
- // api services
- .service(api::liveops())
- .service(api::users())
- .service(api::clients())
- .service(api::oauth())
- .service(api::ops())
- })
- .shutdown_timeout(1)
- .bind(("127.0.0.1", 8080))?
- .run()
- .await?;
-
- Ok(())
-}
+use std::time::Duration;
+
+use actix_web::http::header::{self, HeaderValue};
+use actix_web::middleware::{ErrorHandlerResponse, ErrorHandlers, Logger, NormalizePath};
+use actix_web::web::Data;
+use actix_web::{dev, App, HttpServer};
+
+use bpaf::Bpaf;
+use exun::*;
+
+mod api;
+mod models;
+mod resources;
+mod scopes;
+mod services;
+
+use resources::*;
+use services::*;
+use sqlx::MySqlPool;
+
+fn error_content_language<B>(
+ mut res: dev::ServiceResponse,
+) -> actix_web::Result<ErrorHandlerResponse<B>> {
+ res.response_mut()
+ .headers_mut()
+ .insert(header::CONTENT_LANGUAGE, HeaderValue::from_static("en"));
+
+ Ok(ErrorHandlerResponse::Response(res.map_into_right_body()))
+}
+
+async fn delete_expired_tokens(db: MySqlPool) {
+ let db = db.clone();
+ let mut interval = actix_rt::time::interval(Duration::from_secs(60 * 20));
+ loop {
+ interval.tick().await;
+ if let Err(e) = db::delete_expired_auth_codes(&db).await {
+ log::error!("{}", e);
+ }
+ if let Err(e) = db::delete_expired_access_tokens(&db).await {
+ log::error!("{}", e);
+ }
+ if let Err(e) = db::delete_expired_refresh_tokens(&db).await {
+ log::error!("{}", e);
+ }
+ }
+}
+
+#[derive(Debug, Clone, Bpaf)]
+#[bpaf(options, version)]
+struct Opts {
+ /// The environment that the server is running in. Must be one of: local,
+ /// dev, staging, prod.
+ #[bpaf(
+ env("LOCKDAGGER_ENVIRONMENT"),
+ fallback(config::Environment::Local),
+ display_fallback
+ )]
+ env: config::Environment,
+}
+
+#[actix_web::main]
+async fn main() -> Result<(), RawUnexpected> {
+ // load the environment file, but only in debug mode
+ #[cfg(debug_assertions)]
+ dotenv::dotenv()?;
+
+ let args = opts().run();
+ config::set_environment(args.env);
+
+ // initialize the database
+ let db_url = secrets::database_url()?;
+ let sql_pool = db::initialize(&db_url).await?;
+
+ let tera = templates::initialize()?;
+
+ let translations = languages::initialize()?;
+
+ actix_rt::spawn(delete_expired_tokens(sql_pool.clone()));
+
+ // start the server
+ HttpServer::new(move || {
+ App::new()
+ // middleware
+ .wrap(ErrorHandlers::new().default_handler(error_content_language))
+ .wrap(NormalizePath::trim())
+ .wrap(Logger::new("\"%r\" %s %Dms"))
+ // app shared state
+ .app_data(Data::new(sql_pool.clone()))
+ .app_data(Data::new(tera.clone()))
+ .app_data(Data::new(translations.clone()))
+ // frontend services
+ .service(style::get_css)
+ .service(scripts::get_js)
+ .service(languages::languages())
+ // api services
+ .service(api::liveops())
+ .service(api::users())
+ .service(api::clients())
+ .service(api::oauth())
+ .service(api::ops())
+ })
+ .shutdown_timeout(1)
+ .bind(("127.0.0.1", 8080))?
+ .run()
+ .await?;
+
+ Ok(())
+}
diff --git a/src/models/client.rs b/src/models/client.rs
index 38be37f..6d0c909 100644
--- a/src/models/client.rs
+++ b/src/models/client.rs
@@ -1,165 +1,165 @@
-use std::{hash::Hash, marker::PhantomData};
-
-use actix_web::{http::StatusCode, ResponseError};
-use exun::{Expect, RawUnexpected};
-use raise::yeet;
-use serde::{Deserialize, Serialize};
-use thiserror::Error;
-use url::Url;
-use uuid::Uuid;
-
-use crate::services::crypto::PasswordHash;
-
-/// There are two types of clients, based on their ability to maintain the
-/// security of their client credentials.
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
-#[sqlx(rename_all = "lowercase")]
-#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
-pub enum ClientType {
- /// A client that is capable of maintaining the confidentiality of their
- /// credentials, or capable of secure client authentication using other
- /// means. An example would be a secure server with restricted access to
- /// the client credentials.
- Confidential,
- /// A client that is incapable of maintaining the confidentiality of their
- /// credentials and cannot authenticate securely by any other means, such
- /// as an installed application, or a web-browser based application.
- Public,
-}
-
-#[derive(Debug, Clone)]
-pub struct Client {
- id: Uuid,
- ty: ClientType,
- alias: Box<str>,
- secret: Option<PasswordHash>,
- allowed_scopes: Box<[Box<str>]>,
- default_scopes: Option<Box<[Box<str>]>>,
- redirect_uris: Box<[Url]>,
- trusted: bool,
-}
-
-impl PartialEq for Client {
- fn eq(&self, other: &Self) -> bool {
- self.id == other.id
- }
-}
-
-impl Eq for Client {}
-
-impl Hash for Client {
- fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
- state.write_u128(self.id.as_u128())
- }
-}
-
-#[derive(Debug, Clone, Copy, Error)]
-#[error("Confidential clients must have a secret, but it was not provided")]
-pub enum CreateClientError {
- #[error("Confidential clients must have a secret, but it was not provided")]
- NoSecret,
- #[error("Only confidential clients may be trusted")]
- TrustedError,
- #[error("Redirect URIs must not include a fragment component")]
- UriFragment,
- #[error("Redirect URIs must use HTTPS")]
- NonHttpsUri,
-}
-
-impl ResponseError for CreateClientError {
- fn status_code(&self) -> StatusCode {
- StatusCode::BAD_REQUEST
- }
-}
-
-impl Client {
- pub fn new(
- id: Uuid,
- alias: &str,
- ty: ClientType,
- secret: Option<&str>,
- allowed_scopes: Box<[Box<str>]>,
- default_scopes: Option<Box<[Box<str>]>>,
- redirect_uris: &[Url],
- trusted: bool,
- ) -> Result<Self, Expect<CreateClientError>> {
- let secret = if let Some(secret) = secret {
- Some(PasswordHash::new(secret)?)
- } else {
- None
- };
-
- if ty == ClientType::Confidential && secret.is_none() {
- yeet!(CreateClientError::NoSecret.into());
- }
-
- if ty == ClientType::Public && trusted {
- yeet!(CreateClientError::TrustedError.into());
- }
-
- for redirect_uri in redirect_uris {
- if redirect_uri.scheme() != "https" {
- yeet!(CreateClientError::NonHttpsUri.into())
- }
-
- if redirect_uri.fragment().is_some() {
- yeet!(CreateClientError::UriFragment.into())
- }
- }
-
- Ok(Self {
- id,
- alias: Box::from(alias),
- ty,
- secret,
- allowed_scopes,
- default_scopes,
- redirect_uris: redirect_uris.into_iter().cloned().collect(),
- trusted,
- })
- }
-
- pub fn id(&self) -> Uuid {
- self.id
- }
-
- pub fn alias(&self) -> &str {
- &self.alias
- }
-
- pub fn client_type(&self) -> ClientType {
- self.ty
- }
-
- pub fn redirect_uris(&self) -> &[Url] {
- &self.redirect_uris
- }
-
- pub fn secret_hash(&self) -> Option<&[u8]> {
- self.secret.as_ref().map(|s| s.hash())
- }
-
- pub fn secret_salt(&self) -> Option<&[u8]> {
- self.secret.as_ref().map(|s| s.salt())
- }
-
- pub fn secret_version(&self) -> Option<u8> {
- self.secret.as_ref().map(|s| s.version())
- }
-
- pub fn allowed_scopes(&self) -> String {
- self.allowed_scopes.join(" ")
- }
-
- pub fn default_scopes(&self) -> Option<String> {
- self.default_scopes.clone().map(|s| s.join(" "))
- }
-
- pub fn is_trusted(&self) -> bool {
- self.trusted
- }
-
- pub fn check_secret(&self, secret: &str) -> Option<Result<bool, RawUnexpected>> {
- self.secret.as_ref().map(|s| s.check_password(secret))
- }
-}
+use std::{hash::Hash, marker::PhantomData};
+
+use actix_web::{http::StatusCode, ResponseError};
+use exun::{Expect, RawUnexpected};
+use raise::yeet;
+use serde::{Deserialize, Serialize};
+use thiserror::Error;
+use url::Url;
+use uuid::Uuid;
+
+use crate::services::crypto::PasswordHash;
+
+/// There are two types of clients, based on their ability to maintain the
+/// security of their client credentials.
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
+#[sqlx(rename_all = "lowercase")]
+#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
+pub enum ClientType {
+ /// A client that is capable of maintaining the confidentiality of their
+ /// credentials, or capable of secure client authentication using other
+ /// means. An example would be a secure server with restricted access to
+ /// the client credentials.
+ Confidential,
+ /// A client that is incapable of maintaining the confidentiality of their
+ /// credentials and cannot authenticate securely by any other means, such
+ /// as an installed application, or a web-browser based application.
+ Public,
+}
+
+#[derive(Debug, Clone)]
+pub struct Client {
+ id: Uuid,
+ ty: ClientType,
+ alias: Box<str>,
+ secret: Option<PasswordHash>,
+ allowed_scopes: Box<[Box<str>]>,
+ default_scopes: Option<Box<[Box<str>]>>,
+ redirect_uris: Box<[Url]>,
+ trusted: bool,
+}
+
+impl PartialEq for Client {
+ fn eq(&self, other: &Self) -> bool {
+ self.id == other.id
+ }
+}
+
+impl Eq for Client {}
+
+impl Hash for Client {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ state.write_u128(self.id.as_u128())
+ }
+}
+
+#[derive(Debug, Clone, Copy, Error)]
+#[error("Confidential clients must have a secret, but it was not provided")]
+pub enum CreateClientError {
+ #[error("Confidential clients must have a secret, but it was not provided")]
+ NoSecret,
+ #[error("Only confidential clients may be trusted")]
+ TrustedError,
+ #[error("Redirect URIs must not include a fragment component")]
+ UriFragment,
+ #[error("Redirect URIs must use HTTPS")]
+ NonHttpsUri,
+}
+
+impl ResponseError for CreateClientError {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::BAD_REQUEST
+ }
+}
+
+impl Client {
+ pub fn new(
+ id: Uuid,
+ alias: &str,
+ ty: ClientType,
+ secret: Option<&str>,
+ allowed_scopes: Box<[Box<str>]>,
+ default_scopes: Option<Box<[Box<str>]>>,
+ redirect_uris: &[Url],
+ trusted: bool,
+ ) -> Result<Self, Expect<CreateClientError>> {
+ let secret = if let Some(secret) = secret {
+ Some(PasswordHash::new(secret)?)
+ } else {
+ None
+ };
+
+ if ty == ClientType::Confidential && secret.is_none() {
+ yeet!(CreateClientError::NoSecret.into());
+ }
+
+ if ty == ClientType::Public && trusted {
+ yeet!(CreateClientError::TrustedError.into());
+ }
+
+ for redirect_uri in redirect_uris {
+ if redirect_uri.scheme() != "https" {
+ yeet!(CreateClientError::NonHttpsUri.into())
+ }
+
+ if redirect_uri.fragment().is_some() {
+ yeet!(CreateClientError::UriFragment.into())
+ }
+ }
+
+ Ok(Self {
+ id,
+ alias: Box::from(alias),
+ ty,
+ secret,
+ allowed_scopes,
+ default_scopes,
+ redirect_uris: redirect_uris.into_iter().cloned().collect(),
+ trusted,
+ })
+ }
+
+ pub fn id(&self) -> Uuid {
+ self.id
+ }
+
+ pub fn alias(&self) -> &str {
+ &self.alias
+ }
+
+ pub fn client_type(&self) -> ClientType {
+ self.ty
+ }
+
+ pub fn redirect_uris(&self) -> &[Url] {
+ &self.redirect_uris
+ }
+
+ pub fn secret_hash(&self) -> Option<&[u8]> {
+ self.secret.as_ref().map(|s| s.hash())
+ }
+
+ pub fn secret_salt(&self) -> Option<&[u8]> {
+ self.secret.as_ref().map(|s| s.salt())
+ }
+
+ pub fn secret_version(&self) -> Option<u8> {
+ self.secret.as_ref().map(|s| s.version())
+ }
+
+ pub fn allowed_scopes(&self) -> String {
+ self.allowed_scopes.join(" ")
+ }
+
+ pub fn default_scopes(&self) -> Option<String> {
+ self.default_scopes.clone().map(|s| s.join(" "))
+ }
+
+ pub fn is_trusted(&self) -> bool {
+ self.trusted
+ }
+
+ pub fn check_secret(&self, secret: &str) -> Option<Result<bool, RawUnexpected>> {
+ self.secret.as_ref().map(|s| s.check_password(secret))
+ }
+}
diff --git a/src/models/mod.rs b/src/models/mod.rs
index 633f846..1379893 100644
--- a/src/models/mod.rs
+++ b/src/models/mod.rs
@@ -1,2 +1,2 @@
-pub mod client;
-pub mod user;
+pub mod client;
+pub mod user;
diff --git a/src/models/user.rs b/src/models/user.rs
index 8555ee2..493a267 100644
--- a/src/models/user.rs
+++ b/src/models/user.rs
@@ -1,49 +1,49 @@
-use std::hash::Hash;
-
-use exun::RawUnexpected;
-use uuid::Uuid;
-
-use crate::services::crypto::PasswordHash;
-
-#[derive(Debug, Clone)]
-pub struct User {
- pub id: Uuid,
- pub username: Box<str>,
- pub password: PasswordHash,
-}
-
-impl PartialEq for User {
- fn eq(&self, other: &Self) -> bool {
- self.id == other.id
- }
-}
-
-impl Eq for User {}
-
-impl Hash for User {
- fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
- state.write_u128(self.id.as_u128())
- }
-}
-
-impl User {
- pub fn username(&self) -> &str {
- &self.username
- }
-
- pub fn password_hash(&self) -> &[u8] {
- self.password.hash()
- }
-
- pub fn password_salt(&self) -> &[u8] {
- self.password.salt()
- }
-
- pub fn password_version(&self) -> u8 {
- self.password.version()
- }
-
- pub fn check_password(&self, password: &str) -> Result<bool, RawUnexpected> {
- self.password.check_password(password)
- }
-}
+use std::hash::Hash;
+
+use exun::RawUnexpected;
+use uuid::Uuid;
+
+use crate::services::crypto::PasswordHash;
+
+#[derive(Debug, Clone)]
+pub struct User {
+ pub id: Uuid,
+ pub username: Box<str>,
+ pub password: PasswordHash,
+}
+
+impl PartialEq for User {
+ fn eq(&self, other: &Self) -> bool {
+ self.id == other.id
+ }
+}
+
+impl Eq for User {}
+
+impl Hash for User {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ state.write_u128(self.id.as_u128())
+ }
+}
+
+impl User {
+ pub fn username(&self) -> &str {
+ &self.username
+ }
+
+ pub fn password_hash(&self) -> &[u8] {
+ self.password.hash()
+ }
+
+ pub fn password_salt(&self) -> &[u8] {
+ self.password.salt()
+ }
+
+ pub fn password_version(&self) -> u8 {
+ self.password.version()
+ }
+
+ pub fn check_password(&self, password: &str) -> Result<bool, RawUnexpected> {
+ self.password.check_password(password)
+ }
+}
diff --git a/src/resources/languages.rs b/src/resources/languages.rs
index 8ef7553..b01daf9 100644
--- a/src/resources/languages.rs
+++ b/src/resources/languages.rs
@@ -1,67 +1,67 @@
-use std::collections::HashMap;
-use std::path::PathBuf;
-
-use actix_web::{get, web, HttpResponse, Scope};
-use exun::RawUnexpected;
-use ini::{Ini, Properties};
-use raise::yeet;
-use unic_langid::subtags::Language;
-
-#[derive(Debug, Clone, PartialEq)]
-pub struct Translations {
- languages: HashMap<Language, Properties>,
-}
-
-pub fn initialize() -> Result<Translations, RawUnexpected> {
- let mut translations = Translations {
- languages: HashMap::new(),
- };
- translations.refresh()?;
- Ok(translations)
-}
-
-impl Translations {
- pub fn languages(&self) -> Box<[Language]> {
- self.languages.keys().cloned().collect()
- }
-
- pub fn get_message(&self, language: Language, key: &str) -> Option<String> {
- Some(self.languages.get(&language)?.get(key)?.to_owned())
- }
-
- pub fn refresh(&mut self) -> Result<(), RawUnexpected> {
- let mut languages = HashMap::with_capacity(1);
- for entry in PathBuf::from("static/languages").read_dir()? {
- let entry = entry?;
- if entry.file_type()?.is_dir() {
- continue;
- }
-
- let path = entry.path();
- let path = path.to_string_lossy();
- let Some(language) = path.as_bytes().get(0..2) else { yeet!(RawUnexpected::msg(format!("{} not long enough to be a language name", path))) };
- let language = Language::from_bytes(language)?;
- let messages = Ini::load_from_file(entry.path())?.general_section().clone();
-
- languages.insert(language, messages);
- }
-
- self.languages = languages;
- Ok(())
- }
-}
-
-#[get("")]
-pub async fn all_languages(translations: web::Data<Translations>) -> HttpResponse {
- HttpResponse::Ok().json(
- translations
- .languages()
- .into_iter()
- .map(|l| l.as_str())
- .collect::<Box<[&str]>>(),
- )
-}
-
-pub fn languages() -> Scope {
- web::scope("/languages").service(all_languages)
-}
+use std::collections::HashMap;
+use std::path::PathBuf;
+
+use actix_web::{get, web, HttpResponse, Scope};
+use exun::RawUnexpected;
+use ini::{Ini, Properties};
+use raise::yeet;
+use unic_langid::subtags::Language;
+
+#[derive(Debug, Clone, PartialEq)]
+pub struct Translations {
+ languages: HashMap<Language, Properties>,
+}
+
+pub fn initialize() -> Result<Translations, RawUnexpected> {
+ let mut translations = Translations {
+ languages: HashMap::new(),
+ };
+ translations.refresh()?;
+ Ok(translations)
+}
+
+impl Translations {
+ pub fn languages(&self) -> Box<[Language]> {
+ self.languages.keys().cloned().collect()
+ }
+
+ pub fn get_message(&self, language: Language, key: &str) -> Option<String> {
+ Some(self.languages.get(&language)?.get(key)?.to_owned())
+ }
+
+ pub fn refresh(&mut self) -> Result<(), RawUnexpected> {
+ let mut languages = HashMap::with_capacity(1);
+ for entry in PathBuf::from("static/languages").read_dir()? {
+ let entry = entry?;
+ if entry.file_type()?.is_dir() {
+ continue;
+ }
+
+ let path = entry.path();
+ let path = path.to_string_lossy();
+ let Some(language) = path.as_bytes().get(0..2) else { yeet!(RawUnexpected::msg(format!("{} not long enough to be a language name", path))) };
+ let language = Language::from_bytes(language)?;
+ let messages = Ini::load_from_file(entry.path())?.general_section().clone();
+
+ languages.insert(language, messages);
+ }
+
+ self.languages = languages;
+ Ok(())
+ }
+}
+
+#[get("")]
+pub async fn all_languages(translations: web::Data<Translations>) -> HttpResponse {
+ HttpResponse::Ok().json(
+ translations
+ .languages()
+ .into_iter()
+ .map(|l| l.as_str())
+ .collect::<Box<[&str]>>(),
+ )
+}
+
+pub fn languages() -> Scope {
+ web::scope("/languages").service(all_languages)
+}
diff --git a/src/resources/mod.rs b/src/resources/mod.rs
index 9251d2c..d9f14ba 100644
--- a/src/resources/mod.rs
+++ b/src/resources/mod.rs
@@ -1,4 +1,4 @@
-pub mod languages;
-pub mod scripts;
-pub mod style;
-pub mod templates;
+pub mod languages;
+pub mod scripts;
+pub mod style;
+pub mod templates;
diff --git a/src/resources/scripts.rs b/src/resources/scripts.rs
index 1b27859..66b9693 100644
--- a/src/resources/scripts.rs
+++ b/src/resources/scripts.rs
@@ -1,38 +1,38 @@
-use std::path::Path;
-
-use actix_web::{get, http::StatusCode, web, HttpResponse, ResponseError};
-use exun::{Expect, ResultErrorExt};
-use path_clean::clean;
-use raise::yeet;
-use serde::Serialize;
-use thiserror::Error;
-
-#[derive(Debug, Clone, Error, Serialize)]
-pub enum LoadScriptError {
- #[error("The requested script does not exist")]
- FileNotFound(Box<Path>),
-}
-
-impl ResponseError for LoadScriptError {
- fn status_code(&self) -> StatusCode {
- match self {
- Self::FileNotFound(..) => StatusCode::NOT_FOUND,
- }
- }
-}
-
-fn load(script: &str) -> Result<String, Expect<LoadScriptError>> {
- let path = clean(format!("static/scripts/{}.js", script));
- if !path.exists() {
- yeet!(LoadScriptError::FileNotFound(path.into()).into());
- }
- let js = std::fs::read_to_string(format!("static/scripts/{}.js", script)).unexpect()?;
- Ok(js)
-}
-
-#[get("/{script}.js")]
-pub async fn get_js(script: web::Path<Box<str>>) -> Result<HttpResponse, LoadScriptError> {
- let js = load(&script).map_err(|e| e.unwrap())?;
- let response = HttpResponse::Ok().content_type("text/javascript").body(js);
- Ok(response)
-}
+use std::path::Path;
+
+use actix_web::{get, http::StatusCode, web, HttpResponse, ResponseError};
+use exun::{Expect, ResultErrorExt};
+use path_clean::clean;
+use raise::yeet;
+use serde::Serialize;
+use thiserror::Error;
+
+#[derive(Debug, Clone, Error, Serialize)]
+pub enum LoadScriptError {
+ #[error("The requested script does not exist")]
+ FileNotFound(Box<Path>),
+}
+
+impl ResponseError for LoadScriptError {
+ fn status_code(&self) -> StatusCode {
+ match self {
+ Self::FileNotFound(..) => StatusCode::NOT_FOUND,
+ }
+ }
+}
+
+fn load(script: &str) -> Result<String, Expect<LoadScriptError>> {
+ let path = clean(format!("static/scripts/{}.js", script));
+ if !path.exists() {
+ yeet!(LoadScriptError::FileNotFound(path.into()).into());
+ }
+ let js = std::fs::read_to_string(format!("static/scripts/{}.js", script)).unexpect()?;
+ Ok(js)
+}
+
+#[get("/{script}.js")]
+pub async fn get_js(script: web::Path<Box<str>>) -> Result<HttpResponse, LoadScriptError> {
+ let js = load(&script).map_err(|e| e.unwrap())?;
+ let response = HttpResponse::Ok().content_type("text/javascript").body(js);
+ Ok(response)
+}
diff --git a/src/resources/style.rs b/src/resources/style.rs
index 3ea56d2..8b21dc4 100644
--- a/src/resources/style.rs
+++ b/src/resources/style.rs
@@ -1,54 +1,54 @@
-use std::path::Path;
-
-use actix_web::{get, http::StatusCode, web, HttpResponse, ResponseError};
-use exun::{Expect, ResultErrorExt};
-use grass::OutputStyle;
-use path_clean::clean;
-use raise::yeet;
-use serde::Serialize;
-use thiserror::Error;
-
-fn output_style() -> OutputStyle {
- if cfg!(debug_assertions) {
- OutputStyle::Expanded
- } else {
- OutputStyle::Compressed
- }
-}
-
-fn options() -> grass::Options<'static> {
- grass::Options::default()
- .load_path("static/style")
- .style(output_style())
-}
-
-#[derive(Debug, Clone, Error, Serialize)]
-pub enum LoadStyleError {
- #[error("The requested stylesheet was not found")]
- FileNotFound(Box<Path>),
-}
-
-impl ResponseError for LoadStyleError {
- fn status_code(&self) -> StatusCode {
- match self {
- Self::FileNotFound(..) => StatusCode::NOT_FOUND,
- }
- }
-}
-
-pub fn load(stylesheet: &str) -> Result<String, Expect<LoadStyleError>> {
- let options = options();
- let path = clean(format!("static/style/{}.scss", stylesheet));
- if !path.exists() {
- yeet!(LoadStyleError::FileNotFound(path.into()).into());
- }
- let css = grass::from_path(format!("static/style/{}.scss", stylesheet), &options).unexpect()?;
- Ok(css)
-}
-
-#[get("/{stylesheet}.css")]
-pub async fn get_css(stylesheet: web::Path<Box<str>>) -> Result<HttpResponse, LoadStyleError> {
- let css = load(&stylesheet).map_err(|e| e.unwrap())?;
- let response = HttpResponse::Ok().content_type("text/css").body(css);
- Ok(response)
-}
+use std::path::Path;
+
+use actix_web::{get, http::StatusCode, web, HttpResponse, ResponseError};
+use exun::{Expect, ResultErrorExt};
+use grass::OutputStyle;
+use path_clean::clean;
+use raise::yeet;
+use serde::Serialize;
+use thiserror::Error;
+
+fn output_style() -> OutputStyle {
+ if cfg!(debug_assertions) {
+ OutputStyle::Expanded
+ } else {
+ OutputStyle::Compressed
+ }
+}
+
+fn options() -> grass::Options<'static> {
+ grass::Options::default()
+ .load_path("static/style")
+ .style(output_style())
+}
+
+#[derive(Debug, Clone, Error, Serialize)]
+pub enum LoadStyleError {
+ #[error("The requested stylesheet was not found")]
+ FileNotFound(Box<Path>),
+}
+
+impl ResponseError for LoadStyleError {
+ fn status_code(&self) -> StatusCode {
+ match self {
+ Self::FileNotFound(..) => StatusCode::NOT_FOUND,
+ }
+ }
+}
+
+pub fn load(stylesheet: &str) -> Result<String, Expect<LoadStyleError>> {
+ let options = options();
+ let path = clean(format!("static/style/{}.scss", stylesheet));
+ if !path.exists() {
+ yeet!(LoadStyleError::FileNotFound(path.into()).into());
+ }
+ let css = grass::from_path(format!("static/style/{}.scss", stylesheet), &options).unexpect()?;
+ Ok(css)
+}
+
+#[get("/{stylesheet}.css")]
+pub async fn get_css(stylesheet: web::Path<Box<str>>) -> Result<HttpResponse, LoadStyleError> {
+ let css = load(&stylesheet).map_err(|e| e.unwrap())?;
+ let response = HttpResponse::Ok().content_type("text/css").body(css);
+ Ok(response)
+}
diff --git a/src/resources/templates.rs b/src/resources/templates.rs
index 9168fb9..baf2ee8 100644
--- a/src/resources/templates.rs
+++ b/src/resources/templates.rs
@@ -1,101 +1,101 @@
-use std::collections::HashMap;
-
-use exun::{RawUnexpected, ResultErrorExt};
-use raise::yeet;
-use serde::Serialize;
-use tera::{Function, Tera, Value};
-use unic_langid::subtags::Language;
-
-use crate::api::AuthorizationParameters;
-
-use super::languages;
-
-fn make_msg(language: Language, translations: languages::Translations) -> impl Function {
- Box::new(
- move |args: &HashMap<String, Value>| -> tera::Result<Value> {
- let Some(key) = args.get("key") else { yeet!("No parameter 'key' provided".into()) };
- let Some(key) = key.as_str() else { yeet!(format!("{} is not a string", key).into()) };
- let Some(value) = translations.get_message(language, key) else { yeet!(format!("{} does not exist", key).into()) };
- Ok(Value::String(value))
- },
- )
-}
-
-fn extend_tera(
- tera: &Tera,
- language: Language,
- translations: languages::Translations,
-) -> Result<Tera, RawUnexpected> {
- let mut new_tera = initialize()?;
- new_tera.extend(tera)?;
- new_tera.register_function("msg", make_msg(language, translations));
- Ok(new_tera)
-}
-
-pub fn initialize() -> tera::Result<Tera> {
- let tera = Tera::new("static/templates/*")?;
- Ok(tera)
-}
-
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
-#[serde(rename_all = "camelCase")]
-pub enum ErrorPage {
- InvalidRequest,
- ClientNotFound,
- MissingRedirectUri,
- InvalidRedirectUri,
- InternalServerError,
-}
-
-pub fn error_page(
- tera: &Tera,
- language: Language,
- mut translations: languages::Translations,
- error: ErrorPage,
-) -> Result<String, RawUnexpected> {
- translations.refresh()?;
- let mut tera = extend_tera(tera, language, translations)?;
- tera.full_reload()?;
-
- let error = serde_variant::to_variant_name(&error)?;
- let header = format!("errorHeader_{error}");
- let message = format!("errorMessage_{error}");
-
- let mut context = tera::Context::new();
- context.insert("lang", language.as_str());
- context.insert("errorHeader", &header);
- context.insert("errormessage", &message);
-
- tera.render("error.html", &context).unexpect()
-}
-
-pub fn login_page(
- tera: &Tera,
- params: &AuthorizationParameters,
- language: Language,
- mut translations: languages::Translations,
-) -> Result<String, RawUnexpected> {
- translations.refresh()?;
- let mut tera = extend_tera(tera, language, translations)?;
- tera.full_reload()?;
- let mut context = tera::Context::new();
- context.insert("lang", language.as_str());
- context.insert("params", &serde_urlencoded::to_string(params)?);
- tera.render("login.html", &context).unexpect()
-}
-
-pub fn login_error_page(
- tera: &Tera,
- params: &AuthorizationParameters,
- language: Language,
- mut translations: languages::Translations,
-) -> Result<String, RawUnexpected> {
- translations.refresh()?;
- let mut tera = extend_tera(tera, language, translations)?;
- tera.full_reload()?;
- let mut context = tera::Context::new();
- context.insert("lang", language.as_str());
- context.insert("params", &serde_urlencoded::to_string(params)?);
- context.insert("errorMessage", "loginErrorMessage");
- tera.render("login.html", &context).unexpect()
-}
+use std::collections::HashMap;
+
+use exun::{RawUnexpected, ResultErrorExt};
+use raise::yeet;
+use serde::Serialize;
+use tera::{Function, Tera, Value};
+use unic_langid::subtags::Language;
+
+use crate::api::AuthorizationParameters;
+
+use super::languages;
+
+fn make_msg(language: Language, translations: languages::Translations) -> impl Function {
+ Box::new(
+ move |args: &HashMap<String, Value>| -> tera::Result<Value> {
+ let Some(key) = args.get("key") else { yeet!("No parameter 'key' provided".into()) };
+ let Some(key) = key.as_str() else { yeet!(format!("{} is not a string", key).into()) };
+ let Some(value) = translations.get_message(language, key) else { yeet!(format!("{} does not exist", key).into()) };
+ Ok(Value::String(value))
+ },
+ )
+}
+
+fn extend_tera(
+ tera: &Tera,
+ language: Language,
+ translations: languages::Translations,
+) -> Result<Tera, RawUnexpected> {
+ let mut new_tera = initialize()?;
+ new_tera.extend(tera)?;
+ new_tera.register_function("msg", make_msg(language, translations));
+ Ok(new_tera)
+}
+
+pub fn initialize() -> tera::Result<Tera> {
+ let tera = Tera::new("static/templates/*")?;
+ Ok(tera)
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub enum ErrorPage {
+ InvalidRequest,
+ ClientNotFound,
+ MissingRedirectUri,
+ InvalidRedirectUri,
+ InternalServerError,
+}
+
+pub fn error_page(
+ tera: &Tera,
+ language: Language,
+ mut translations: languages::Translations,
+ error: ErrorPage,
+) -> Result<String, RawUnexpected> {
+ translations.refresh()?;
+ let mut tera = extend_tera(tera, language, translations)?;
+ tera.full_reload()?;
+
+ let error = serde_variant::to_variant_name(&error)?;
+ let header = format!("errorHeader_{error}");
+ let message = format!("errorMessage_{error}");
+
+ let mut context = tera::Context::new();
+ context.insert("lang", language.as_str());
+ context.insert("errorHeader", &header);
+ context.insert("errormessage", &message);
+
+ tera.render("error.html", &context).unexpect()
+}
+
+pub fn login_page(
+ tera: &Tera,
+ params: &AuthorizationParameters,
+ language: Language,
+ mut translations: languages::Translations,
+) -> Result<String, RawUnexpected> {
+ translations.refresh()?;
+ let mut tera = extend_tera(tera, language, translations)?;
+ tera.full_reload()?;
+ let mut context = tera::Context::new();
+ context.insert("lang", language.as_str());
+ context.insert("params", &serde_urlencoded::to_string(params)?);
+ tera.render("login.html", &context).unexpect()
+}
+
+pub fn login_error_page(
+ tera: &Tera,
+ params: &AuthorizationParameters,
+ language: Language,
+ mut translations: languages::Translations,
+) -> Result<String, RawUnexpected> {
+ translations.refresh()?;
+ let mut tera = extend_tera(tera, language, translations)?;
+ tera.full_reload()?;
+ let mut context = tera::Context::new();
+ context.insert("lang", language.as_str());
+ context.insert("params", &serde_urlencoded::to_string(params)?);
+ context.insert("errorMessage", "loginErrorMessage");
+ tera.render("login.html", &context).unexpect()
+}
diff --git a/src/scopes/admin.rs b/src/scopes/admin.rs
index 1e13b85..31e7880 100644
--- a/src/scopes/admin.rs
+++ b/src/scopes/admin.rs
@@ -1,28 +1,28 @@
-use std::fmt::{self, Display};
-
-use crate::models::{client::Client, user::User};
-
-use super::{Action, Scope};
-
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
-pub struct Admin;
-
-impl Display for Admin {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- f.write_str("admin")
- }
-}
-
-impl Scope for Admin {
- fn parse_modifiers(_modifiers: &str) -> Result<Self, Box<str>> {
- Ok(Self)
- }
-
- fn has_user_permission(&self, _: &User, _: &Action<User>) -> bool {
- true
- }
-
- fn has_client_permission(&self, _: &User, _: &Action<Client>) -> bool {
- true
- }
-}
+use std::fmt::{self, Display};
+
+use crate::models::{client::Client, user::User};
+
+use super::{Action, Scope};
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub struct Admin;
+
+impl Display for Admin {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str("admin")
+ }
+}
+
+impl Scope for Admin {
+ fn parse_modifiers(_modifiers: &str) -> Result<Self, Box<str>> {
+ Ok(Self)
+ }
+
+ fn has_user_permission(&self, _: &User, _: &Action<User>) -> bool {
+ true
+ }
+
+ fn has_client_permission(&self, _: &User, _: &Action<Client>) -> bool {
+ true
+ }
+}
diff --git a/src/scopes/mod.rs b/src/scopes/mod.rs
index fb7780f..25296fd 100644
--- a/src/scopes/mod.rs
+++ b/src/scopes/mod.rs
@@ -1,128 +1,128 @@
-use std::collections::HashSet;
-
-use self::admin::Admin;
-use crate::models::{client::Client, user::User};
-
-mod admin;
-
-/// The action which was attempted on a resource
-pub enum Action<T> {
- Create(T),
- Read(T),
- Update(T, T),
- Delete(T),
-}
-
-trait ScopeSuperSet {
- fn is_superset_of(&self, other: &Self) -> bool;
-}
-
-trait Scope: ToString {
- /// Parse a scope of the format: `{Scope::NAME}:{modifiers}`
- fn parse_modifiers(modifiers: &str) -> Result<Self, Box<str>>
- where
- Self: Sized;
-
- /// Returns `true` if and only if the given `user` is allowed to take the
- /// given `action` with this scope
- fn has_user_permission(&self, user: &User, action: &Action<User>) -> bool;
-
- // Returns `true` if and only if the given `user` is allowed to take the
- /// given `action` with this scope
- fn has_client_permission(&self, user: &User, action: &Action<Client>) -> bool;
-}
-
-pub struct ParseScopeError {
- scope: Box<str>,
- error: ParseScopeErrorType,
-}
-
-impl ParseScopeError {
- fn invalid_type(scope: &str, scope_type: &str) -> Self {
- let scope = scope.into();
- let error = ParseScopeErrorType::InvalidType(scope_type.into());
- Self { scope, error }
- }
-}
-
-pub enum ParseScopeErrorType {
- InvalidType(Box<str>),
- InvalidModifiers(Box<str>),
-}
-
-fn parse_scope(scope: &str) -> Result<Box<dyn Scope>, ParseScopeError> {
- let mut split = scope.split(':');
- let scope_type = split.next().unwrap();
- let _modifiers: String = split.collect();
-
- match scope_type {
- "admin" => Ok(Box::new(Admin)),
- _ => Err(ParseScopeError::invalid_type(scope, scope_type)),
- }
-}
-
-fn parse_scopes(scopes: &str) -> Result<Vec<Box<dyn Scope>>, ParseScopeError> {
- scopes
- .split_whitespace()
- .map(|scope| parse_scope(scope))
- .collect()
-}
-
-fn parse_scopes_errors(
- results: &[Result<Box<dyn Scope>, ParseScopeError>],
-) -> Vec<&ParseScopeError> {
- let mut errors = Vec::with_capacity(results.len());
- for result in results {
- if let Err(pse) = result {
- errors.push(pse)
- }
- }
-
- errors
-}
-
-/// Returns `true` if and only if all values in `left_scopes` are contained in
-/// `right_scopes`.
-pub fn is_subset_of(left_scopes: &str, right_scopes: &str) -> bool {
- let right_scopes: HashSet<&str> = right_scopes.split_whitespace().collect();
-
- for scope in left_scopes.split_whitespace() {
- if !right_scopes.contains(scope) {
- return false;
- }
- }
-
- true
-}
-
-pub fn has_user_permission(
- user: User,
- action: Action<User>,
- client_scopes: &str,
-) -> Result<bool, ParseScopeError> {
- let scopes = parse_scopes(client_scopes)?;
-
- for scope in scopes {
- if scope.has_user_permission(&user, &action) {
- return Ok(true);
- }
- }
-
- Ok(false)
-}
-
-pub fn has_client_permission(
- user: User,
- action: Action<Client>,
- client_scopes: &str,
-) -> Result<bool, ParseScopeError> {
- let scopes = parse_scopes(client_scopes)?;
-
- for scope in scopes {
- if scope.has_client_permission(&user, &action) {
- return Ok(true);
- }
- }
-
- Ok(false)
-}
+use std::collections::HashSet;
+
+use self::admin::Admin;
+use crate::models::{client::Client, user::User};
+
+mod admin;
+
+/// The action which was attempted on a resource
+pub enum Action<T> {
+ Create(T),
+ Read(T),
+ Update(T, T),
+ Delete(T),
+}
+
+trait ScopeSuperSet {
+ fn is_superset_of(&self, other: &Self) -> bool;
+}
+
+trait Scope: ToString {
+ /// Parse a scope of the format: `{Scope::NAME}:{modifiers}`
+ fn parse_modifiers(modifiers: &str) -> Result<Self, Box<str>>
+ where
+ Self: Sized;
+
+ /// Returns `true` if and only if the given `user` is allowed to take the
+ /// given `action` with this scope
+ fn has_user_permission(&self, user: &User, action: &Action<User>) -> bool;
+
+ // Returns `true` if and only if the given `user` is allowed to take the
+ /// given `action` with this scope
+ fn has_client_permission(&self, user: &User, action: &Action<Client>) -> bool;
+}
+
+pub struct ParseScopeError {
+ scope: Box<str>,
+ error: ParseScopeErrorType,
+}
+
+impl ParseScopeError {
+ fn invalid_type(scope: &str, scope_type: &str) -> Self {
+ let scope = scope.into();
+ let error = ParseScopeErrorType::InvalidType(scope_type.into());
+ Self { scope, error }
+ }
+}
+
+pub enum ParseScopeErrorType {
+ InvalidType(Box<str>),
+ InvalidModifiers(Box<str>),
+}
+
+fn parse_scope(scope: &str) -> Result<Box<dyn Scope>, ParseScopeError> {
+ let mut split = scope.split(':');
+ let scope_type = split.next().unwrap();
+ let _modifiers: String = split.collect();
+
+ match scope_type {
+ "admin" => Ok(Box::new(Admin)),
+ _ => Err(ParseScopeError::invalid_type(scope, scope_type)),
+ }
+}
+
+fn parse_scopes(scopes: &str) -> Result<Vec<Box<dyn Scope>>, ParseScopeError> {
+ scopes
+ .split_whitespace()
+ .map(|scope| parse_scope(scope))
+ .collect()
+}
+
+fn parse_scopes_errors(
+ results: &[Result<Box<dyn Scope>, ParseScopeError>],
+) -> Vec<&ParseScopeError> {
+ let mut errors = Vec::with_capacity(results.len());
+ for result in results {
+ if let Err(pse) = result {
+ errors.push(pse)
+ }
+ }
+
+ errors
+}
+
+/// Returns `true` if and only if all values in `left_scopes` are contained in
+/// `right_scopes`.
+pub fn is_subset_of(left_scopes: &str, right_scopes: &str) -> bool {
+ let right_scopes: HashSet<&str> = right_scopes.split_whitespace().collect();
+
+ for scope in left_scopes.split_whitespace() {
+ if !right_scopes.contains(scope) {
+ return false;
+ }
+ }
+
+ true
+}
+
+pub fn has_user_permission(
+ user: User,
+ action: Action<User>,
+ client_scopes: &str,
+) -> Result<bool, ParseScopeError> {
+ let scopes = parse_scopes(client_scopes)?;
+
+ for scope in scopes {
+ if scope.has_user_permission(&user, &action) {
+ return Ok(true);
+ }
+ }
+
+ Ok(false)
+}
+
+pub fn has_client_permission(
+ user: User,
+ action: Action<Client>,
+ client_scopes: &str,
+) -> Result<bool, ParseScopeError> {
+ let scopes = parse_scopes(client_scopes)?;
+
+ for scope in scopes {
+ if scope.has_client_permission(&user, &action) {
+ return Ok(true);
+ }
+ }
+
+ Ok(false)
+}
diff --git a/src/services/authorization.rs b/src/services/authorization.rs
index bfbbb5a..4e6ef35 100644
--- a/src/services/authorization.rs
+++ b/src/services/authorization.rs
@@ -1,82 +1,82 @@
-use actix_web::{
- error::ParseError,
- http::header::{self, Header, HeaderName, HeaderValue, InvalidHeaderValue, TryIntoHeaderValue},
-};
-use base64::Engine;
-use raise::yeet;
-
-#[derive(Clone)]
-pub struct BasicAuthorization {
- username: Box<str>,
- password: Box<str>,
-}
-
-impl TryIntoHeaderValue for BasicAuthorization {
- type Error = InvalidHeaderValue;
-
- fn try_into_value(self) -> Result<HeaderValue, Self::Error> {
- let username = self.username;
- let password = self.password;
- let utf8 = format!("{username}:{password}");
- let b64 = base64::engine::general_purpose::STANDARD.encode(utf8);
- let value = format!("Basic {b64}");
- HeaderValue::from_str(&value)
- }
-}
-
-impl Header for BasicAuthorization {
- fn name() -> HeaderName {
- header::AUTHORIZATION
- }
-
- fn parse<M: actix_web::HttpMessage>(msg: &M) -> Result<Self, actix_web::error::ParseError> {
- let Some(value) = msg.headers().get(Self::name()) else {
- yeet!(ParseError::Header)
- };
-
- let Ok(value) = value.to_str() else {
- yeet!(ParseError::Header)
- };
-
- if !value.starts_with("Basic") {
- yeet!(ParseError::Header);
- }
-
- let value: String = value
- .chars()
- .skip(5)
- .skip_while(|ch| ch.is_whitespace())
- .collect();
-
- if value.is_empty() {
- yeet!(ParseError::Header);
- }
-
- let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(value) else {
- yeet!(ParseError::Header)
- };
-
- let Ok(value) = String::from_utf8(bytes) else {
- yeet!(ParseError::Header)
- };
-
- let mut parts = value.split(':');
- let username = Box::from(parts.next().unwrap());
- let Some(password) = parts.next() else {
- yeet!(ParseError::Header)
- };
- let password = Box::from(password);
-
- Ok(Self { username, password })
- }
-}
-
-impl BasicAuthorization {
- pub fn username(&self) -> &str {
- &self.username
- }
-
- pub fn password(&self) -> &str {
- &self.password
- }
-}
+use actix_web::{
+ error::ParseError,
+ http::header::{self, Header, HeaderName, HeaderValue, InvalidHeaderValue, TryIntoHeaderValue},
+};
+use base64::Engine;
+use raise::yeet;
+
+#[derive(Clone)]
+pub struct BasicAuthorization {
+ username: Box<str>,
+ password: Box<str>,
+}
+
+impl TryIntoHeaderValue for BasicAuthorization {
+ type Error = InvalidHeaderValue;
+
+ fn try_into_value(self) -> Result<HeaderValue, Self::Error> {
+ let username = self.username;
+ let password = self.password;
+ let utf8 = format!("{username}:{password}");
+ let b64 = base64::engine::general_purpose::STANDARD.encode(utf8);
+ let value = format!("Basic {b64}");
+ HeaderValue::from_str(&value)
+ }
+}
+
+impl Header for BasicAuthorization {
+ fn name() -> HeaderName {
+ header::AUTHORIZATION
+ }
+
+ fn parse<M: actix_web::HttpMessage>(msg: &M) -> Result<Self, actix_web::error::ParseError> {
+ let Some(value) = msg.headers().get(Self::name()) else {
+ yeet!(ParseError::Header)
+ };
+
+ let Ok(value) = value.to_str() else {
+ yeet!(ParseError::Header)
+ };
+
+ if !value.starts_with("Basic") {
+ yeet!(ParseError::Header);
+ }
+
+ let value: String = value
+ .chars()
+ .skip(5)
+ .skip_while(|ch| ch.is_whitespace())
+ .collect();
+
+ if value.is_empty() {
+ yeet!(ParseError::Header);
+ }
+
+ let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(value) else {
+ yeet!(ParseError::Header)
+ };
+
+ let Ok(value) = String::from_utf8(bytes) else {
+ yeet!(ParseError::Header)
+ };
+
+ let mut parts = value.split(':');
+ let username = Box::from(parts.next().unwrap());
+ let Some(password) = parts.next() else {
+ yeet!(ParseError::Header)
+ };
+ let password = Box::from(password);
+
+ Ok(Self { username, password })
+ }
+}
+
+impl BasicAuthorization {
+ pub fn username(&self) -> &str {
+ &self.username
+ }
+
+ pub fn password(&self) -> &str {
+ &self.password
+ }
+}
diff --git a/src/services/config.rs b/src/services/config.rs
index 6468126..932f38f 100644
--- a/src/services/config.rs
+++ b/src/services/config.rs
@@ -1,74 +1,74 @@
-use std::{
- fmt::{self, Display},
- str::FromStr,
-};
-
-use exun::RawUnexpected;
-use parking_lot::RwLock;
-use serde::Deserialize;
-use thiserror::Error;
-use url::Url;
-
-static ENVIRONMENT: RwLock<Environment> = RwLock::new(Environment::Local);
-
-#[derive(Debug, Clone, Deserialize)]
-pub struct Config {
- pub id: Box<str>,
- pub url: Url,
-}
-
-pub fn get_config() -> Result<Config, RawUnexpected> {
- let env = get_environment();
- let path = format!("static/config/{env}.toml");
- let string = std::fs::read_to_string(path)?;
- let config = toml::from_str(&string)?;
- Ok(config)
-}
-
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
-pub enum Environment {
- Local,
- Dev,
- Staging,
- Production,
-}
-
-impl Display for Environment {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- match self {
- Self::Local => f.write_str("local"),
- Self::Dev => f.write_str("dev"),
- Self::Staging => f.write_str("staging"),
- Self::Production => f.write_str("prod"),
- }
- }
-}
-
-#[derive(Debug, Clone, Error)]
-#[error("Expected one of the following environments: local, dev, staging, prod. Found {string}")]
-pub struct ParseEnvironmentError {
- string: Box<str>,
-}
-
-impl FromStr for Environment {
- type Err = ParseEnvironmentError;
-
- fn from_str(s: &str) -> Result<Self, Self::Err> {
- match s {
- "local" => Ok(Self::Local),
- "dev" => Ok(Self::Dev),
- "staging" => Ok(Self::Staging),
- "prod" => Ok(Self::Production),
- _ => Err(ParseEnvironmentError { string: s.into() }),
- }
- }
-}
-
-pub fn set_environment(env: Environment) {
- let mut env_ptr = ENVIRONMENT.write();
- *env_ptr = env;
-}
-
-fn get_environment() -> Environment {
- ENVIRONMENT.read().clone()
-}
+use std::{
+ fmt::{self, Display},
+ str::FromStr,
+};
+
+use exun::RawUnexpected;
+use parking_lot::RwLock;
+use serde::Deserialize;
+use thiserror::Error;
+use url::Url;
+
+static ENVIRONMENT: RwLock<Environment> = RwLock::new(Environment::Local);
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct Config {
+ pub id: Box<str>,
+ pub url: Url,
+}
+
+pub fn get_config() -> Result<Config, RawUnexpected> {
+ let env = get_environment();
+ let path = format!("static/config/{env}.toml");
+ let string = std::fs::read_to_string(path)?;
+ let config = toml::from_str(&string)?;
+ Ok(config)
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub enum Environment {
+ Local,
+ Dev,
+ Staging,
+ Production,
+}
+
+impl Display for Environment {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::Local => f.write_str("local"),
+ Self::Dev => f.write_str("dev"),
+ Self::Staging => f.write_str("staging"),
+ Self::Production => f.write_str("prod"),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("Expected one of the following environments: local, dev, staging, prod. Found {string}")]
+pub struct ParseEnvironmentError {
+ string: Box<str>,
+}
+
+impl FromStr for Environment {
+ type Err = ParseEnvironmentError;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ match s {
+ "local" => Ok(Self::Local),
+ "dev" => Ok(Self::Dev),
+ "staging" => Ok(Self::Staging),
+ "prod" => Ok(Self::Production),
+ _ => Err(ParseEnvironmentError { string: s.into() }),
+ }
+ }
+}
+
+pub fn set_environment(env: Environment) {
+ let mut env_ptr = ENVIRONMENT.write();
+ *env_ptr = env;
+}
+
+fn get_environment() -> Environment {
+ ENVIRONMENT.read().clone()
+}
diff --git a/src/services/crypto.rs b/src/services/crypto.rs
index 5fce403..0107374 100644
--- a/src/services/crypto.rs
+++ b/src/services/crypto.rs
@@ -1,97 +1,97 @@
-use std::hash::Hash;
-
-use argon2::{hash_raw, verify_raw};
-use exun::RawUnexpected;
-
-use crate::services::secrets::pepper;
-
-/// The configuration used for hashing and verifying passwords
-///
-/// # Example
-///
-/// ```
-/// use crate::services::secrets;
-///
-/// let pepper = secrets::pepper();
-/// let config = config(&pepper);
-/// ```
-fn config<'a>(pepper: &'a [u8]) -> argon2::Config<'a> {
- argon2::Config {
- hash_length: 32,
- lanes: 4,
- mem_cost: 5333,
- time_cost: 4,
- secret: pepper,
-
- ad: &[],
- thread_mode: argon2::ThreadMode::Sequential,
- variant: argon2::Variant::Argon2i,
- version: argon2::Version::Version13,
- }
-}
-
-/// A password hash and salt for a user
-#[derive(Debug, Clone, PartialEq, Eq)]
-pub struct PasswordHash {
- hash: Box<[u8]>,
- salt: Box<[u8]>,
- version: u8,
-}
-
-impl Hash for PasswordHash {
- fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
- state.write(&self.hash)
- }
-}
-
-impl PasswordHash {
- /// Hash a password using Argon2
- pub fn new(password: &str) -> Result<Self, RawUnexpected> {
- let password = password.as_bytes();
-
- let salt: [u8; 32] = rand::random();
- let salt = Box::from(salt);
- let pepper = pepper()?;
- let hash = hash_raw(password, &salt, &config(&pepper))?.into_boxed_slice();
-
- Ok(Self {
- hash,
- salt,
- version: 0,
- })
- }
-
- /// Create this structure from a given hash and salt
- pub fn from_fields(hash: &[u8], salt: &[u8], version: u8) -> Self {
- Self {
- hash: Box::from(hash),
- salt: Box::from(salt),
- version,
- }
- }
-
- /// Get the password hash
- pub fn hash(&self) -> &[u8] {
- &self.hash
- }
-
- /// Get the salt used for the hash
- pub fn salt(&self) -> &[u8] {
- &self.salt
- }
-
- pub fn version(&self) -> u8 {
- self.version
- }
-
- /// Check if the given password is the one that was hashed
- pub fn check_password(&self, password: &str) -> Result<bool, RawUnexpected> {
- let pepper = pepper()?;
- Ok(verify_raw(
- password.as_bytes(),
- &self.salt,
- &self.hash,
- &config(&pepper),
- )?)
- }
-}
+use std::hash::Hash;
+
+use argon2::{hash_raw, verify_raw};
+use exun::RawUnexpected;
+
+use crate::services::secrets::pepper;
+
+/// The configuration used for hashing and verifying passwords
+///
+/// # Example
+///
+/// ```
+/// use crate::services::secrets;
+///
+/// let pepper = secrets::pepper();
+/// let config = config(&pepper);
+/// ```
+fn config<'a>(pepper: &'a [u8]) -> argon2::Config<'a> {
+ argon2::Config {
+ hash_length: 32,
+ lanes: 4,
+ mem_cost: 5333,
+ time_cost: 4,
+ secret: pepper,
+
+ ad: &[],
+ thread_mode: argon2::ThreadMode::Sequential,
+ variant: argon2::Variant::Argon2i,
+ version: argon2::Version::Version13,
+ }
+}
+
+/// A password hash and salt for a user
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct PasswordHash {
+ hash: Box<[u8]>,
+ salt: Box<[u8]>,
+ version: u8,
+}
+
+impl Hash for PasswordHash {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ state.write(&self.hash)
+ }
+}
+
+impl PasswordHash {
+ /// Hash a password using Argon2
+ pub fn new(password: &str) -> Result<Self, RawUnexpected> {
+ let password = password.as_bytes();
+
+ let salt: [u8; 32] = rand::random();
+ let salt = Box::from(salt);
+ let pepper = pepper()?;
+ let hash = hash_raw(password, &salt, &config(&pepper))?.into_boxed_slice();
+
+ Ok(Self {
+ hash,
+ salt,
+ version: 0,
+ })
+ }
+
+ /// Create this structure from a given hash and salt
+ pub fn from_fields(hash: &[u8], salt: &[u8], version: u8) -> Self {
+ Self {
+ hash: Box::from(hash),
+ salt: Box::from(salt),
+ version,
+ }
+ }
+
+ /// Get the password hash
+ pub fn hash(&self) -> &[u8] {
+ &self.hash
+ }
+
+ /// Get the salt used for the hash
+ pub fn salt(&self) -> &[u8] {
+ &self.salt
+ }
+
+ pub fn version(&self) -> u8 {
+ self.version
+ }
+
+ /// Check if the given password is the one that was hashed
+ pub fn check_password(&self, password: &str) -> Result<bool, RawUnexpected> {
+ let pepper = pepper()?;
+ Ok(verify_raw(
+ password.as_bytes(),
+ &self.salt,
+ &self.hash,
+ &config(&pepper),
+ )?)
+ }
+}
diff --git a/src/services/db.rs b/src/services/db.rs
index f811d79..e3cb48b 100644
--- a/src/services/db.rs
+++ b/src/services/db.rs
@@ -1,15 +1,15 @@
-use exun::{RawUnexpected, ResultErrorExt};
-use sqlx::MySqlPool;
-
-mod client;
-mod jwt;
-mod user;
-
-pub use self::jwt::*;
-pub use client::*;
-pub use user::*;
-
-/// Intialize the connection pool
-pub async fn initialize(db_url: &str) -> Result<MySqlPool, RawUnexpected> {
- MySqlPool::connect(db_url).await.unexpect()
-}
+use exun::{RawUnexpected, ResultErrorExt};
+use sqlx::MySqlPool;
+
+mod client;
+mod jwt;
+mod user;
+
+pub use self::jwt::*;
+pub use client::*;
+pub use user::*;
+
+/// Intialize the connection pool
+pub async fn initialize(db_url: &str) -> Result<MySqlPool, RawUnexpected> {
+ MySqlPool::connect(db_url).await.unexpect()
+}
diff --git a/src/services/db/client.rs b/src/services/db/client.rs
index b8942e9..1ad97b1 100644
--- a/src/services/db/client.rs
+++ b/src/services/db/client.rs
@@ -1,392 +1,392 @@
-use std::str::FromStr;
-
-use exun::{RawUnexpected, ResultErrorExt};
-use sqlx::{
- mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, FromRow, MySql, Transaction,
-};
-use url::Url;
-use uuid::Uuid;
-
-use crate::{
- models::client::{Client, ClientType},
- services::crypto::PasswordHash,
-};
-
-#[derive(Debug, Clone, FromRow)]
-pub struct ClientRow {
- pub id: Uuid,
- pub alias: String,
- pub client_type: ClientType,
- pub allowed_scopes: String,
- pub default_scopes: Option<String>,
- pub is_trusted: bool,
-}
-
-#[derive(Clone, FromRow)]
-struct HashRow {
- secret_hash: Option<Vec<u8>>,
- secret_salt: Option<Vec<u8>>,
- secret_version: Option<u32>,
-}
-
-pub async fn client_id_exists<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
-) -> Result<bool, RawUnexpected> {
- query_scalar!(
- r"SELECT EXISTS(SELECT id FROM clients WHERE id = ?) as `e: bool`",
- id
- )
- .fetch_one(executor)
- .await
- .unexpect()
-}
-
-pub async fn client_alias_exists<'c>(
- executor: impl Executor<'c, Database = MySql>,
- alias: &str,
-) -> Result<bool, RawUnexpected> {
- query_scalar!(
- "SELECT EXISTS(SELECT alias FROM clients WHERE alias = ?) as `e: bool`",
- alias
- )
- .fetch_one(executor)
- .await
- .unexpect()
-}
-
-pub async fn get_client_id_by_alias<'c>(
- executor: impl Executor<'c, Database = MySql>,
- alias: &str,
-) -> Result<Option<Uuid>, RawUnexpected> {
- query_scalar!(
- "SELECT id as `id: Uuid` FROM clients WHERE alias = ?",
- alias
- )
- .fetch_optional(executor)
- .await
- .unexpect()
-}
-
-pub async fn get_client_response<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
-) -> Result<Option<ClientRow>, RawUnexpected> {
- let record = query_as!(
- ClientRow,
- r"SELECT id as `id: Uuid`,
- alias,
- type as `client_type: ClientType`,
- allowed_scopes,
- default_scopes,
- trusted as `is_trusted: bool`
- FROM clients WHERE id = ?",
- id
- )
- .fetch_optional(executor)
- .await?;
-
- Ok(record)
-}
-
-pub async fn get_client_alias<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
-) -> Result<Option<Box<str>>, RawUnexpected> {
- let alias = query_scalar!("SELECT alias FROM clients WHERE id = ?", id)
- .fetch_optional(executor)
- .await
- .unexpect()?;
-
- Ok(alias.map(String::into_boxed_str))
-}
-
-pub async fn get_client_type<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
-) -> Result<Option<ClientType>, RawUnexpected> {
- let ty = query_scalar!(
- "SELECT type as `type: ClientType` FROM clients WHERE id = ?",
- id
- )
- .fetch_optional(executor)
- .await
- .unexpect()?;
-
- Ok(ty)
-}
-
-pub async fn get_client_allowed_scopes<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
-) -> Result<Option<Box<str>>, RawUnexpected> {
- let scopes = query_scalar!("SELECT allowed_scopes FROM clients WHERE id = ?", id)
- .fetch_optional(executor)
- .await?;
-
- Ok(scopes.map(Box::from))
-}
-
-pub async fn get_client_default_scopes<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
-) -> Result<Option<Option<Box<str>>>, RawUnexpected> {
- let scopes = query_scalar!("SELECT default_scopes FROM clients WHERE id = ?", id)
- .fetch_optional(executor)
- .await?;
-
- Ok(scopes.map(|s| s.map(Box::from)))
-}
-
-pub async fn get_client_secret<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
-) -> Result<Option<PasswordHash>, RawUnexpected> {
- let hash = query_as!(
- HashRow,
- r"SELECT secret_hash, secret_salt, secret_version
- FROM clients WHERE id = ?",
- id
- )
- .fetch_optional(executor)
- .await?;
-
- let Some(hash) = hash else { return Ok(None) };
- let Some(version) = hash.secret_version else { return Ok(None) };
- let Some(salt) = hash.secret_hash else { return Ok(None) };
- let Some(hash) = hash.secret_salt else { return Ok(None) };
-
- let hash = PasswordHash::from_fields(&hash, &salt, version as u8);
- Ok(Some(hash))
-}
-
-pub async fn is_client_trusted<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
-) -> Result<Option<bool>, RawUnexpected> {
- query_scalar!("SELECT trusted as `t: bool` FROM clients WHERE id = ?", id)
- .fetch_optional(executor)
- .await
- .unexpect()
-}
-
-pub async fn get_client_redirect_uris<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
-) -> Result<Box<[Url]>, RawUnexpected> {
- let uris = query_scalar!(
- "SELECT redirect_uri FROM client_redirect_uris WHERE client_id = ?",
- id
- )
- .fetch_all(executor)
- .await
- .unexpect()?;
-
- uris.into_iter()
- .map(|s| Url::from_str(&s).unexpect())
- .collect()
-}
-
-pub async fn client_has_redirect_uri<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
- url: &Url,
-) -> Result<bool, RawUnexpected> {
- query_scalar!(
- r"SELECT EXISTS(
- SELECT redirect_uri
- FROM client_redirect_uris
- WHERE client_id = ? AND redirect_uri = ?
- ) as `e: bool`",
- id,
- url.to_string()
- )
- .fetch_one(executor)
- .await
- .unexpect()
-}
-
-async fn delete_client_redirect_uris<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
-) -> Result<(), sqlx::Error> {
- query!("DELETE FROM client_redirect_uris WHERE client_id = ?", id)
- .execute(executor)
- .await?;
- Ok(())
-}
-
-async fn create_client_redirect_uris<'c>(
- mut transaction: Transaction<'c, MySql>,
- client_id: Uuid,
- uris: &[Url],
-) -> Result<(), sqlx::Error> {
- for uri in uris {
- query!(
- r"INSERT INTO client_redirect_uris (client_id, redirect_uri)
- VALUES ( ?, ?)",
- client_id,
- uri.to_string()
- )
- .execute(&mut transaction)
- .await?;
- }
-
- transaction.commit().await?;
-
- Ok(())
-}
-
-pub async fn create_client<'c>(
- mut transaction: Transaction<'c, MySql>,
- client: &Client,
-) -> Result<(), sqlx::Error> {
- query!(
- r"INSERT INTO clients (id, alias, type, secret_hash, secret_salt, secret_version, allowed_scopes, default_scopes)
- VALUES ( ?, ?, ?, ?, ?, ?, ?, ?)",
- client.id(),
- client.alias(),
- client.client_type(),
- client.secret_hash(),
- client.secret_salt(),
- client.secret_version(),
- client.allowed_scopes(),
- client.default_scopes()
- )
- .execute(&mut transaction)
- .await?;
-
- create_client_redirect_uris(transaction, client.id(), client.redirect_uris()).await?;
-
- Ok(())
-}
-
-pub async fn update_client<'c>(
- mut transaction: Transaction<'c, MySql>,
- client: &Client,
-) -> Result<(), sqlx::Error> {
- query!(
- r"UPDATE clients SET
- alias = ?,
- type = ?,
- secret_hash = ?,
- secret_salt = ?,
- secret_version = ?,
- allowed_scopes = ?,
- default_scopes = ?
- WHERE id = ?",
- client.client_type(),
- client.alias(),
- client.secret_hash(),
- client.secret_salt(),
- client.secret_version(),
- client.allowed_scopes(),
- client.default_scopes(),
- client.id()
- )
- .execute(&mut transaction)
- .await?;
-
- update_client_redirect_uris(transaction, client.id(), client.redirect_uris()).await?;
-
- Ok(())
-}
-
-pub async fn update_client_alias<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
- alias: &str,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- query!("UPDATE clients SET alias = ? WHERE id = ?", alias, id)
- .execute(executor)
- .await
-}
-
-pub async fn update_client_type<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
- ty: ClientType,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- query!("UPDATE clients SET type = ? WHERE id = ?", ty, id)
- .execute(executor)
- .await
-}
-
-pub async fn update_client_allowed_scopes<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
- allowed_scopes: &str,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- query!(
- "UPDATE clients SET allowed_scopes = ? WHERE id = ?",
- allowed_scopes,
- id
- )
- .execute(executor)
- .await
-}
-
-pub async fn update_client_default_scopes<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
- default_scopes: Option<String>,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- query!(
- "UPDATE clients SET default_scopes = ? WHERE id = ?",
- default_scopes,
- id
- )
- .execute(executor)
- .await
-}
-
-pub async fn update_client_trusted<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
- is_trusted: bool,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- query!(
- "UPDATE clients SET trusted = ? WHERE id = ?",
- is_trusted,
- id
- )
- .execute(executor)
- .await
-}
-
-pub async fn update_client_redirect_uris<'c>(
- mut transaction: Transaction<'c, MySql>,
- id: Uuid,
- uris: &[Url],
-) -> Result<(), sqlx::Error> {
- delete_client_redirect_uris(&mut transaction, id).await?;
- create_client_redirect_uris(transaction, id, uris).await?;
- Ok(())
-}
-
-pub async fn update_client_secret<'c>(
- executor: impl Executor<'c, Database = MySql>,
- id: Uuid,
- secret: Option<PasswordHash>,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- if let Some(secret) = secret {
- query!(
- "UPDATE clients SET secret_hash = ?, secret_salt = ?, secret_version = ? WHERE id = ?",
- secret.hash(),
- secret.salt(),
- secret.version(),
- id
- )
- .execute(executor)
- .await
- } else {
- query!(
- r"UPDATE clients
- SET secret_hash = NULL, secret_salt = NULL, secret_version = NULL
- WHERE id = ?",
- id
- )
- .execute(executor)
- .await
- }
-}
+use std::str::FromStr;
+
+use exun::{RawUnexpected, ResultErrorExt};
+use sqlx::{
+ mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, FromRow, MySql, Transaction,
+};
+use url::Url;
+use uuid::Uuid;
+
+use crate::{
+ models::client::{Client, ClientType},
+ services::crypto::PasswordHash,
+};
+
+#[derive(Debug, Clone, FromRow)]
+pub struct ClientRow {
+ pub id: Uuid,
+ pub alias: String,
+ pub client_type: ClientType,
+ pub allowed_scopes: String,
+ pub default_scopes: Option<String>,
+ pub is_trusted: bool,
+}
+
+#[derive(Clone, FromRow)]
+struct HashRow {
+ secret_hash: Option<Vec<u8>>,
+ secret_salt: Option<Vec<u8>>,
+ secret_version: Option<u32>,
+}
+
+pub async fn client_id_exists<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ r"SELECT EXISTS(SELECT id FROM clients WHERE id = ?) as `e: bool`",
+ id
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn client_alias_exists<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ alias: &str,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ "SELECT EXISTS(SELECT alias FROM clients WHERE alias = ?) as `e: bool`",
+ alias
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn get_client_id_by_alias<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ alias: &str,
+) -> Result<Option<Uuid>, RawUnexpected> {
+ query_scalar!(
+ "SELECT id as `id: Uuid` FROM clients WHERE alias = ?",
+ alias
+ )
+ .fetch_optional(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn get_client_response<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<ClientRow>, RawUnexpected> {
+ let record = query_as!(
+ ClientRow,
+ r"SELECT id as `id: Uuid`,
+ alias,
+ type as `client_type: ClientType`,
+ allowed_scopes,
+ default_scopes,
+ trusted as `is_trusted: bool`
+ FROM clients WHERE id = ?",
+ id
+ )
+ .fetch_optional(executor)
+ .await?;
+
+ Ok(record)
+}
+
+pub async fn get_client_alias<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<Box<str>>, RawUnexpected> {
+ let alias = query_scalar!("SELECT alias FROM clients WHERE id = ?", id)
+ .fetch_optional(executor)
+ .await
+ .unexpect()?;
+
+ Ok(alias.map(String::into_boxed_str))
+}
+
+pub async fn get_client_type<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<ClientType>, RawUnexpected> {
+ let ty = query_scalar!(
+ "SELECT type as `type: ClientType` FROM clients WHERE id = ?",
+ id
+ )
+ .fetch_optional(executor)
+ .await
+ .unexpect()?;
+
+ Ok(ty)
+}
+
+pub async fn get_client_allowed_scopes<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<Box<str>>, RawUnexpected> {
+ let scopes = query_scalar!("SELECT allowed_scopes FROM clients WHERE id = ?", id)
+ .fetch_optional(executor)
+ .await?;
+
+ Ok(scopes.map(Box::from))
+}
+
+pub async fn get_client_default_scopes<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<Option<Box<str>>>, RawUnexpected> {
+ let scopes = query_scalar!("SELECT default_scopes FROM clients WHERE id = ?", id)
+ .fetch_optional(executor)
+ .await?;
+
+ Ok(scopes.map(|s| s.map(Box::from)))
+}
+
+pub async fn get_client_secret<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<PasswordHash>, RawUnexpected> {
+ let hash = query_as!(
+ HashRow,
+ r"SELECT secret_hash, secret_salt, secret_version
+ FROM clients WHERE id = ?",
+ id
+ )
+ .fetch_optional(executor)
+ .await?;
+
+ let Some(hash) = hash else { return Ok(None) };
+ let Some(version) = hash.secret_version else { return Ok(None) };
+ let Some(salt) = hash.secret_hash else { return Ok(None) };
+ let Some(hash) = hash.secret_salt else { return Ok(None) };
+
+ let hash = PasswordHash::from_fields(&hash, &salt, version as u8);
+ Ok(Some(hash))
+}
+
+pub async fn is_client_trusted<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<bool>, RawUnexpected> {
+ query_scalar!("SELECT trusted as `t: bool` FROM clients WHERE id = ?", id)
+ .fetch_optional(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn get_client_redirect_uris<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Box<[Url]>, RawUnexpected> {
+ let uris = query_scalar!(
+ "SELECT redirect_uri FROM client_redirect_uris WHERE client_id = ?",
+ id
+ )
+ .fetch_all(executor)
+ .await
+ .unexpect()?;
+
+ uris.into_iter()
+ .map(|s| Url::from_str(&s).unexpect())
+ .collect()
+}
+
+pub async fn client_has_redirect_uri<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ url: &Url,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ r"SELECT EXISTS(
+ SELECT redirect_uri
+ FROM client_redirect_uris
+ WHERE client_id = ? AND redirect_uri = ?
+ ) as `e: bool`",
+ id,
+ url.to_string()
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+async fn delete_client_redirect_uris<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<(), sqlx::Error> {
+ query!("DELETE FROM client_redirect_uris WHERE client_id = ?", id)
+ .execute(executor)
+ .await?;
+ Ok(())
+}
+
+async fn create_client_redirect_uris<'c>(
+ mut transaction: Transaction<'c, MySql>,
+ client_id: Uuid,
+ uris: &[Url],
+) -> Result<(), sqlx::Error> {
+ for uri in uris {
+ query!(
+ r"INSERT INTO client_redirect_uris (client_id, redirect_uri)
+ VALUES ( ?, ?)",
+ client_id,
+ uri.to_string()
+ )
+ .execute(&mut transaction)
+ .await?;
+ }
+
+ transaction.commit().await?;
+
+ Ok(())
+}
+
+pub async fn create_client<'c>(
+ mut transaction: Transaction<'c, MySql>,
+ client: &Client,
+) -> Result<(), sqlx::Error> {
+ query!(
+ r"INSERT INTO clients (id, alias, type, secret_hash, secret_salt, secret_version, allowed_scopes, default_scopes)
+ VALUES ( ?, ?, ?, ?, ?, ?, ?, ?)",
+ client.id(),
+ client.alias(),
+ client.client_type(),
+ client.secret_hash(),
+ client.secret_salt(),
+ client.secret_version(),
+ client.allowed_scopes(),
+ client.default_scopes()
+ )
+ .execute(&mut transaction)
+ .await?;
+
+ create_client_redirect_uris(transaction, client.id(), client.redirect_uris()).await?;
+
+ Ok(())
+}
+
+pub async fn update_client<'c>(
+ mut transaction: Transaction<'c, MySql>,
+ client: &Client,
+) -> Result<(), sqlx::Error> {
+ query!(
+ r"UPDATE clients SET
+ alias = ?,
+ type = ?,
+ secret_hash = ?,
+ secret_salt = ?,
+ secret_version = ?,
+ allowed_scopes = ?,
+ default_scopes = ?
+ WHERE id = ?",
+ client.client_type(),
+ client.alias(),
+ client.secret_hash(),
+ client.secret_salt(),
+ client.secret_version(),
+ client.allowed_scopes(),
+ client.default_scopes(),
+ client.id()
+ )
+ .execute(&mut transaction)
+ .await?;
+
+ update_client_redirect_uris(transaction, client.id(), client.redirect_uris()).await?;
+
+ Ok(())
+}
+
+pub async fn update_client_alias<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ alias: &str,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!("UPDATE clients SET alias = ? WHERE id = ?", alias, id)
+ .execute(executor)
+ .await
+}
+
+pub async fn update_client_type<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ ty: ClientType,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!("UPDATE clients SET type = ? WHERE id = ?", ty, id)
+ .execute(executor)
+ .await
+}
+
+pub async fn update_client_allowed_scopes<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ allowed_scopes: &str,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ "UPDATE clients SET allowed_scopes = ? WHERE id = ?",
+ allowed_scopes,
+ id
+ )
+ .execute(executor)
+ .await
+}
+
+pub async fn update_client_default_scopes<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ default_scopes: Option<String>,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ "UPDATE clients SET default_scopes = ? WHERE id = ?",
+ default_scopes,
+ id
+ )
+ .execute(executor)
+ .await
+}
+
+pub async fn update_client_trusted<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ is_trusted: bool,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ "UPDATE clients SET trusted = ? WHERE id = ?",
+ is_trusted,
+ id
+ )
+ .execute(executor)
+ .await
+}
+
+pub async fn update_client_redirect_uris<'c>(
+ mut transaction: Transaction<'c, MySql>,
+ id: Uuid,
+ uris: &[Url],
+) -> Result<(), sqlx::Error> {
+ delete_client_redirect_uris(&mut transaction, id).await?;
+ create_client_redirect_uris(transaction, id, uris).await?;
+ Ok(())
+}
+
+pub async fn update_client_secret<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ secret: Option<PasswordHash>,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ if let Some(secret) = secret {
+ query!(
+ "UPDATE clients SET secret_hash = ?, secret_salt = ?, secret_version = ? WHERE id = ?",
+ secret.hash(),
+ secret.salt(),
+ secret.version(),
+ id
+ )
+ .execute(executor)
+ .await
+ } else {
+ query!(
+ r"UPDATE clients
+ SET secret_hash = NULL, secret_salt = NULL, secret_version = NULL
+ WHERE id = ?",
+ id
+ )
+ .execute(executor)
+ .await
+ }
+}
diff --git a/src/services/db/jwt.rs b/src/services/db/jwt.rs
index b2f1367..73d6902 100644
--- a/src/services/db/jwt.rs
+++ b/src/services/db/jwt.rs
@@ -1,199 +1,199 @@
-use chrono::{DateTime, Utc};
-use exun::{RawUnexpected, ResultErrorExt};
-use sqlx::{query, query_scalar, Executor, MySql};
-use uuid::Uuid;
-
-use crate::services::jwt::RevokedRefreshTokenReason;
-
-pub async fn auth_code_exists<'c>(
- executor: impl Executor<'c, Database = MySql>,
- jti: Uuid,
-) -> Result<bool, RawUnexpected> {
- query_scalar!(
- "SELECT EXISTS(SELECT jti FROM auth_codes WHERE jti = ?) as `e: bool`",
- jti
- )
- .fetch_one(executor)
- .await
- .unexpect()
-}
-
-pub async fn access_token_exists<'c>(
- executor: impl Executor<'c, Database = MySql>,
- jti: Uuid,
-) -> Result<bool, RawUnexpected> {
- query_scalar!(
- "SELECT EXISTS(SELECT jti FROM access_tokens WHERE jti = ?) as `e: bool`",
- jti
- )
- .fetch_one(executor)
- .await
- .unexpect()
-}
-
-pub async fn refresh_token_exists<'c>(
- executor: impl Executor<'c, Database = MySql>,
- jti: Uuid,
-) -> Result<bool, RawUnexpected> {
- query_scalar!(
- "SELECT EXISTS(SELECT jti FROM refresh_tokens WHERE jti = ?) as `e: bool`",
- jti
- )
- .fetch_one(executor)
- .await
- .unexpect()
-}
-
-pub async fn refresh_token_revoked<'c>(
- executor: impl Executor<'c, Database = MySql>,
- jti: Uuid,
-) -> Result<bool, RawUnexpected> {
- let result = query_scalar!(
- r"SELECT EXISTS(
- SELECT revoked_reason FROM refresh_tokens WHERE jti = ? and revoked_reason IS NOT NULL
- ) as `e: bool`",
- jti
- )
- .fetch_one(executor)
- .await?
- .unwrap_or(true);
-
- Ok(result)
-}
-
-pub async fn create_auth_code<'c>(
- executor: impl Executor<'c, Database = MySql>,
- jti: Uuid,
- exp: DateTime<Utc>,
-) -> Result<(), sqlx::Error> {
- query!(
- r"INSERT INTO auth_codes (jti, exp)
- VALUES ( ?, ?)",
- jti,
- exp
- )
- .execute(executor)
- .await?;
-
- Ok(())
-}
-
-pub async fn create_access_token<'c>(
- executor: impl Executor<'c, Database = MySql>,
- jti: Uuid,
- auth_code: Option<Uuid>,
- exp: DateTime<Utc>,
-) -> Result<(), sqlx::Error> {
- query!(
- r"INSERT INTO access_tokens (jti, auth_code, exp)
- VALUES ( ?, ?, ?)",
- jti,
- auth_code,
- exp
- )
- .execute(executor)
- .await?;
-
- Ok(())
-}
-
-pub async fn create_refresh_token<'c>(
- executor: impl Executor<'c, Database = MySql>,
- jti: Uuid,
- auth_code: Option<Uuid>,
- exp: DateTime<Utc>,
-) -> Result<(), sqlx::Error> {
- query!(
- r"INSERT INTO access_tokens (jti, auth_code, exp)
- VALUES ( ?, ?, ?)",
- jti,
- auth_code,
- exp
- )
- .execute(executor)
- .await?;
-
- Ok(())
-}
-
-pub async fn delete_auth_code<'c>(
- executor: impl Executor<'c, Database = MySql>,
- auth_code: Uuid,
-) -> Result<bool, RawUnexpected> {
- let result = query!("DELETE FROM auth_codes WHERE jti = ?", auth_code)
- .execute(executor)
- .await?;
-
- Ok(result.rows_affected() != 0)
-}
-
-pub async fn delete_expired_auth_codes<'c>(
- executor: impl Executor<'c, Database = MySql>,
-) -> Result<(), RawUnexpected> {
- query!("DELETE FROM auth_codes WHERE exp < ?", Utc::now())
- .execute(executor)
- .await?;
-
- Ok(())
-}
-
-pub async fn delete_access_tokens_with_auth_code<'c>(
- executor: impl Executor<'c, Database = MySql>,
- auth_code: Uuid,
-) -> Result<bool, RawUnexpected> {
- let result = query!("DELETE FROM access_tokens WHERE auth_code = ?", auth_code)
- .execute(executor)
- .await?;
-
- Ok(result.rows_affected() != 0)
-}
-
-pub async fn delete_expired_access_tokens<'c>(
- executor: impl Executor<'c, Database = MySql>,
-) -> Result<(), RawUnexpected> {
- query!("DELETE FROM access_tokens WHERE exp < ?", Utc::now())
- .execute(executor)
- .await?;
-
- Ok(())
-}
-
-pub async fn revoke_refresh_token<'c>(
- executor: impl Executor<'c, Database = MySql>,
- jti: Uuid,
-) -> Result<bool, RawUnexpected> {
- let result = query!(
- "UPDATE refresh_tokens SET revoked_reason = ? WHERE jti = ?",
- RevokedRefreshTokenReason::NewRefreshToken,
- jti
- )
- .execute(executor)
- .await?;
-
- Ok(result.rows_affected() != 0)
-}
-
-pub async fn revoke_refresh_tokens_with_auth_code<'c>(
- executor: impl Executor<'c, Database = MySql>,
- auth_code: Uuid,
-) -> Result<bool, RawUnexpected> {
- let result = query!(
- "UPDATE refresh_tokens SET revoked_reason = ? WHERE auth_code = ?",
- RevokedRefreshTokenReason::ReusedAuthorizationCode,
- auth_code
- )
- .execute(executor)
- .await?;
-
- Ok(result.rows_affected() != 0)
-}
-
-pub async fn delete_expired_refresh_tokens<'c>(
- executor: impl Executor<'c, Database = MySql>,
-) -> Result<(), RawUnexpected> {
- query!("DELETE FROM refresh_tokens WHERE exp < ?", Utc::now())
- .execute(executor)
- .await?;
-
- Ok(())
-}
+use chrono::{DateTime, Utc};
+use exun::{RawUnexpected, ResultErrorExt};
+use sqlx::{query, query_scalar, Executor, MySql};
+use uuid::Uuid;
+
+use crate::services::jwt::RevokedRefreshTokenReason;
+
+pub async fn auth_code_exists<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ "SELECT EXISTS(SELECT jti FROM auth_codes WHERE jti = ?) as `e: bool`",
+ jti
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn access_token_exists<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ "SELECT EXISTS(SELECT jti FROM access_tokens WHERE jti = ?) as `e: bool`",
+ jti
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn refresh_token_exists<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ "SELECT EXISTS(SELECT jti FROM refresh_tokens WHERE jti = ?) as `e: bool`",
+ jti
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn refresh_token_revoked<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+) -> Result<bool, RawUnexpected> {
+ let result = query_scalar!(
+ r"SELECT EXISTS(
+ SELECT revoked_reason FROM refresh_tokens WHERE jti = ? and revoked_reason IS NOT NULL
+ ) as `e: bool`",
+ jti
+ )
+ .fetch_one(executor)
+ .await?
+ .unwrap_or(true);
+
+ Ok(result)
+}
+
+pub async fn create_auth_code<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+ exp: DateTime<Utc>,
+) -> Result<(), sqlx::Error> {
+ query!(
+ r"INSERT INTO auth_codes (jti, exp)
+ VALUES ( ?, ?)",
+ jti,
+ exp
+ )
+ .execute(executor)
+ .await?;
+
+ Ok(())
+}
+
+pub async fn create_access_token<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+ auth_code: Option<Uuid>,
+ exp: DateTime<Utc>,
+) -> Result<(), sqlx::Error> {
+ query!(
+ r"INSERT INTO access_tokens (jti, auth_code, exp)
+ VALUES ( ?, ?, ?)",
+ jti,
+ auth_code,
+ exp
+ )
+ .execute(executor)
+ .await?;
+
+ Ok(())
+}
+
+pub async fn create_refresh_token<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+ auth_code: Option<Uuid>,
+ exp: DateTime<Utc>,
+) -> Result<(), sqlx::Error> {
+ query!(
+ r"INSERT INTO access_tokens (jti, auth_code, exp)
+ VALUES ( ?, ?, ?)",
+ jti,
+ auth_code,
+ exp
+ )
+ .execute(executor)
+ .await?;
+
+ Ok(())
+}
+
+pub async fn delete_auth_code<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ auth_code: Uuid,
+) -> Result<bool, RawUnexpected> {
+ let result = query!("DELETE FROM auth_codes WHERE jti = ?", auth_code)
+ .execute(executor)
+ .await?;
+
+ Ok(result.rows_affected() != 0)
+}
+
+pub async fn delete_expired_auth_codes<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+) -> Result<(), RawUnexpected> {
+ query!("DELETE FROM auth_codes WHERE exp < ?", Utc::now())
+ .execute(executor)
+ .await?;
+
+ Ok(())
+}
+
+pub async fn delete_access_tokens_with_auth_code<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ auth_code: Uuid,
+) -> Result<bool, RawUnexpected> {
+ let result = query!("DELETE FROM access_tokens WHERE auth_code = ?", auth_code)
+ .execute(executor)
+ .await?;
+
+ Ok(result.rows_affected() != 0)
+}
+
+pub async fn delete_expired_access_tokens<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+) -> Result<(), RawUnexpected> {
+ query!("DELETE FROM access_tokens WHERE exp < ?", Utc::now())
+ .execute(executor)
+ .await?;
+
+ Ok(())
+}
+
+pub async fn revoke_refresh_token<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+) -> Result<bool, RawUnexpected> {
+ let result = query!(
+ "UPDATE refresh_tokens SET revoked_reason = ? WHERE jti = ?",
+ RevokedRefreshTokenReason::NewRefreshToken,
+ jti
+ )
+ .execute(executor)
+ .await?;
+
+ Ok(result.rows_affected() != 0)
+}
+
+pub async fn revoke_refresh_tokens_with_auth_code<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ auth_code: Uuid,
+) -> Result<bool, RawUnexpected> {
+ let result = query!(
+ "UPDATE refresh_tokens SET revoked_reason = ? WHERE auth_code = ?",
+ RevokedRefreshTokenReason::ReusedAuthorizationCode,
+ auth_code
+ )
+ .execute(executor)
+ .await?;
+
+ Ok(result.rows_affected() != 0)
+}
+
+pub async fn delete_expired_refresh_tokens<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+) -> Result<(), RawUnexpected> {
+ query!("DELETE FROM refresh_tokens WHERE exp < ?", Utc::now())
+ .execute(executor)
+ .await?;
+
+ Ok(())
+}
diff --git a/src/services/db/user.rs b/src/services/db/user.rs
index 09a09da..f85047a 100644
--- a/src/services/db/user.rs
+++ b/src/services/db/user.rs
@@ -1,236 +1,236 @@
-use exun::RawUnexpected;
-use sqlx::{mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, MySql};
-use uuid::Uuid;
-
-use crate::{models::user::User, services::crypto::PasswordHash};
-
-struct UserRow {
- id: Uuid,
- username: String,
- password_hash: Vec<u8>,
- password_salt: Vec<u8>,
- password_version: u32,
-}
-
-impl TryFrom<UserRow> for User {
- type Error = RawUnexpected;
-
- fn try_from(row: UserRow) -> Result<Self, Self::Error> {
- let password = PasswordHash::from_fields(
- &row.password_hash,
- &row.password_salt,
- row.password_version as u8,
- );
- let user = User {
- id: row.id,
- username: row.username.into_boxed_str(),
- password,
- };
- Ok(user)
- }
-}
-
-/// Check if a user with a given user ID exists
-pub async fn user_id_exists<'c>(
- conn: impl Executor<'c, Database = MySql>,
- id: Uuid,
-) -> Result<bool, RawUnexpected> {
- let exists = query_scalar!(
- r#"SELECT EXISTS(SELECT id FROM users WHERE id = ?) as `e: bool`"#,
- id
- )
- .fetch_one(conn)
- .await?;
-
- Ok(exists)
-}
-
-/// Check if a given username is taken
-pub async fn username_is_used<'c>(
- conn: impl Executor<'c, Database = MySql>,
- username: &str,
-) -> Result<bool, RawUnexpected> {
- let exists = query_scalar!(
- r#"SELECT EXISTS(SELECT id FROM users WHERE username = ?) as "e: bool""#,
- username
- )
- .fetch_one(conn)
- .await?;
-
- Ok(exists)
-}
-
-/// Get a user from their ID
-pub async fn get_user<'c>(
- conn: impl Executor<'c, Database = MySql>,
- user_id: Uuid,
-) -> Result<Option<User>, RawUnexpected> {
- let record = query_as!(
- UserRow,
- r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
- FROM users WHERE id = ?",
- user_id
- )
- .fetch_optional(conn)
- .await?;
-
- let Some(record) = record else { return Ok(None) };
-
- Ok(Some(record.try_into()?))
-}
-
-/// Get a user from their username
-pub async fn get_user_by_username<'c>(
- conn: impl Executor<'c, Database = MySql>,
- username: &str,
-) -> Result<Option<User>, RawUnexpected> {
- let record = query_as!(
- UserRow,
- r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
- FROM users WHERE username = ?",
- username
- )
- .fetch_optional(conn)
- .await?;
-
- let Some(record) = record else { return Ok(None) };
-
- Ok(Some(record.try_into()?))
-}
-
-/// Search the list of users for a given username
-pub async fn search_users<'c>(
- conn: impl Executor<'c, Database = MySql>,
- username: &str,
-) -> Result<Box<[User]>, RawUnexpected> {
- let records = query_as!(
- UserRow,
- r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
- FROM users
- WHERE LOCATE(?, username) != 0",
- username,
- )
- .fetch_all(conn)
- .await?;
-
- Ok(records
- .into_iter()
- .map(|u| u.try_into())
- .collect::<Result<Box<[User]>, RawUnexpected>>()?)
-}
-
-/// Search the list of users, only returning a certain range of results
-pub async fn search_users_limit<'c>(
- conn: impl Executor<'c, Database = MySql>,
- username: &str,
- offset: u32,
- limit: u32,
-) -> Result<Box<[User]>, RawUnexpected> {
- let records = query_as!(
- UserRow,
- r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
- FROM users
- WHERE LOCATE(?, username) != 0
- LIMIT ?
- OFFSET ?",
- username,
- offset,
- limit
- )
- .fetch_all(conn)
- .await?;
-
- Ok(records
- .into_iter()
- .map(|u| u.try_into())
- .collect::<Result<Box<[User]>, RawUnexpected>>()?)
-}
-
-/// Get the username of a user with a certain ID
-pub async fn get_username<'c>(
- conn: impl Executor<'c, Database = MySql>,
- user_id: Uuid,
-) -> Result<Option<Box<str>>, RawUnexpected> {
- let username = query_scalar!(r"SELECT username FROM users where id = ?", user_id)
- .fetch_optional(conn)
- .await?
- .map(String::into_boxed_str);
-
- Ok(username)
-}
-
-/// Create a new user
-pub async fn create_user<'c>(
- conn: impl Executor<'c, Database = MySql>,
- user: &User,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- query!(
- r"INSERT INTO users (id, username, password_hash, password_salt, password_version)
- VALUES ( ?, ?, ?, ?, ?)",
- user.id,
- user.username(),
- user.password_hash(),
- user.password_salt(),
- user.password_version()
- )
- .execute(conn)
- .await
-}
-
-/// Update a user
-pub async fn update_user<'c>(
- conn: impl Executor<'c, Database = MySql>,
- user: &User,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- query!(
- r"UPDATE users SET
- username = ?,
- password_hash = ?,
- password_salt = ?,
- password_version = ?
- WHERE id = ?",
- user.username(),
- user.password_hash(),
- user.password_salt(),
- user.password_version(),
- user.id
- )
- .execute(conn)
- .await
-}
-
-/// Update the username of a user with the given ID
-pub async fn update_username<'c>(
- conn: impl Executor<'c, Database = MySql>,
- user_id: Uuid,
- username: &str,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- query!(
- r"UPDATE users SET username = ? WHERE id = ?",
- username,
- user_id
- )
- .execute(conn)
- .await
-}
-
-/// Update the password of a user with the given ID
-pub async fn update_password<'c>(
- conn: impl Executor<'c, Database = MySql>,
- user_id: Uuid,
- password: &PasswordHash,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- query!(
- r"UPDATE users SET
- password_hash = ?,
- password_salt = ?,
- password_version = ?
- WHERE id = ?",
- password.hash(),
- password.salt(),
- password.version(),
- user_id
- )
- .execute(conn)
- .await
-}
+use exun::RawUnexpected;
+use sqlx::{mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, MySql};
+use uuid::Uuid;
+
+use crate::{models::user::User, services::crypto::PasswordHash};
+
+struct UserRow {
+ id: Uuid,
+ username: String,
+ password_hash: Vec<u8>,
+ password_salt: Vec<u8>,
+ password_version: u32,
+}
+
+impl TryFrom<UserRow> for User {
+ type Error = RawUnexpected;
+
+ fn try_from(row: UserRow) -> Result<Self, Self::Error> {
+ let password = PasswordHash::from_fields(
+ &row.password_hash,
+ &row.password_salt,
+ row.password_version as u8,
+ );
+ let user = User {
+ id: row.id,
+ username: row.username.into_boxed_str(),
+ password,
+ };
+ Ok(user)
+ }
+}
+
+/// Check if a user with a given user ID exists
+pub async fn user_id_exists<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<bool, RawUnexpected> {
+ let exists = query_scalar!(
+ r#"SELECT EXISTS(SELECT id FROM users WHERE id = ?) as `e: bool`"#,
+ id
+ )
+ .fetch_one(conn)
+ .await?;
+
+ Ok(exists)
+}
+
+/// Check if a given username is taken
+pub async fn username_is_used<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ username: &str,
+) -> Result<bool, RawUnexpected> {
+ let exists = query_scalar!(
+ r#"SELECT EXISTS(SELECT id FROM users WHERE username = ?) as "e: bool""#,
+ username
+ )
+ .fetch_one(conn)
+ .await?;
+
+ Ok(exists)
+}
+
+/// Get a user from their ID
+pub async fn get_user<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user_id: Uuid,
+) -> Result<Option<User>, RawUnexpected> {
+ let record = query_as!(
+ UserRow,
+ r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
+ FROM users WHERE id = ?",
+ user_id
+ )
+ .fetch_optional(conn)
+ .await?;
+
+ let Some(record) = record else { return Ok(None) };
+
+ Ok(Some(record.try_into()?))
+}
+
+/// Get a user from their username
+pub async fn get_user_by_username<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ username: &str,
+) -> Result<Option<User>, RawUnexpected> {
+ let record = query_as!(
+ UserRow,
+ r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
+ FROM users WHERE username = ?",
+ username
+ )
+ .fetch_optional(conn)
+ .await?;
+
+ let Some(record) = record else { return Ok(None) };
+
+ Ok(Some(record.try_into()?))
+}
+
+/// Search the list of users for a given username
+pub async fn search_users<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ username: &str,
+) -> Result<Box<[User]>, RawUnexpected> {
+ let records = query_as!(
+ UserRow,
+ r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
+ FROM users
+ WHERE LOCATE(?, username) != 0",
+ username,
+ )
+ .fetch_all(conn)
+ .await?;
+
+ Ok(records
+ .into_iter()
+ .map(|u| u.try_into())
+ .collect::<Result<Box<[User]>, RawUnexpected>>()?)
+}
+
+/// Search the list of users, only returning a certain range of results
+pub async fn search_users_limit<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ username: &str,
+ offset: u32,
+ limit: u32,
+) -> Result<Box<[User]>, RawUnexpected> {
+ let records = query_as!(
+ UserRow,
+ r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
+ FROM users
+ WHERE LOCATE(?, username) != 0
+ LIMIT ?
+ OFFSET ?",
+ username,
+ offset,
+ limit
+ )
+ .fetch_all(conn)
+ .await?;
+
+ Ok(records
+ .into_iter()
+ .map(|u| u.try_into())
+ .collect::<Result<Box<[User]>, RawUnexpected>>()?)
+}
+
+/// Get the username of a user with a certain ID
+pub async fn get_username<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user_id: Uuid,
+) -> Result<Option<Box<str>>, RawUnexpected> {
+ let username = query_scalar!(r"SELECT username FROM users where id = ?", user_id)
+ .fetch_optional(conn)
+ .await?
+ .map(String::into_boxed_str);
+
+ Ok(username)
+}
+
+/// Create a new user
+pub async fn create_user<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user: &User,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ r"INSERT INTO users (id, username, password_hash, password_salt, password_version)
+ VALUES ( ?, ?, ?, ?, ?)",
+ user.id,
+ user.username(),
+ user.password_hash(),
+ user.password_salt(),
+ user.password_version()
+ )
+ .execute(conn)
+ .await
+}
+
+/// Update a user
+pub async fn update_user<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user: &User,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ r"UPDATE users SET
+ username = ?,
+ password_hash = ?,
+ password_salt = ?,
+ password_version = ?
+ WHERE id = ?",
+ user.username(),
+ user.password_hash(),
+ user.password_salt(),
+ user.password_version(),
+ user.id
+ )
+ .execute(conn)
+ .await
+}
+
+/// Update the username of a user with the given ID
+pub async fn update_username<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user_id: Uuid,
+ username: &str,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ r"UPDATE users SET username = ? WHERE id = ?",
+ username,
+ user_id
+ )
+ .execute(conn)
+ .await
+}
+
+/// Update the password of a user with the given ID
+pub async fn update_password<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user_id: Uuid,
+ password: &PasswordHash,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ r"UPDATE users SET
+ password_hash = ?,
+ password_salt = ?,
+ password_version = ?
+ WHERE id = ?",
+ password.hash(),
+ password.salt(),
+ password.version(),
+ user_id
+ )
+ .execute(conn)
+ .await
+}
diff --git a/src/services/id.rs b/src/services/id.rs
index 0c665ed..e1227e4 100644
--- a/src/services/id.rs
+++ b/src/services/id.rs
@@ -1,27 +1,27 @@
-use std::future::Future;
-
-use exun::RawUnexpected;
-use sqlx::{Executor, MySql};
-use uuid::Uuid;
-
-/// Create a unique id, handling duplicate ID's.
-///
-/// The given `unique_check` parameter returns `true` if the ID is used and
-/// `false` otherwise.
-pub async fn new_id<
- 'c,
- E: Executor<'c, Database = MySql> + Clone,
- F: Future<Output = Result<bool, RawUnexpected>>,
->(
- conn: E,
- unique_check: impl Fn(E, Uuid) -> F,
-) -> Result<Uuid, RawUnexpected> {
- let uuid = loop {
- let uuid = Uuid::new_v4();
- if !unique_check(conn.clone(), uuid).await? {
- break uuid;
- }
- };
-
- Ok(uuid)
-}
+use std::future::Future;
+
+use exun::RawUnexpected;
+use sqlx::{Executor, MySql};
+use uuid::Uuid;
+
+/// Create a unique id, handling duplicate ID's.
+///
+/// The given `unique_check` parameter returns `true` if the ID is used and
+/// `false` otherwise.
+pub async fn new_id<
+ 'c,
+ E: Executor<'c, Database = MySql> + Clone,
+ F: Future<Output = Result<bool, RawUnexpected>>,
+>(
+ conn: E,
+ unique_check: impl Fn(E, Uuid) -> F,
+) -> Result<Uuid, RawUnexpected> {
+ let uuid = loop {
+ let uuid = Uuid::new_v4();
+ if !unique_check(conn.clone(), uuid).await? {
+ break uuid;
+ }
+ };
+
+ Ok(uuid)
+}
diff --git a/src/services/jwt.rs b/src/services/jwt.rs
index 16f5fa6..863eb83 100644
--- a/src/services/jwt.rs
+++ b/src/services/jwt.rs
@@ -1,291 +1,291 @@
-use chrono::{serde::ts_milliseconds, serde::ts_milliseconds_option, DateTime, Duration, Utc};
-use exun::{Expect, RawUnexpected, ResultErrorExt};
-use jwt::{SignWithKey, VerifyWithKey};
-use raise::yeet;
-use serde::{Deserialize, Serialize};
-use sqlx::{Executor, MySql, MySqlPool};
-use thiserror::Error;
-use url::Url;
-use uuid::Uuid;
-
-use super::{db, id::new_id, secrets};
-
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
-pub enum TokenType {
- Authorization,
- Access,
- Refresh,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct Claims {
- iss: Url,
- sub: Uuid,
- aud: Box<[String]>,
- #[serde(with = "ts_milliseconds")]
- exp: DateTime<Utc>,
- #[serde(with = "ts_milliseconds_option")]
- nbf: Option<DateTime<Utc>>,
- #[serde(with = "ts_milliseconds")]
- iat: DateTime<Utc>,
- jti: Uuid,
- scope: Box<str>,
- client_id: Uuid,
- token_type: TokenType,
- auth_code_id: Option<Uuid>,
- redirect_uri: Option<Url>,
-}
-
-#[derive(Debug, Clone, Copy, sqlx::Type)]
-#[sqlx(rename_all = "kebab-case")]
-pub enum RevokedRefreshTokenReason {
- ReusedAuthorizationCode,
- NewRefreshToken,
-}
-
-impl Claims {
- pub async fn auth_code<'c>(
- db: &MySqlPool,
- self_id: Url,
- client_id: Uuid,
- sub: Uuid,
- scopes: &str,
- redirect_uri: &Url,
- ) -> Result<Self, RawUnexpected> {
- let five_minutes = Duration::minutes(5);
-
- let id = new_id(db, db::auth_code_exists).await?;
- let iat = Utc::now();
- let exp = iat + five_minutes;
-
- db::create_auth_code(db, id, exp).await?;
-
- let aud = [self_id.to_string(), client_id.to_string()].into();
-
- Ok(Self {
- iss: self_id,
- sub,
- aud,
- exp,
- nbf: None,
- iat,
- jti: id,
- scope: scopes.into(),
- client_id,
- auth_code_id: Some(id),
- token_type: TokenType::Authorization,
- redirect_uri: Some(redirect_uri.clone()),
- })
- }
-
- pub async fn access_token<'c>(
- db: &MySqlPool,
- auth_code_id: Option<Uuid>,
- self_id: Url,
- client_id: Uuid,
- sub: Uuid,
- duration: Duration,
- scopes: &str,
- ) -> Result<Self, RawUnexpected> {
- let id = new_id(db, db::access_token_exists).await?;
- let iat = Utc::now();
- let exp = iat + duration;
-
- db::create_access_token(db, id, auth_code_id, exp)
- .await
- .unexpect()?;
-
- let aud = [self_id.to_string(), client_id.to_string()].into();
-
- Ok(Self {
- iss: self_id,
- sub,
- aud,
- exp,
- nbf: None,
- iat,
- jti: id,
- scope: scopes.into(),
- client_id,
- auth_code_id,
- token_type: TokenType::Access,
- redirect_uri: None,
- })
- }
-
- pub async fn refresh_token(
- db: &MySqlPool,
- other_token: &Claims,
- ) -> Result<Self, RawUnexpected> {
- let one_day = Duration::days(1);
-
- let id = new_id(db, db::refresh_token_exists).await?;
- let iat = Utc::now();
- let exp = other_token.exp + one_day;
-
- db::create_refresh_token(db, id, other_token.auth_code_id, exp).await?;
-
- let mut claims = other_token.clone();
- claims.exp = exp;
- claims.iat = iat;
- claims.jti = id;
- claims.token_type = TokenType::Refresh;
-
- Ok(claims)
- }
-
- pub async fn refreshed_access_token(
- db: &MySqlPool,
- refresh_token: &Claims,
- exp_time: Duration,
- ) -> Result<Self, RawUnexpected> {
- let id = new_id(db, db::access_token_exists).await?;
- let iat = Utc::now();
- let exp = iat + exp_time;
-
- db::create_access_token(db, id, refresh_token.auth_code_id, exp).await?;
-
- let mut claims = refresh_token.clone();
- claims.exp = exp;
- claims.iat = iat;
- claims.jti = id;
- claims.token_type = TokenType::Access;
-
- Ok(claims)
- }
-
- pub fn id(&self) -> Uuid {
- self.jti
- }
-
- pub fn subject(&self) -> Uuid {
- self.sub
- }
-
- pub fn expires_in(&self) -> i64 {
- (self.exp - Utc::now()).num_seconds()
- }
-
- pub fn scopes(&self) -> &str {
- &self.scope
- }
-
- pub fn to_jwt(&self) -> Result<Box<str>, RawUnexpected> {
- let key = secrets::signing_key()?;
- let jwt = self.sign_with_key(&key)?.into_boxed_str();
- Ok(jwt)
- }
-}
-
-#[derive(Debug, Error)]
-pub enum VerifyJwtError {
- #[error("{0}")]
- ParseJwtError(#[from] jwt::Error),
- #[error("The issuer for this token is incorrect")]
- IncorrectIssuer,
- #[error("This bearer token was intended for a different client")]
- WrongClient,
- #[error("The given audience parameter does not contain this issuer")]
- BadAudience,
- #[error("The redirect URI doesn't match what's in the token")]
- IncorrectRedirectUri,
- #[error("The token is expired")]
- ExpiredToken,
- #[error("The token cannot be used yet")]
- NotYet,
- #[error("The bearer token has been revoked")]
- JwtRevoked,
-}
-
-fn verify_jwt(
- token: &str,
- self_id: &Url,
- client_id: Option<Uuid>,
-) -> Result<Claims, Expect<VerifyJwtError>> {
- let key = secrets::signing_key()?;
- let claims: Claims = token
- .verify_with_key(&key)
- .map_err(|e| VerifyJwtError::from(e))?;
-
- if &claims.iss != self_id {
- yeet!(VerifyJwtError::IncorrectIssuer.into())
- }
-
- if let Some(client_id) = client_id {
- if claims.client_id != client_id {
- yeet!(VerifyJwtError::WrongClient.into())
- }
- }
-
- if !claims.aud.contains(&self_id.to_string()) {
- yeet!(VerifyJwtError::BadAudience.into())
- }
-
- let now = Utc::now();
-
- if now > claims.exp {
- yeet!(VerifyJwtError::ExpiredToken.into())
- }
-
- if let Some(nbf) = claims.nbf {
- if now < nbf {
- yeet!(VerifyJwtError::NotYet.into())
- }
- }
-
- Ok(claims)
-}
-
-pub async fn verify_auth_code<'c>(
- db: &MySqlPool,
- token: &str,
- self_id: &Url,
- client_id: Uuid,
- redirect_uri: Url,
-) -> Result<Claims, Expect<VerifyJwtError>> {
- let claims = verify_jwt(token, self_id, Some(client_id))?;
-
- if let Some(claimed_uri) = &claims.redirect_uri {
- if claimed_uri.clone() != redirect_uri {
- yeet!(VerifyJwtError::IncorrectRedirectUri.into());
- }
- }
-
- if db::delete_auth_code(db, claims.jti).await? {
- db::delete_access_tokens_with_auth_code(db, claims.jti).await?;
- db::revoke_refresh_tokens_with_auth_code(db, claims.jti).await?;
- yeet!(VerifyJwtError::JwtRevoked.into());
- }
-
- Ok(claims)
-}
-
-pub async fn verify_access_token<'c>(
- db: impl Executor<'c, Database = MySql>,
- token: &str,
- self_id: &Url,
- client_id: Uuid,
-) -> Result<Claims, Expect<VerifyJwtError>> {
- let claims = verify_jwt(token, self_id, Some(client_id))?;
-
- if !db::access_token_exists(db, claims.jti).await? {
- yeet!(VerifyJwtError::JwtRevoked.into())
- }
-
- Ok(claims)
-}
-
-pub async fn verify_refresh_token<'c>(
- db: impl Executor<'c, Database = MySql>,
- token: &str,
- self_id: &Url,
- client_id: Option<Uuid>,
-) -> Result<Claims, Expect<VerifyJwtError>> {
- let claims = verify_jwt(token, self_id, client_id)?;
-
- if db::refresh_token_revoked(db, claims.jti).await? {
- yeet!(VerifyJwtError::JwtRevoked.into())
- }
-
- Ok(claims)
-}
+use chrono::{serde::ts_milliseconds, serde::ts_milliseconds_option, DateTime, Duration, Utc};
+use exun::{Expect, RawUnexpected, ResultErrorExt};
+use jwt::{SignWithKey, VerifyWithKey};
+use raise::yeet;
+use serde::{Deserialize, Serialize};
+use sqlx::{Executor, MySql, MySqlPool};
+use thiserror::Error;
+use url::Url;
+use uuid::Uuid;
+
+use super::{db, id::new_id, secrets};
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
+pub enum TokenType {
+ Authorization,
+ Access,
+ Refresh,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Claims {
+ iss: Url,
+ sub: Uuid,
+ aud: Box<[String]>,
+ #[serde(with = "ts_milliseconds")]
+ exp: DateTime<Utc>,
+ #[serde(with = "ts_milliseconds_option")]
+ nbf: Option<DateTime<Utc>>,
+ #[serde(with = "ts_milliseconds")]
+ iat: DateTime<Utc>,
+ jti: Uuid,
+ scope: Box<str>,
+ client_id: Uuid,
+ token_type: TokenType,
+ auth_code_id: Option<Uuid>,
+ redirect_uri: Option<Url>,
+}
+
+#[derive(Debug, Clone, Copy, sqlx::Type)]
+#[sqlx(rename_all = "kebab-case")]
+pub enum RevokedRefreshTokenReason {
+ ReusedAuthorizationCode,
+ NewRefreshToken,
+}
+
+impl Claims {
+ pub async fn auth_code<'c>(
+ db: &MySqlPool,
+ self_id: Url,
+ client_id: Uuid,
+ sub: Uuid,
+ scopes: &str,
+ redirect_uri: &Url,
+ ) -> Result<Self, RawUnexpected> {
+ let five_minutes = Duration::minutes(5);
+
+ let id = new_id(db, db::auth_code_exists).await?;
+ let iat = Utc::now();
+ let exp = iat + five_minutes;
+
+ db::create_auth_code(db, id, exp).await?;
+
+ let aud = [self_id.to_string(), client_id.to_string()].into();
+
+ Ok(Self {
+ iss: self_id,
+ sub,
+ aud,
+ exp,
+ nbf: None,
+ iat,
+ jti: id,
+ scope: scopes.into(),
+ client_id,
+ auth_code_id: Some(id),
+ token_type: TokenType::Authorization,
+ redirect_uri: Some(redirect_uri.clone()),
+ })
+ }
+
+ pub async fn access_token<'c>(
+ db: &MySqlPool,
+ auth_code_id: Option<Uuid>,
+ self_id: Url,
+ client_id: Uuid,
+ sub: Uuid,
+ duration: Duration,
+ scopes: &str,
+ ) -> Result<Self, RawUnexpected> {
+ let id = new_id(db, db::access_token_exists).await?;
+ let iat = Utc::now();
+ let exp = iat + duration;
+
+ db::create_access_token(db, id, auth_code_id, exp)
+ .await
+ .unexpect()?;
+
+ let aud = [self_id.to_string(), client_id.to_string()].into();
+
+ Ok(Self {
+ iss: self_id,
+ sub,
+ aud,
+ exp,
+ nbf: None,
+ iat,
+ jti: id,
+ scope: scopes.into(),
+ client_id,
+ auth_code_id,
+ token_type: TokenType::Access,
+ redirect_uri: None,
+ })
+ }
+
+ pub async fn refresh_token(
+ db: &MySqlPool,
+ other_token: &Claims,
+ ) -> Result<Self, RawUnexpected> {
+ let one_day = Duration::days(1);
+
+ let id = new_id(db, db::refresh_token_exists).await?;
+ let iat = Utc::now();
+ let exp = other_token.exp + one_day;
+
+ db::create_refresh_token(db, id, other_token.auth_code_id, exp).await?;
+
+ let mut claims = other_token.clone();
+ claims.exp = exp;
+ claims.iat = iat;
+ claims.jti = id;
+ claims.token_type = TokenType::Refresh;
+
+ Ok(claims)
+ }
+
+ pub async fn refreshed_access_token(
+ db: &MySqlPool,
+ refresh_token: &Claims,
+ exp_time: Duration,
+ ) -> Result<Self, RawUnexpected> {
+ let id = new_id(db, db::access_token_exists).await?;
+ let iat = Utc::now();
+ let exp = iat + exp_time;
+
+ db::create_access_token(db, id, refresh_token.auth_code_id, exp).await?;
+
+ let mut claims = refresh_token.clone();
+ claims.exp = exp;
+ claims.iat = iat;
+ claims.jti = id;
+ claims.token_type = TokenType::Access;
+
+ Ok(claims)
+ }
+
+ pub fn id(&self) -> Uuid {
+ self.jti
+ }
+
+ pub fn subject(&self) -> Uuid {
+ self.sub
+ }
+
+ pub fn expires_in(&self) -> i64 {
+ (self.exp - Utc::now()).num_seconds()
+ }
+
+ pub fn scopes(&self) -> &str {
+ &self.scope
+ }
+
+ pub fn to_jwt(&self) -> Result<Box<str>, RawUnexpected> {
+ let key = secrets::signing_key()?;
+ let jwt = self.sign_with_key(&key)?.into_boxed_str();
+ Ok(jwt)
+ }
+}
+
+#[derive(Debug, Error)]
+pub enum VerifyJwtError {
+ #[error("{0}")]
+ ParseJwtError(#[from] jwt::Error),
+ #[error("The issuer for this token is incorrect")]
+ IncorrectIssuer,
+ #[error("This bearer token was intended for a different client")]
+ WrongClient,
+ #[error("The given audience parameter does not contain this issuer")]
+ BadAudience,
+ #[error("The redirect URI doesn't match what's in the token")]
+ IncorrectRedirectUri,
+ #[error("The token is expired")]
+ ExpiredToken,
+ #[error("The token cannot be used yet")]
+ NotYet,
+ #[error("The bearer token has been revoked")]
+ JwtRevoked,
+}
+
+fn verify_jwt(
+ token: &str,
+ self_id: &Url,
+ client_id: Option<Uuid>,
+) -> Result<Claims, Expect<VerifyJwtError>> {
+ let key = secrets::signing_key()?;
+ let claims: Claims = token
+ .verify_with_key(&key)
+ .map_err(|e| VerifyJwtError::from(e))?;
+
+ if &claims.iss != self_id {
+ yeet!(VerifyJwtError::IncorrectIssuer.into())
+ }
+
+ if let Some(client_id) = client_id {
+ if claims.client_id != client_id {
+ yeet!(VerifyJwtError::WrongClient.into())
+ }
+ }
+
+ if !claims.aud.contains(&self_id.to_string()) {
+ yeet!(VerifyJwtError::BadAudience.into())
+ }
+
+ let now = Utc::now();
+
+ if now > claims.exp {
+ yeet!(VerifyJwtError::ExpiredToken.into())
+ }
+
+ if let Some(nbf) = claims.nbf {
+ if now < nbf {
+ yeet!(VerifyJwtError::NotYet.into())
+ }
+ }
+
+ Ok(claims)
+}
+
+pub async fn verify_auth_code<'c>(
+ db: &MySqlPool,
+ token: &str,
+ self_id: &Url,
+ client_id: Uuid,
+ redirect_uri: Url,
+) -> Result<Claims, Expect<VerifyJwtError>> {
+ let claims = verify_jwt(token, self_id, Some(client_id))?;
+
+ if let Some(claimed_uri) = &claims.redirect_uri {
+ if claimed_uri.clone() != redirect_uri {
+ yeet!(VerifyJwtError::IncorrectRedirectUri.into());
+ }
+ }
+
+ if db::delete_auth_code(db, claims.jti).await? {
+ db::delete_access_tokens_with_auth_code(db, claims.jti).await?;
+ db::revoke_refresh_tokens_with_auth_code(db, claims.jti).await?;
+ yeet!(VerifyJwtError::JwtRevoked.into());
+ }
+
+ Ok(claims)
+}
+
+pub async fn verify_access_token<'c>(
+ db: impl Executor<'c, Database = MySql>,
+ token: &str,
+ self_id: &Url,
+ client_id: Uuid,
+) -> Result<Claims, Expect<VerifyJwtError>> {
+ let claims = verify_jwt(token, self_id, Some(client_id))?;
+
+ if !db::access_token_exists(db, claims.jti).await? {
+ yeet!(VerifyJwtError::JwtRevoked.into())
+ }
+
+ Ok(claims)
+}
+
+pub async fn verify_refresh_token<'c>(
+ db: impl Executor<'c, Database = MySql>,
+ token: &str,
+ self_id: &Url,
+ client_id: Option<Uuid>,
+) -> Result<Claims, Expect<VerifyJwtError>> {
+ let claims = verify_jwt(token, self_id, client_id)?;
+
+ if db::refresh_token_revoked(db, claims.jti).await? {
+ yeet!(VerifyJwtError::JwtRevoked.into())
+ }
+
+ Ok(claims)
+}
diff --git a/src/services/mod.rs b/src/services/mod.rs
index de08b58..4c69367 100644
--- a/src/services/mod.rs
+++ b/src/services/mod.rs
@@ -1,7 +1,7 @@
-pub mod authorization;
-pub mod config;
-pub mod crypto;
-pub mod db;
-pub mod id;
-pub mod jwt;
-pub mod secrets;
+pub mod authorization;
+pub mod config;
+pub mod crypto;
+pub mod db;
+pub mod id;
+pub mod jwt;
+pub mod secrets;
diff --git a/src/services/secrets.rs b/src/services/secrets.rs
index 241b2c5..e1d4992 100644
--- a/src/services/secrets.rs
+++ b/src/services/secrets.rs
@@ -1,24 +1,24 @@
-use std::env;
-
-use exun::*;
-use hmac::{Hmac, Mac};
-use sha2::Sha256;
-
-/// This is a secret salt, needed for creating passwords. It's used as an extra
-/// layer of security, on top of the salt that's already used.
-pub fn pepper() -> Result<Box<[u8]>, RawUnexpected> {
- let pepper = env::var("SECRET_SALT")?;
- let pepper = hex::decode(pepper)?;
- Ok(pepper.into_boxed_slice())
-}
-
-/// The URL to the MySQL database
-pub fn database_url() -> Result<String, RawUnexpected> {
- env::var("DATABASE_URL").unexpect()
-}
-
-pub fn signing_key() -> Result<Hmac<Sha256>, RawUnexpected> {
- let key = env::var("PRIVATE_KEY")?;
- let key = Hmac::<Sha256>::new_from_slice(key.as_bytes())?;
- Ok(key)
-}
+use std::env;
+
+use exun::*;
+use hmac::{Hmac, Mac};
+use sha2::Sha256;
+
+/// This is a secret salt, needed for creating passwords. It's used as an extra
+/// layer of security, on top of the salt that's already used.
+pub fn pepper() -> Result<Box<[u8]>, RawUnexpected> {
+ let pepper = env::var("SECRET_SALT")?;
+ let pepper = hex::decode(pepper)?;
+ Ok(pepper.into_boxed_slice())
+}
+
+/// The URL to the MySQL database
+pub fn database_url() -> Result<String, RawUnexpected> {
+ env::var("DATABASE_URL").unexpect()
+}
+
+pub fn signing_key() -> Result<Hmac<Sha256>, RawUnexpected> {
+ let key = env::var("PRIVATE_KEY")?;
+ let key = Hmac::<Sha256>::new_from_slice(key.as_bytes())?;
+ Ok(key)
+}