summaryrefslogtreecommitdiff
path: root/src/api
diff options
context:
space:
mode:
Diffstat (limited to 'src/api')
-rw-r--r--src/api/clients.rs966
-rw-r--r--src/api/liveops.rs22
-rw-r--r--src/api/mod.rs26
-rw-r--r--src/api/oauth.rs1852
-rw-r--r--src/api/ops.rs140
-rw-r--r--src/api/users.rs544
6 files changed, 1775 insertions, 1775 deletions
diff --git a/src/api/clients.rs b/src/api/clients.rs
index 3f906bb..ded8b81 100644
--- a/src/api/clients.rs
+++ b/src/api/clients.rs
@@ -1,483 +1,483 @@
-use actix_web::http::{header, StatusCode};
-use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope};
-use raise::yeet;
-use serde::{Deserialize, Serialize};
-use sqlx::MySqlPool;
-use thiserror::Error;
-use url::Url;
-use uuid::Uuid;
-
-use crate::models::client::{Client, ClientType, CreateClientError};
-use crate::services::crypto::PasswordHash;
-use crate::services::db::ClientRow;
-use crate::services::{db, id};
-
-#[derive(Debug, Clone, Serialize)]
-#[serde(rename_all = "camelCase")]
-struct ClientResponse {
- client_id: Uuid,
- alias: Box<str>,
- client_type: ClientType,
- allowed_scopes: Box<[Box<str>]>,
- default_scopes: Option<Box<[Box<str>]>>,
- is_trusted: bool,
-}
-
-impl From<ClientRow> for ClientResponse {
- fn from(value: ClientRow) -> Self {
- Self {
- client_id: value.id,
- alias: value.alias.into_boxed_str(),
- client_type: value.client_type,
- allowed_scopes: value
- .allowed_scopes
- .split_whitespace()
- .map(Box::from)
- .collect(),
- default_scopes: value
- .default_scopes
- .map(|s| s.split_whitespace().map(Box::from).collect()),
- is_trusted: value.is_trusted,
- }
- }
-}
-
-#[derive(Debug, Clone, Copy, Error)]
-#[error("No client with the given client ID was found")]
-struct ClientNotFound {
- id: Uuid,
-}
-
-impl ResponseError for ClientNotFound {
- fn status_code(&self) -> StatusCode {
- StatusCode::NOT_FOUND
- }
-}
-
-impl ClientNotFound {
- fn new(id: Uuid) -> Self {
- Self { id }
- }
-}
-
-#[get("/{client_id}")]
-async fn get_client(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- let Some(client) = db::get_client_response(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id))
- };
-
- let redirect_uris_link = format!("</clients/{client_id}/redirect-uris>; rel=\"redirect-uris\"");
- let response: ClientResponse = client.into();
- let response = HttpResponse::Ok()
- .append_header((header::LINK, redirect_uris_link))
- .json(response);
- Ok(response)
-}
-
-#[get("/{client_id}/alias")]
-async fn get_client_alias(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- let Some(alias) = db::get_client_alias(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id))
- };
-
- Ok(HttpResponse::Ok().json(alias))
-}
-
-#[get("/{client_id}/type")]
-async fn get_client_type(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- let Some(client_type) = db::get_client_type(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id))
- };
-
- Ok(HttpResponse::Ok().json(client_type))
-}
-
-#[get("/{client_id}/redirect-uris")]
-async fn get_client_redirect_uris(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- if !db::client_id_exists(db, id).await.unwrap() {
- yeet!(ClientNotFound::new(id))
- };
-
- let redirect_uris = db::get_client_redirect_uris(db, id).await.unwrap();
-
- Ok(HttpResponse::Ok().json(redirect_uris))
-}
-
-#[get("/{client_id}/allowed-scopes")]
-async fn get_client_allowed_scopes(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- let Some(allowed_scopes) = db::get_client_allowed_scopes(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id))
- };
-
- let allowed_scopes = allowed_scopes.split_whitespace().collect::<Box<[&str]>>();
-
- Ok(HttpResponse::Ok().json(allowed_scopes))
-}
-
-#[get("/{client_id}/default-scopes")]
-async fn get_client_default_scopes(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- let Some(default_scopes) = db::get_client_default_scopes(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id))
- };
-
- let default_scopes = default_scopes.map(|scopes| {
- scopes
- .split_whitespace()
- .map(Box::from)
- .collect::<Box<[Box<str>]>>()
- });
-
- Ok(HttpResponse::Ok().json(default_scopes))
-}
-
-#[get("/{client_id}/is-trusted")]
-async fn get_client_is_trusted(
- client_id: web::Path<Uuid>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, ClientNotFound> {
- let db = db.as_ref();
- let id = *client_id;
-
- let Some(is_trusted) = db::is_client_trusted(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id))
- };
-
- Ok(HttpResponse::Ok().json(is_trusted))
-}
-
-#[derive(Clone, Deserialize)]
-#[serde(rename_all = "camelCase")]
-struct ClientRequest {
- alias: Box<str>,
- ty: ClientType,
- redirect_uris: Box<[Url]>,
- secret: Option<Box<str>>,
- allowed_scopes: Box<[Box<str>]>,
- default_scopes: Option<Box<[Box<str>]>>,
- trusted: bool,
-}
-
-#[derive(Debug, Clone, Error)]
-#[error("The given client alias is already taken")]
-struct AliasTakenError {
- alias: Box<str>,
-}
-
-impl ResponseError for AliasTakenError {
- fn status_code(&self) -> StatusCode {
- StatusCode::CONFLICT
- }
-}
-
-impl AliasTakenError {
- fn new(alias: &str) -> Self {
- Self {
- alias: Box::from(alias),
- }
- }
-}
-
-#[post("")]
-async fn create_client(
- body: web::Json<ClientRequest>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let alias = &body.alias;
-
- if db::client_alias_exists(db, &alias).await.unwrap() {
- yeet!(AliasTakenError::new(&alias).into());
- }
-
- let id = id::new_id(db, db::client_id_exists).await.unwrap();
- let client = Client::new(
- id,
- &alias,
- body.ty,
- body.secret.as_deref(),
- body.allowed_scopes.clone(),
- body.default_scopes.clone(),
- &body.redirect_uris,
- body.trusted,
- )
- .map_err(|e| e.unwrap())?;
-
- let transaction = db.begin().await.unwrap();
- db::create_client(transaction, &client).await.unwrap();
-
- let response = HttpResponse::Created()
- .insert_header((header::LOCATION, format!("clients/{id}")))
- .finish();
- Ok(response)
-}
-
-#[derive(Debug, Clone, Error)]
-enum UpdateClientError {
- #[error(transparent)]
- NotFound(#[from] ClientNotFound),
- #[error(transparent)]
- ClientError(#[from] CreateClientError),
- #[error(transparent)]
- AliasTaken(#[from] AliasTakenError),
-}
-
-impl ResponseError for UpdateClientError {
- fn status_code(&self) -> StatusCode {
- match self {
- Self::NotFound(e) => e.status_code(),
- Self::ClientError(e) => e.status_code(),
- Self::AliasTaken(e) => e.status_code(),
- }
- }
-}
-
-#[put("/{id}")]
-async fn update_client(
- id: web::Path<Uuid>,
- body: web::Json<ClientRequest>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
- let alias = &body.alias;
-
- let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id).into())
- };
- if old_alias != alias.clone() && db::client_alias_exists(db, &alias).await.unwrap() {
- yeet!(AliasTakenError::new(&alias).into());
- }
-
- let client = Client::new(
- id,
- &alias,
- body.ty,
- body.secret.as_deref(),
- body.allowed_scopes.clone(),
- body.default_scopes.clone(),
- &body.redirect_uris,
- body.trusted,
- )
- .map_err(|e| e.unwrap())?;
-
- let transaction = db.begin().await.unwrap();
- db::update_client(transaction, &client).await.unwrap();
-
- let response = HttpResponse::NoContent().finish();
- Ok(response)
-}
-
-#[put("/{id}/alias")]
-async fn update_client_alias(
- id: web::Path<Uuid>,
- body: web::Json<Box<str>>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
- let alias = body.0;
-
- let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id).into())
- };
- if old_alias == alias {
- return Ok(HttpResponse::NoContent().finish());
- }
- if db::client_alias_exists(db, &alias).await.unwrap() {
- yeet!(AliasTakenError::new(&alias).into());
- }
-
- db::update_client_alias(db, id, &alias).await.unwrap();
-
- let response = HttpResponse::NoContent().finish();
- Ok(response)
-}
-
-#[put("/{id}/type")]
-async fn update_client_type(
- id: web::Path<Uuid>,
- body: web::Json<ClientType>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
- let ty = body.0;
-
- if !db::client_id_exists(db, id).await.unwrap() {
- yeet!(ClientNotFound::new(id).into());
- }
-
- db::update_client_type(db, id, ty).await.unwrap();
-
- Ok(HttpResponse::NoContent().finish())
-}
-
-#[put("/{id}/allowed-scopes")]
-async fn update_client_allowed_scopes(
- id: web::Path<Uuid>,
- body: web::Json<Box<[Box<str>]>>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
- let allowed_scopes = body.0.join(" ");
-
- if !db::client_id_exists(db, id).await.unwrap() {
- yeet!(ClientNotFound::new(id).into());
- }
-
- db::update_client_allowed_scopes(db, id, &allowed_scopes)
- .await
- .unwrap();
-
- Ok(HttpResponse::NoContent().finish())
-}
-
-#[put("/{id}/default-scopes")]
-async fn update_client_default_scopes(
- id: web::Path<Uuid>,
- body: web::Json<Option<Box<[Box<str>]>>>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
- let default_scopes = body.0.map(|s| s.join(" "));
-
- if !db::client_id_exists(db, id).await.unwrap() {
- yeet!(ClientNotFound::new(id).into());
- }
-
- db::update_client_default_scopes(db, id, default_scopes)
- .await
- .unwrap();
-
- Ok(HttpResponse::NoContent().finish())
-}
-
-#[put("/{id}/is-trusted")]
-async fn update_client_is_trusted(
- id: web::Path<Uuid>,
- body: web::Json<bool>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
- let is_trusted = *body;
-
- if !db::client_id_exists(db, id).await.unwrap() {
- yeet!(ClientNotFound::new(id).into());
- }
-
- db::update_client_trusted(db, id, is_trusted).await.unwrap();
-
- Ok(HttpResponse::NoContent().finish())
-}
-
-#[put("/{id}/redirect-uris")]
-async fn update_client_redirect_uris(
- id: web::Path<Uuid>,
- body: web::Json<Box<[Url]>>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
-
- for uri in body.0.iter() {
- if uri.scheme() != "https" {
- yeet!(CreateClientError::NonHttpsUri.into());
- }
-
- if uri.fragment().is_some() {
- yeet!(CreateClientError::UriFragment.into())
- }
- }
-
- if !db::client_id_exists(db, id).await.unwrap() {
- yeet!(ClientNotFound::new(id).into());
- }
-
- let transaction = db.begin().await.unwrap();
- db::update_client_redirect_uris(transaction, id, &body.0)
- .await
- .unwrap();
-
- Ok(HttpResponse::NoContent().finish())
-}
-
-#[put("{id}/secret")]
-async fn update_client_secret(
- id: web::Path<Uuid>,
- body: web::Json<Option<Box<str>>>,
- db: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateClientError> {
- let db = db.get_ref();
- let id = *id;
-
- let Some(client_type) = db::get_client_type(db, id).await.unwrap() else {
- yeet!(ClientNotFound::new(id).into())
- };
-
- if client_type == ClientType::Confidential && body.is_none() {
- yeet!(CreateClientError::NoSecret.into())
- }
-
- let secret = body.0.map(|s| PasswordHash::new(&s).unwrap());
- db::update_client_secret(db, id, secret).await.unwrap();
-
- Ok(HttpResponse::NoContent().finish())
-}
-
-pub fn service() -> Scope {
- web::scope("/clients")
- .service(get_client)
- .service(get_client_alias)
- .service(get_client_type)
- .service(get_client_allowed_scopes)
- .service(get_client_default_scopes)
- .service(get_client_redirect_uris)
- .service(get_client_is_trusted)
- .service(create_client)
- .service(update_client)
- .service(update_client_alias)
- .service(update_client_type)
- .service(update_client_allowed_scopes)
- .service(update_client_default_scopes)
- .service(update_client_redirect_uris)
- .service(update_client_secret)
-}
+use actix_web::http::{header, StatusCode};
+use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope};
+use raise::yeet;
+use serde::{Deserialize, Serialize};
+use sqlx::MySqlPool;
+use thiserror::Error;
+use url::Url;
+use uuid::Uuid;
+
+use crate::models::client::{Client, ClientType, CreateClientError};
+use crate::services::crypto::PasswordHash;
+use crate::services::db::ClientRow;
+use crate::services::{db, id};
+
+#[derive(Debug, Clone, Serialize)]
+#[serde(rename_all = "camelCase")]
+struct ClientResponse {
+ client_id: Uuid,
+ alias: Box<str>,
+ client_type: ClientType,
+ allowed_scopes: Box<[Box<str>]>,
+ default_scopes: Option<Box<[Box<str>]>>,
+ is_trusted: bool,
+}
+
+impl From<ClientRow> for ClientResponse {
+ fn from(value: ClientRow) -> Self {
+ Self {
+ client_id: value.id,
+ alias: value.alias.into_boxed_str(),
+ client_type: value.client_type,
+ allowed_scopes: value
+ .allowed_scopes
+ .split_whitespace()
+ .map(Box::from)
+ .collect(),
+ default_scopes: value
+ .default_scopes
+ .map(|s| s.split_whitespace().map(Box::from).collect()),
+ is_trusted: value.is_trusted,
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, Error)]
+#[error("No client with the given client ID was found")]
+struct ClientNotFound {
+ id: Uuid,
+}
+
+impl ResponseError for ClientNotFound {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::NOT_FOUND
+ }
+}
+
+impl ClientNotFound {
+ fn new(id: Uuid) -> Self {
+ Self { id }
+ }
+}
+
+#[get("/{client_id}")]
+async fn get_client(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(client) = db::get_client_response(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let redirect_uris_link = format!("</clients/{client_id}/redirect-uris>; rel=\"redirect-uris\"");
+ let response: ClientResponse = client.into();
+ let response = HttpResponse::Ok()
+ .append_header((header::LINK, redirect_uris_link))
+ .json(response);
+ Ok(response)
+}
+
+#[get("/{client_id}/alias")]
+async fn get_client_alias(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(alias) = db::get_client_alias(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ Ok(HttpResponse::Ok().json(alias))
+}
+
+#[get("/{client_id}/type")]
+async fn get_client_type(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(client_type) = db::get_client_type(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ Ok(HttpResponse::Ok().json(client_type))
+}
+
+#[get("/{client_id}/redirect-uris")]
+async fn get_client_redirect_uris(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let redirect_uris = db::get_client_redirect_uris(db, id).await.unwrap();
+
+ Ok(HttpResponse::Ok().json(redirect_uris))
+}
+
+#[get("/{client_id}/allowed-scopes")]
+async fn get_client_allowed_scopes(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(allowed_scopes) = db::get_client_allowed_scopes(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let allowed_scopes = allowed_scopes.split_whitespace().collect::<Box<[&str]>>();
+
+ Ok(HttpResponse::Ok().json(allowed_scopes))
+}
+
+#[get("/{client_id}/default-scopes")]
+async fn get_client_default_scopes(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(default_scopes) = db::get_client_default_scopes(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let default_scopes = default_scopes.map(|scopes| {
+ scopes
+ .split_whitespace()
+ .map(Box::from)
+ .collect::<Box<[Box<str>]>>()
+ });
+
+ Ok(HttpResponse::Ok().json(default_scopes))
+}
+
+#[get("/{client_id}/is-trusted")]
+async fn get_client_is_trusted(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(is_trusted) = db::is_client_trusted(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ Ok(HttpResponse::Ok().json(is_trusted))
+}
+
+#[derive(Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct ClientRequest {
+ alias: Box<str>,
+ ty: ClientType,
+ redirect_uris: Box<[Url]>,
+ secret: Option<Box<str>>,
+ allowed_scopes: Box<[Box<str>]>,
+ default_scopes: Option<Box<[Box<str>]>>,
+ trusted: bool,
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("The given client alias is already taken")]
+struct AliasTakenError {
+ alias: Box<str>,
+}
+
+impl ResponseError for AliasTakenError {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::CONFLICT
+ }
+}
+
+impl AliasTakenError {
+ fn new(alias: &str) -> Self {
+ Self {
+ alias: Box::from(alias),
+ }
+ }
+}
+
+#[post("")]
+async fn create_client(
+ body: web::Json<ClientRequest>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let alias = &body.alias;
+
+ if db::client_alias_exists(db, &alias).await.unwrap() {
+ yeet!(AliasTakenError::new(&alias).into());
+ }
+
+ let id = id::new_id(db, db::client_id_exists).await.unwrap();
+ let client = Client::new(
+ id,
+ &alias,
+ body.ty,
+ body.secret.as_deref(),
+ body.allowed_scopes.clone(),
+ body.default_scopes.clone(),
+ &body.redirect_uris,
+ body.trusted,
+ )
+ .map_err(|e| e.unwrap())?;
+
+ let transaction = db.begin().await.unwrap();
+ db::create_client(transaction, &client).await.unwrap();
+
+ let response = HttpResponse::Created()
+ .insert_header((header::LOCATION, format!("clients/{id}")))
+ .finish();
+ Ok(response)
+}
+
+#[derive(Debug, Clone, Error)]
+enum UpdateClientError {
+ #[error(transparent)]
+ NotFound(#[from] ClientNotFound),
+ #[error(transparent)]
+ ClientError(#[from] CreateClientError),
+ #[error(transparent)]
+ AliasTaken(#[from] AliasTakenError),
+}
+
+impl ResponseError for UpdateClientError {
+ fn status_code(&self) -> StatusCode {
+ match self {
+ Self::NotFound(e) => e.status_code(),
+ Self::ClientError(e) => e.status_code(),
+ Self::AliasTaken(e) => e.status_code(),
+ }
+ }
+}
+
+#[put("/{id}")]
+async fn update_client(
+ id: web::Path<Uuid>,
+ body: web::Json<ClientRequest>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let alias = &body.alias;
+
+ let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id).into())
+ };
+ if old_alias != alias.clone() && db::client_alias_exists(db, &alias).await.unwrap() {
+ yeet!(AliasTakenError::new(&alias).into());
+ }
+
+ let client = Client::new(
+ id,
+ &alias,
+ body.ty,
+ body.secret.as_deref(),
+ body.allowed_scopes.clone(),
+ body.default_scopes.clone(),
+ &body.redirect_uris,
+ body.trusted,
+ )
+ .map_err(|e| e.unwrap())?;
+
+ let transaction = db.begin().await.unwrap();
+ db::update_client(transaction, &client).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{id}/alias")]
+async fn update_client_alias(
+ id: web::Path<Uuid>,
+ body: web::Json<Box<str>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let alias = body.0;
+
+ let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id).into())
+ };
+ if old_alias == alias {
+ return Ok(HttpResponse::NoContent().finish());
+ }
+ if db::client_alias_exists(db, &alias).await.unwrap() {
+ yeet!(AliasTakenError::new(&alias).into());
+ }
+
+ db::update_client_alias(db, id, &alias).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{id}/type")]
+async fn update_client_type(
+ id: web::Path<Uuid>,
+ body: web::Json<ClientType>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let ty = body.0;
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_type(db, id, ty).await.unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("/{id}/allowed-scopes")]
+async fn update_client_allowed_scopes(
+ id: web::Path<Uuid>,
+ body: web::Json<Box<[Box<str>]>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let allowed_scopes = body.0.join(" ");
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_allowed_scopes(db, id, &allowed_scopes)
+ .await
+ .unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("/{id}/default-scopes")]
+async fn update_client_default_scopes(
+ id: web::Path<Uuid>,
+ body: web::Json<Option<Box<[Box<str>]>>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let default_scopes = body.0.map(|s| s.join(" "));
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_default_scopes(db, id, default_scopes)
+ .await
+ .unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("/{id}/is-trusted")]
+async fn update_client_is_trusted(
+ id: web::Path<Uuid>,
+ body: web::Json<bool>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let is_trusted = *body;
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_trusted(db, id, is_trusted).await.unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("/{id}/redirect-uris")]
+async fn update_client_redirect_uris(
+ id: web::Path<Uuid>,
+ body: web::Json<Box<[Url]>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+
+ for uri in body.0.iter() {
+ if uri.scheme() != "https" {
+ yeet!(CreateClientError::NonHttpsUri.into());
+ }
+
+ if uri.fragment().is_some() {
+ yeet!(CreateClientError::UriFragment.into())
+ }
+ }
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ let transaction = db.begin().await.unwrap();
+ db::update_client_redirect_uris(transaction, id, &body.0)
+ .await
+ .unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("{id}/secret")]
+async fn update_client_secret(
+ id: web::Path<Uuid>,
+ body: web::Json<Option<Box<str>>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+
+ let Some(client_type) = db::get_client_type(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id).into())
+ };
+
+ if client_type == ClientType::Confidential && body.is_none() {
+ yeet!(CreateClientError::NoSecret.into())
+ }
+
+ let secret = body.0.map(|s| PasswordHash::new(&s).unwrap());
+ db::update_client_secret(db, id, secret).await.unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+pub fn service() -> Scope {
+ web::scope("/clients")
+ .service(get_client)
+ .service(get_client_alias)
+ .service(get_client_type)
+ .service(get_client_allowed_scopes)
+ .service(get_client_default_scopes)
+ .service(get_client_redirect_uris)
+ .service(get_client_is_trusted)
+ .service(create_client)
+ .service(update_client)
+ .service(update_client_alias)
+ .service(update_client_type)
+ .service(update_client_allowed_scopes)
+ .service(update_client_default_scopes)
+ .service(update_client_redirect_uris)
+ .service(update_client_secret)
+}
diff --git a/src/api/liveops.rs b/src/api/liveops.rs
index d4bf129..2caf6e3 100644
--- a/src/api/liveops.rs
+++ b/src/api/liveops.rs
@@ -1,11 +1,11 @@
-use actix_web::{get, web, HttpResponse, Scope};
-
-/// Simple ping
-#[get("/ping")]
-async fn ping() -> HttpResponse {
- HttpResponse::Ok().finish()
-}
-
-pub fn service() -> Scope {
- web::scope("/liveops").service(ping)
-}
+use actix_web::{get, web, HttpResponse, Scope};
+
+/// Simple ping
+#[get("/ping")]
+async fn ping() -> HttpResponse {
+ HttpResponse::Ok().finish()
+}
+
+pub fn service() -> Scope {
+ web::scope("/liveops").service(ping)
+}
diff --git a/src/api/mod.rs b/src/api/mod.rs
index 0ab4037..9059e71 100644
--- a/src/api/mod.rs
+++ b/src/api/mod.rs
@@ -1,13 +1,13 @@
-mod clients;
-mod liveops;
-mod oauth;
-mod ops;
-mod users;
-
-pub use clients::service as clients;
-pub use liveops::service as liveops;
-pub use oauth::service as oauth;
-pub use ops::service as ops;
-pub use users::service as users;
-
-pub use oauth::AuthorizationParameters;
+mod clients;
+mod liveops;
+mod oauth;
+mod ops;
+mod users;
+
+pub use clients::service as clients;
+pub use liveops::service as liveops;
+pub use oauth::service as oauth;
+pub use ops::service as ops;
+pub use users::service as users;
+
+pub use oauth::AuthorizationParameters;
diff --git a/src/api/oauth.rs b/src/api/oauth.rs
index f1aa012..3422d2f 100644
--- a/src/api/oauth.rs
+++ b/src/api/oauth.rs
@@ -1,926 +1,926 @@
-use std::ops::Deref;
-use std::str::FromStr;
-
-use actix_web::http::{header, StatusCode};
-use actix_web::{
- get, post, web, HttpRequest, HttpResponse, HttpResponseBuilder, ResponseError, Scope,
-};
-use chrono::Duration;
-use exun::{Expect, RawUnexpected, ResultErrorExt, UnexpectedError};
-use raise::yeet;
-use serde::{Deserialize, Serialize};
-use sqlx::MySqlPool;
-use tera::Tera;
-use thiserror::Error;
-use unic_langid::subtags::Language;
-use url::Url;
-use uuid::Uuid;
-
-use crate::models::client::ClientType;
-use crate::resources::{languages, templates};
-use crate::scopes;
-use crate::services::jwt::VerifyJwtError;
-use crate::services::{authorization, config, db, jwt};
-
-const REALLY_BAD_ERROR_PAGE: &str = "<!DOCTYPE html><html><head><title>Internal Server Error</title></head><body>Internal Server Error</body></html>";
-
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
-#[serde(rename_all = "snake_case")]
-enum ResponseType {
- Code,
- Token,
- #[serde(other)]
- Unsupported,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct AuthorizationParameters {
- response_type: ResponseType,
- client_id: Box<str>,
- redirect_uri: Option<Url>,
- scope: Option<Box<str>>,
- state: Option<Box<str>>,
-}
-
-#[derive(Clone, Deserialize)]
-struct AuthorizeCredentials {
- username: Box<str>,
- password: Box<str>,
-}
-
-#[derive(Clone, Serialize)]
-struct AuthCodeResponse {
- code: Box<str>,
- state: Option<Box<str>>,
-}
-
-#[derive(Clone, Serialize)]
-struct AuthTokenResponse {
- access_token: Box<str>,
- token_type: &'static str,
- expires_in: i64,
- scope: Box<str>,
- state: Option<Box<str>>,
-}
-
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
-#[serde(rename_all = "camelCase")]
-enum AuthorizeErrorType {
- InvalidRequest,
- UnauthorizedClient,
- AccessDenied,
- UnsupportedResponseType,
- InvalidScope,
- ServerError,
- TemporarilyUnavailable,
-}
-
-#[derive(Debug, Clone, Error, Serialize)]
-#[error("{error_description}")]
-struct AuthorizeError {
- error: AuthorizeErrorType,
- error_description: Box<str>,
- // TODO error uri
- state: Option<Box<str>>,
- #[serde(skip)]
- redirect_uri: Url,
-}
-
-impl AuthorizeError {
- fn no_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self {
- Self {
- error: AuthorizeErrorType::InvalidScope,
- error_description: Box::from(
- "No scope was provided, and the client does not have a default scope",
- ),
- state,
- redirect_uri,
- }
- }
-
- fn unsupported_response_type(redirect_uri: Url, state: Option<Box<str>>) -> Self {
- Self {
- error: AuthorizeErrorType::UnsupportedResponseType,
- error_description: Box::from("The given response type is not supported"),
- state,
- redirect_uri,
- }
- }
-
- fn invalid_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self {
- Self {
- error: AuthorizeErrorType::InvalidScope,
- error_description: Box::from("The given scope exceeds what the client is allowed"),
- state,
- redirect_uri,
- }
- }
-
- fn internal_server_error(redirect_uri: Url, state: Option<Box<str>>) -> Self {
- Self {
- error: AuthorizeErrorType::ServerError,
- error_description: "An unexpected error occurred".into(),
- state,
- redirect_uri,
- }
- }
-}
-
-impl ResponseError for AuthorizeError {
- fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
- let query = Some(serde_urlencoded::to_string(self).unwrap());
- let query = query.as_deref();
- let mut url = self.redirect_uri.clone();
- url.set_query(query);
-
- HttpResponse::Found()
- .insert_header((header::LOCATION, url.as_str()))
- .finish()
- }
-}
-
-fn error_page(
- tera: &Tera,
- translations: &languages::Translations,
- error: templates::ErrorPage,
-) -> Result<String, RawUnexpected> {
- // TODO find a better way of doing languages
- let language = Language::from_str("en").unwrap();
- let translations = translations.clone();
- let page = templates::error_page(&tera, language, translations, error)?;
- Ok(page)
-}
-
-async fn get_redirect_uri(
- redirect_uri: &Option<Url>,
- db: &MySqlPool,
- client_id: Uuid,
-) -> Result<Url, Expect<templates::ErrorPage>> {
- if let Some(uri) = &redirect_uri {
- let redirect_uri = uri.clone();
- if !db::client_has_redirect_uri(db, client_id, &redirect_uri)
- .await
- .map_err(|e| UnexpectedError::from(e))
- .unexpect()?
- {
- yeet!(Expect::Expected(templates::ErrorPage::InvalidRedirectUri));
- }
-
- Ok(redirect_uri)
- } else {
- let redirect_uris = db::get_client_redirect_uris(db, client_id)
- .await
- .map_err(|e| UnexpectedError::from(e))
- .unexpect()?;
- if redirect_uris.len() != 1 {
- yeet!(Expect::Expected(templates::ErrorPage::MissingRedirectUri));
- }
-
- Ok(redirect_uris.get(0).unwrap().clone())
- }
-}
-
-async fn get_scope(
- scope: &Option<Box<str>>,
- db: &MySqlPool,
- client_id: Uuid,
- redirect_uri: &Url,
- state: &Option<Box<str>>,
-) -> Result<Box<str>, Expect<AuthorizeError>> {
- let scope = if let Some(scope) = &scope {
- scope.clone()
- } else {
- let default_scopes = db::get_client_default_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- let Some(scope) = default_scopes else {
- yeet!(AuthorizeError::no_scope(redirect_uri.clone(), state.clone()).into())
- };
- scope
- };
-
- // verify scope is valid
- let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- if !scopes::is_subset_of(&scope, &allowed_scopes) {
- yeet!(AuthorizeError::invalid_scope(redirect_uri.clone(), state.clone()).into());
- }
-
- Ok(scope)
-}
-
-async fn authenticate_user(
- db: &MySqlPool,
- username: &str,
- password: &str,
-) -> Result<Option<Uuid>, RawUnexpected> {
- let Some(user) = db::get_user_by_username(db, username).await? else {
- return Ok(None);
- };
-
- if user.check_password(password)? {
- Ok(Some(user.id))
- } else {
- Ok(None)
- }
-}
-
-#[post("/authorize")]
-async fn authorize(
- db: web::Data<MySqlPool>,
- req: web::Query<AuthorizationParameters>,
- credentials: web::Json<AuthorizeCredentials>,
- tera: web::Data<Tera>,
- translations: web::Data<languages::Translations>,
-) -> Result<HttpResponse, AuthorizeError> {
- // TODO protect against brute force attacks
- let db = db.get_ref();
- let Ok(client_id) = db::get_client_id_by_alias(db, &req.client_id).await else {
- let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page));
- };
- let Some(client_id) = client_id else {
- let page = error_page(&tera, &translations, templates::ErrorPage::ClientNotFound).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::NotFound().content_type("text/html").body(page));
- };
- let Ok(config) = config::get_config() else {
- let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page));
- };
-
- let self_id = config.url;
- let state = req.state.clone();
-
- // get redirect uri
- let mut redirect_uri = match get_redirect_uri(&req.redirect_uri, db, client_id).await {
- Ok(uri) => uri,
- Err(e) => {
- let e = e
- .expected()
- .unwrap_or(templates::ErrorPage::InternalServerError);
- let page = error_page(&tera, &translations, e)
- .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::BadRequest()
- .content_type("text/html")
- .body(page));
- }
- };
-
- // authenticate user
- let Some(user_id) = authenticate_user(db, &credentials.username, &credentials.password)
- .await
- .unwrap() else
- {
- let language = Language::from_str("en").unwrap();
- let translations = translations.get_ref().clone();
- let page = templates::login_error_page(&tera, &req, language, translations).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::Ok().content_type("text/html").body(page));
- };
-
- let internal_server_error =
- AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone());
-
- // get scope
- let scope = match get_scope(&req.scope, db, client_id, &redirect_uri, &state).await {
- Ok(scope) => scope,
- Err(e) => {
- let e = e.expected().unwrap_or(internal_server_error);
- return Err(e);
- }
- };
-
- match req.response_type {
- ResponseType::Code => {
- // create auth code
- let code =
- jwt::Claims::auth_code(db, self_id, client_id, user_id, &scope, &redirect_uri)
- .await
- .map_err(|_| internal_server_error.clone())?;
- let code = code.to_jwt().map_err(|_| internal_server_error.clone())?;
-
- let response = AuthCodeResponse { code, state };
- let query =
- Some(serde_urlencoded::to_string(response).map_err(|_| internal_server_error)?);
- let query = query.as_deref();
- redirect_uri.set_query(query);
-
- Ok(HttpResponse::Found()
- .append_header((header::LOCATION, redirect_uri.as_str()))
- .finish())
- }
- ResponseType::Token => {
- // create access token
- let duration = Duration::hours(1);
- let access_token =
- jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope)
- .await
- .map_err(|_| internal_server_error.clone())?;
-
- let access_token = access_token
- .to_jwt()
- .map_err(|_| internal_server_error.clone())?;
- let expires_in = duration.num_seconds();
- let token_type = "bearer";
- let response = AuthTokenResponse {
- access_token,
- expires_in,
- token_type,
- scope,
- state,
- };
-
- let fragment = Some(
- serde_urlencoded::to_string(response).map_err(|_| internal_server_error.clone())?,
- );
- let fragment = fragment.as_deref();
- redirect_uri.set_fragment(fragment);
-
- Ok(HttpResponse::Found()
- .append_header((header::LOCATION, redirect_uri.as_str()))
- .finish())
- }
- _ => Err(AuthorizeError::invalid_scope(redirect_uri, state)),
- }
-}
-
-#[get("/authorize")]
-async fn authorize_page(
- db: web::Data<MySqlPool>,
- tera: web::Data<Tera>,
- translations: web::Data<languages::Translations>,
- request: HttpRequest,
-) -> Result<HttpResponse, AuthorizeError> {
- let Ok(language) = Language::from_str("en") else {
- let page = String::from(REALLY_BAD_ERROR_PAGE);
- return Ok(HttpResponse::InternalServerError()
- .content_type("text/html")
- .body(page));
- };
- let translations = translations.get_ref().clone();
-
- let params = request.query_string();
- let params = serde_urlencoded::from_str::<AuthorizationParameters>(params);
- let Ok(params) = params else {
- let page = error_page(
- &tera,
- &translations,
- templates::ErrorPage::InvalidRequest,
- )
- .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::BadRequest()
- .content_type("text/html")
- .body(page));
- };
-
- let db = db.get_ref();
- let Ok(client_id) = db::get_client_id_by_alias(db, &params.client_id).await else {
- let page = templates::error_page(
- &tera,
- language,
- translations,
- templates::ErrorPage::InternalServerError,
- )
- .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::InternalServerError()
- .content_type("text/html")
- .body(page));
- };
- let Some(client_id) = client_id else {
- let page = templates::error_page(
- &tera,
- language,
- translations,
- templates::ErrorPage::ClientNotFound,
- )
- .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::NotFound()
- .content_type("text/html")
- .body(page));
- };
-
- // verify redirect uri
- let redirect_uri = match get_redirect_uri(&params.redirect_uri, db, client_id).await {
- Ok(uri) => uri,
- Err(e) => {
- let e = e
- .expected()
- .unwrap_or(templates::ErrorPage::InternalServerError);
- let page = error_page(&tera, &translations, e)
- .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return Ok(HttpResponse::BadRequest()
- .content_type("text/html")
- .body(page));
- }
- };
-
- let state = &params.state;
- let internal_server_error =
- AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone());
-
- // verify scope
- let _ = match get_scope(&params.scope, db, client_id, &redirect_uri, &params.state).await {
- Ok(scope) => scope,
- Err(e) => {
- let e = e.expected().unwrap_or(internal_server_error);
- return Err(e);
- }
- };
-
- // verify response type
- if params.response_type == ResponseType::Unsupported {
- return Err(AuthorizeError::unsupported_response_type(
- redirect_uri,
- params.state,
- ));
- }
-
- // TODO find a better way of doing languages
- let language = Language::from_str("en").unwrap();
- let page = templates::login_page(&tera, &params, language, translations).unwrap();
- Ok(HttpResponse::Ok().content_type("text/html").body(page))
-}
-
-#[derive(Clone, Deserialize)]
-#[serde(tag = "grant_type")]
-#[serde(rename_all = "snake_case")]
-enum GrantType {
- AuthorizationCode {
- code: Box<str>,
- redirect_uri: Url,
- #[serde(rename = "client_id")]
- client_alias: Box<str>,
- },
- Password {
- username: Box<str>,
- password: Box<str>,
- scope: Option<Box<str>>,
- },
- ClientCredentials {
- scope: Option<Box<str>>,
- },
- RefreshToken {
- refresh_token: Box<str>,
- scope: Option<Box<str>>,
- },
- #[serde(other)]
- Unsupported,
-}
-
-#[derive(Clone, Deserialize)]
-struct TokenRequest {
- #[serde(flatten)]
- grant_type: GrantType,
- // TODO support optional client credentials in here
-}
-
-#[derive(Clone, Serialize)]
-struct TokenResponse {
- access_token: Box<str>,
- token_type: Box<str>,
- expires_in: i64,
- refresh_token: Option<Box<str>>,
- scope: Box<str>,
-}
-
-#[derive(Debug, Clone, Serialize)]
-#[serde(rename_all = "snake_case")]
-enum TokenErrorType {
- InvalidRequest,
- InvalidClient,
- InvalidGrant,
- UnauthorizedClient,
- UnsupportedGrantType,
- InvalidScope,
-}
-
-#[derive(Debug, Clone, Error, Serialize)]
-#[error("{error_description}")]
-struct TokenError {
- #[serde(skip)]
- status_code: StatusCode,
- error: TokenErrorType,
- error_description: Box<str>,
- // TODO error uri
-}
-
-impl TokenError {
- fn invalid_request() -> Self {
- // TODO make this description better, and all the other ones while you're at it
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::InvalidRequest,
- error_description: "Invalid request".into(),
- }
- }
-
- fn unsupported_grant_type() -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::UnsupportedGrantType,
- error_description: "The given grant type is not supported".into(),
- }
- }
-
- fn bad_auth_code(error: VerifyJwtError) -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::InvalidGrant,
- error_description: error.to_string().into_boxed_str(),
- }
- }
-
- fn no_authorization() -> Self {
- Self {
- status_code: StatusCode::UNAUTHORIZED,
- error: TokenErrorType::InvalidClient,
- error_description: Box::from(
- "Client credentials must be provided in the HTTP Authorization header",
- ),
- }
- }
-
- fn client_not_found(alias: &str) -> Self {
- Self {
- status_code: StatusCode::UNAUTHORIZED,
- error: TokenErrorType::InvalidClient,
- error_description: format!("No client with the client id: {alias} was found")
- .into_boxed_str(),
- }
- }
-
- fn mismatch_client_id() -> Self {
- Self {
- status_code: StatusCode::UNAUTHORIZED,
- error: TokenErrorType::InvalidClient,
- error_description: Box::from("The client ID in the Authorization header is not the same as the client ID in the request body"),
- }
- }
-
- fn incorrect_client_secret() -> Self {
- Self {
- status_code: StatusCode::UNAUTHORIZED,
- error: TokenErrorType::InvalidClient,
- error_description: "The client secret is incorrect".into(),
- }
- }
-
- fn client_not_confidential(alias: &str) -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::UnauthorizedClient,
- error_description: format!("Only a confidential client may be used with this endpoint. The {alias} client is a public client.")
- .into_boxed_str(),
- }
- }
-
- fn no_scope() -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::InvalidScope,
- error_description: Box::from(
- "No scope was provided, and the client doesn't have a default scope",
- ),
- }
- }
-
- fn excessive_scope() -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::InvalidScope,
- error_description: Box::from(
- "The given scope exceeds what the client is allowed to have",
- ),
- }
- }
-
- fn bad_refresh_token(err: VerifyJwtError) -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::InvalidGrant,
- error_description: err.to_string().into_boxed_str(),
- }
- }
-
- fn untrusted_client() -> Self {
- Self {
- status_code: StatusCode::UNAUTHORIZED,
- error: TokenErrorType::InvalidClient,
- error_description: "Only trusted clients may use this grant".into(),
- }
- }
-
- fn incorrect_user_credentials() -> Self {
- Self {
- status_code: StatusCode::BAD_REQUEST,
- error: TokenErrorType::InvalidRequest,
- error_description: "The given credentials are incorrect".into(),
- }
- }
-}
-
-impl ResponseError for TokenError {
- fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
- let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
-
- let mut builder = HttpResponseBuilder::new(self.status_code);
-
- if self.status_code.as_u16() == 401 {
- builder.insert_header((header::WWW_AUTHENTICATE, "Basic charset=\"UTF-8\""));
- }
-
- builder
- .insert_header(cache_control)
- .insert_header((header::PRAGMA, "no-cache"))
- .json(self.clone())
- }
-}
-
-#[post("/token")]
-async fn token(
- db: web::Data<MySqlPool>,
- req: web::Bytes,
- authorization: Option<web::Header<authorization::BasicAuthorization>>,
-) -> HttpResponse {
- // TODO protect against brute force attacks
- let db = db.get_ref();
- let request = serde_json::from_slice::<TokenRequest>(&req);
- let Ok(request) = request else {
- return TokenError::invalid_request().error_response();
- };
- let config = config::get_config().unwrap();
-
- let self_id = config.url;
- let duration = Duration::hours(1);
- let token_type = Box::from("bearer");
- let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
-
- match request.grant_type {
- GrantType::AuthorizationCode {
- code,
- redirect_uri,
- client_alias,
- } => {
- let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else {
- return TokenError::client_not_found(&client_alias).error_response();
- };
-
- // validate auth code
- let claims =
- match jwt::verify_auth_code(db, &code, &self_id, client_id, redirect_uri).await {
- Ok(claims) => claims,
- Err(err) => {
- let err = err.unwrap();
- return TokenError::bad_auth_code(err).error_response();
- }
- };
-
- // verify client, if the client has credentials
- if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() {
- let Some(authorization) = authorization else {
- return TokenError::no_authorization().error_response();
- };
-
- if authorization.username() != client_alias.deref() {
- return TokenError::mismatch_client_id().error_response();
- }
- if !hash.check_password(authorization.password()).unwrap() {
- return TokenError::incorrect_client_secret().error_response();
- }
- }
-
- let access_token = jwt::Claims::access_token(
- db,
- Some(claims.id()),
- self_id,
- client_id,
- claims.subject(),
- duration,
- claims.scopes(),
- )
- .await
- .unwrap();
-
- let expires_in = access_token.expires_in();
- let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
- let scope = access_token.scopes().into();
-
- let access_token = access_token.to_jwt().unwrap();
- let refresh_token = Some(refresh_token.to_jwt().unwrap());
-
- let response = TokenResponse {
- access_token,
- token_type,
- expires_in,
- refresh_token,
- scope,
- };
- HttpResponse::Ok()
- .insert_header(cache_control)
- .insert_header((header::PRAGMA, "no-cache"))
- .json(response)
- }
- GrantType::Password {
- username,
- password,
- scope,
- } => {
- let Some(authorization) = authorization else {
- return TokenError::no_authorization().error_response();
- };
- let client_alias = authorization.username();
- let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
- return TokenError::client_not_found(client_alias).error_response();
- };
-
- let trusted = db::is_client_trusted(db, client_id).await.unwrap().unwrap();
- if !trusted {
- return TokenError::untrusted_client().error_response();
- }
-
- // verify client
- let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap();
- if !hash.check_password(authorization.password()).unwrap() {
- return TokenError::incorrect_client_secret().error_response();
- }
-
- // authenticate user
- let Some(user_id) = authenticate_user(db, &username, &password).await.unwrap() else {
- return TokenError::incorrect_user_credentials().error_response();
- };
-
- // verify scope
- let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- let scope = if let Some(scope) = &scope {
- scope.clone()
- } else {
- let default_scopes = db::get_client_default_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- let Some(scope) = default_scopes else {
- return TokenError::no_scope().error_response();
- };
- scope
- };
- if !scopes::is_subset_of(&scope, &allowed_scopes) {
- return TokenError::excessive_scope().error_response();
- }
-
- let access_token =
- jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope)
- .await
- .unwrap();
- let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
-
- let expires_in = access_token.expires_in();
- let scope = access_token.scopes().into();
- let access_token = access_token.to_jwt().unwrap();
- let refresh_token = Some(refresh_token.to_jwt().unwrap());
-
- let response = TokenResponse {
- access_token,
- token_type,
- expires_in,
- refresh_token,
- scope,
- };
- HttpResponse::Ok()
- .insert_header(cache_control)
- .insert_header((header::PRAGMA, "no-cache"))
- .json(response)
- }
- GrantType::ClientCredentials { scope } => {
- let Some(authorization) = authorization else {
- return TokenError::no_authorization().error_response();
- };
- let client_alias = authorization.username();
- let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
- return TokenError::client_not_found(client_alias).error_response();
- };
-
- let ty = db::get_client_type(db, client_id).await.unwrap().unwrap();
- if ty != ClientType::Confidential {
- return TokenError::client_not_confidential(client_alias).error_response();
- }
-
- // verify client
- let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap();
- if !hash.check_password(authorization.password()).unwrap() {
- return TokenError::incorrect_client_secret().error_response();
- }
-
- // verify scope
- let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- let scope = if let Some(scope) = &scope {
- scope.clone()
- } else {
- let default_scopes = db::get_client_default_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- let Some(scope) = default_scopes else {
- return TokenError::no_scope().error_response();
- };
- scope
- };
- if !scopes::is_subset_of(&scope, &allowed_scopes) {
- return TokenError::excessive_scope().error_response();
- }
-
- let access_token = jwt::Claims::access_token(
- db, None, self_id, client_id, client_id, duration, &scope,
- )
- .await
- .unwrap();
-
- let expires_in = access_token.expires_in();
- let scope = access_token.scopes().into();
- let access_token = access_token.to_jwt().unwrap();
-
- let response = TokenResponse {
- access_token,
- token_type,
- expires_in,
- refresh_token: None,
- scope,
- };
- HttpResponse::Ok()
- .insert_header(cache_control)
- .insert_header((header::PRAGMA, "no-cache"))
- .json(response)
- }
- GrantType::RefreshToken {
- refresh_token,
- scope,
- } => {
- let client_id: Option<Uuid>;
- if let Some(authorization) = authorization {
- let client_alias = authorization.username();
- let Some(id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
- return TokenError::client_not_found(client_alias).error_response();
- };
- client_id = Some(id);
- } else {
- client_id = None;
- }
-
- let claims =
- match jwt::verify_refresh_token(db, &refresh_token, &self_id, client_id).await {
- Ok(claims) => claims,
- Err(e) => {
- let e = e.unwrap();
- return TokenError::bad_refresh_token(e).error_response();
- }
- };
-
- let scope = if let Some(scope) = scope {
- if !scopes::is_subset_of(&scope, claims.scopes()) {
- return TokenError::excessive_scope().error_response();
- }
-
- scope
- } else {
- claims.scopes().into()
- };
-
- let exp_time = Duration::hours(1);
- let access_token = jwt::Claims::refreshed_access_token(db, &claims, exp_time)
- .await
- .unwrap();
- let refresh_token = jwt::Claims::refresh_token(db, &claims).await.unwrap();
-
- let access_token = access_token.to_jwt().unwrap();
- let refresh_token = Some(refresh_token.to_jwt().unwrap());
- let expires_in = exp_time.num_seconds();
-
- let response = TokenResponse {
- access_token,
- token_type,
- expires_in,
- refresh_token,
- scope,
- };
- HttpResponse::Ok()
- .insert_header(cache_control)
- .insert_header((header::PRAGMA, "no-cache"))
- .json(response)
- }
- _ => TokenError::unsupported_grant_type().error_response(),
- }
-}
-
-pub fn service() -> Scope {
- web::scope("/oauth")
- .service(authorize_page)
- .service(authorize)
- .service(token)
-}
+use std::ops::Deref;
+use std::str::FromStr;
+
+use actix_web::http::{header, StatusCode};
+use actix_web::{
+ get, post, web, HttpRequest, HttpResponse, HttpResponseBuilder, ResponseError, Scope,
+};
+use chrono::Duration;
+use exun::{Expect, RawUnexpected, ResultErrorExt, UnexpectedError};
+use raise::yeet;
+use serde::{Deserialize, Serialize};
+use sqlx::MySqlPool;
+use tera::Tera;
+use thiserror::Error;
+use unic_langid::subtags::Language;
+use url::Url;
+use uuid::Uuid;
+
+use crate::models::client::ClientType;
+use crate::resources::{languages, templates};
+use crate::scopes;
+use crate::services::jwt::VerifyJwtError;
+use crate::services::{authorization, config, db, jwt};
+
+const REALLY_BAD_ERROR_PAGE: &str = "<!DOCTYPE html><html><head><title>Internal Server Error</title></head><body>Internal Server Error</body></html>";
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+enum ResponseType {
+ Code,
+ Token,
+ #[serde(other)]
+ Unsupported,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct AuthorizationParameters {
+ response_type: ResponseType,
+ client_id: Box<str>,
+ redirect_uri: Option<Url>,
+ scope: Option<Box<str>>,
+ state: Option<Box<str>>,
+}
+
+#[derive(Clone, Deserialize)]
+struct AuthorizeCredentials {
+ username: Box<str>,
+ password: Box<str>,
+}
+
+#[derive(Clone, Serialize)]
+struct AuthCodeResponse {
+ code: Box<str>,
+ state: Option<Box<str>>,
+}
+
+#[derive(Clone, Serialize)]
+struct AuthTokenResponse {
+ access_token: Box<str>,
+ token_type: &'static str,
+ expires_in: i64,
+ scope: Box<str>,
+ state: Option<Box<str>>,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
+#[serde(rename_all = "camelCase")]
+enum AuthorizeErrorType {
+ InvalidRequest,
+ UnauthorizedClient,
+ AccessDenied,
+ UnsupportedResponseType,
+ InvalidScope,
+ ServerError,
+ TemporarilyUnavailable,
+}
+
+#[derive(Debug, Clone, Error, Serialize)]
+#[error("{error_description}")]
+struct AuthorizeError {
+ error: AuthorizeErrorType,
+ error_description: Box<str>,
+ // TODO error uri
+ state: Option<Box<str>>,
+ #[serde(skip)]
+ redirect_uri: Url,
+}
+
+impl AuthorizeError {
+ fn no_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::InvalidScope,
+ error_description: Box::from(
+ "No scope was provided, and the client does not have a default scope",
+ ),
+ state,
+ redirect_uri,
+ }
+ }
+
+ fn unsupported_response_type(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::UnsupportedResponseType,
+ error_description: Box::from("The given response type is not supported"),
+ state,
+ redirect_uri,
+ }
+ }
+
+ fn invalid_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::InvalidScope,
+ error_description: Box::from("The given scope exceeds what the client is allowed"),
+ state,
+ redirect_uri,
+ }
+ }
+
+ fn internal_server_error(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::ServerError,
+ error_description: "An unexpected error occurred".into(),
+ state,
+ redirect_uri,
+ }
+ }
+}
+
+impl ResponseError for AuthorizeError {
+ fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
+ let query = Some(serde_urlencoded::to_string(self).unwrap());
+ let query = query.as_deref();
+ let mut url = self.redirect_uri.clone();
+ url.set_query(query);
+
+ HttpResponse::Found()
+ .insert_header((header::LOCATION, url.as_str()))
+ .finish()
+ }
+}
+
+fn error_page(
+ tera: &Tera,
+ translations: &languages::Translations,
+ error: templates::ErrorPage,
+) -> Result<String, RawUnexpected> {
+ // TODO find a better way of doing languages
+ let language = Language::from_str("en").unwrap();
+ let translations = translations.clone();
+ let page = templates::error_page(&tera, language, translations, error)?;
+ Ok(page)
+}
+
+async fn get_redirect_uri(
+ redirect_uri: &Option<Url>,
+ db: &MySqlPool,
+ client_id: Uuid,
+) -> Result<Url, Expect<templates::ErrorPage>> {
+ if let Some(uri) = &redirect_uri {
+ let redirect_uri = uri.clone();
+ if !db::client_has_redirect_uri(db, client_id, &redirect_uri)
+ .await
+ .map_err(|e| UnexpectedError::from(e))
+ .unexpect()?
+ {
+ yeet!(Expect::Expected(templates::ErrorPage::InvalidRedirectUri));
+ }
+
+ Ok(redirect_uri)
+ } else {
+ let redirect_uris = db::get_client_redirect_uris(db, client_id)
+ .await
+ .map_err(|e| UnexpectedError::from(e))
+ .unexpect()?;
+ if redirect_uris.len() != 1 {
+ yeet!(Expect::Expected(templates::ErrorPage::MissingRedirectUri));
+ }
+
+ Ok(redirect_uris.get(0).unwrap().clone())
+ }
+}
+
+async fn get_scope(
+ scope: &Option<Box<str>>,
+ db: &MySqlPool,
+ client_id: Uuid,
+ redirect_uri: &Url,
+ state: &Option<Box<str>>,
+) -> Result<Box<str>, Expect<AuthorizeError>> {
+ let scope = if let Some(scope) = &scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ yeet!(AuthorizeError::no_scope(redirect_uri.clone(), state.clone()).into())
+ };
+ scope
+ };
+
+ // verify scope is valid
+ let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ yeet!(AuthorizeError::invalid_scope(redirect_uri.clone(), state.clone()).into());
+ }
+
+ Ok(scope)
+}
+
+async fn authenticate_user(
+ db: &MySqlPool,
+ username: &str,
+ password: &str,
+) -> Result<Option<Uuid>, RawUnexpected> {
+ let Some(user) = db::get_user_by_username(db, username).await? else {
+ return Ok(None);
+ };
+
+ if user.check_password(password)? {
+ Ok(Some(user.id))
+ } else {
+ Ok(None)
+ }
+}
+
+#[post("/authorize")]
+async fn authorize(
+ db: web::Data<MySqlPool>,
+ req: web::Query<AuthorizationParameters>,
+ credentials: web::Json<AuthorizeCredentials>,
+ tera: web::Data<Tera>,
+ translations: web::Data<languages::Translations>,
+) -> Result<HttpResponse, AuthorizeError> {
+ // TODO protect against brute force attacks
+ let db = db.get_ref();
+ let Ok(client_id) = db::get_client_id_by_alias(db, &req.client_id).await else {
+ let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page));
+ };
+ let Some(client_id) = client_id else {
+ let page = error_page(&tera, &translations, templates::ErrorPage::ClientNotFound).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::NotFound().content_type("text/html").body(page));
+ };
+ let Ok(config) = config::get_config() else {
+ let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page));
+ };
+
+ let self_id = config.url;
+ let state = req.state.clone();
+
+ // get redirect uri
+ let mut redirect_uri = match get_redirect_uri(&req.redirect_uri, db, client_id).await {
+ Ok(uri) => uri,
+ Err(e) => {
+ let e = e
+ .expected()
+ .unwrap_or(templates::ErrorPage::InternalServerError);
+ let page = error_page(&tera, &translations, e)
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::BadRequest()
+ .content_type("text/html")
+ .body(page));
+ }
+ };
+
+ // authenticate user
+ let Some(user_id) = authenticate_user(db, &credentials.username, &credentials.password)
+ .await
+ .unwrap() else
+ {
+ let language = Language::from_str("en").unwrap();
+ let translations = translations.get_ref().clone();
+ let page = templates::login_error_page(&tera, &req, language, translations).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::Ok().content_type("text/html").body(page));
+ };
+
+ let internal_server_error =
+ AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone());
+
+ // get scope
+ let scope = match get_scope(&req.scope, db, client_id, &redirect_uri, &state).await {
+ Ok(scope) => scope,
+ Err(e) => {
+ let e = e.expected().unwrap_or(internal_server_error);
+ return Err(e);
+ }
+ };
+
+ match req.response_type {
+ ResponseType::Code => {
+ // create auth code
+ let code =
+ jwt::Claims::auth_code(db, self_id, client_id, user_id, &scope, &redirect_uri)
+ .await
+ .map_err(|_| internal_server_error.clone())?;
+ let code = code.to_jwt().map_err(|_| internal_server_error.clone())?;
+
+ let response = AuthCodeResponse { code, state };
+ let query =
+ Some(serde_urlencoded::to_string(response).map_err(|_| internal_server_error)?);
+ let query = query.as_deref();
+ redirect_uri.set_query(query);
+
+ Ok(HttpResponse::Found()
+ .append_header((header::LOCATION, redirect_uri.as_str()))
+ .finish())
+ }
+ ResponseType::Token => {
+ // create access token
+ let duration = Duration::hours(1);
+ let access_token =
+ jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope)
+ .await
+ .map_err(|_| internal_server_error.clone())?;
+
+ let access_token = access_token
+ .to_jwt()
+ .map_err(|_| internal_server_error.clone())?;
+ let expires_in = duration.num_seconds();
+ let token_type = "bearer";
+ let response = AuthTokenResponse {
+ access_token,
+ expires_in,
+ token_type,
+ scope,
+ state,
+ };
+
+ let fragment = Some(
+ serde_urlencoded::to_string(response).map_err(|_| internal_server_error.clone())?,
+ );
+ let fragment = fragment.as_deref();
+ redirect_uri.set_fragment(fragment);
+
+ Ok(HttpResponse::Found()
+ .append_header((header::LOCATION, redirect_uri.as_str()))
+ .finish())
+ }
+ _ => Err(AuthorizeError::invalid_scope(redirect_uri, state)),
+ }
+}
+
+#[get("/authorize")]
+async fn authorize_page(
+ db: web::Data<MySqlPool>,
+ tera: web::Data<Tera>,
+ translations: web::Data<languages::Translations>,
+ request: HttpRequest,
+) -> Result<HttpResponse, AuthorizeError> {
+ let Ok(language) = Language::from_str("en") else {
+ let page = String::from(REALLY_BAD_ERROR_PAGE);
+ return Ok(HttpResponse::InternalServerError()
+ .content_type("text/html")
+ .body(page));
+ };
+ let translations = translations.get_ref().clone();
+
+ let params = request.query_string();
+ let params = serde_urlencoded::from_str::<AuthorizationParameters>(params);
+ let Ok(params) = params else {
+ let page = error_page(
+ &tera,
+ &translations,
+ templates::ErrorPage::InvalidRequest,
+ )
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::BadRequest()
+ .content_type("text/html")
+ .body(page));
+ };
+
+ let db = db.get_ref();
+ let Ok(client_id) = db::get_client_id_by_alias(db, &params.client_id).await else {
+ let page = templates::error_page(
+ &tera,
+ language,
+ translations,
+ templates::ErrorPage::InternalServerError,
+ )
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::InternalServerError()
+ .content_type("text/html")
+ .body(page));
+ };
+ let Some(client_id) = client_id else {
+ let page = templates::error_page(
+ &tera,
+ language,
+ translations,
+ templates::ErrorPage::ClientNotFound,
+ )
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::NotFound()
+ .content_type("text/html")
+ .body(page));
+ };
+
+ // verify redirect uri
+ let redirect_uri = match get_redirect_uri(&params.redirect_uri, db, client_id).await {
+ Ok(uri) => uri,
+ Err(e) => {
+ let e = e
+ .expected()
+ .unwrap_or(templates::ErrorPage::InternalServerError);
+ let page = error_page(&tera, &translations, e)
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::BadRequest()
+ .content_type("text/html")
+ .body(page));
+ }
+ };
+
+ let state = &params.state;
+ let internal_server_error =
+ AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone());
+
+ // verify scope
+ let _ = match get_scope(&params.scope, db, client_id, &redirect_uri, &params.state).await {
+ Ok(scope) => scope,
+ Err(e) => {
+ let e = e.expected().unwrap_or(internal_server_error);
+ return Err(e);
+ }
+ };
+
+ // verify response type
+ if params.response_type == ResponseType::Unsupported {
+ return Err(AuthorizeError::unsupported_response_type(
+ redirect_uri,
+ params.state,
+ ));
+ }
+
+ // TODO find a better way of doing languages
+ let language = Language::from_str("en").unwrap();
+ let page = templates::login_page(&tera, &params, language, translations).unwrap();
+ Ok(HttpResponse::Ok().content_type("text/html").body(page))
+}
+
+#[derive(Clone, Deserialize)]
+#[serde(tag = "grant_type")]
+#[serde(rename_all = "snake_case")]
+enum GrantType {
+ AuthorizationCode {
+ code: Box<str>,
+ redirect_uri: Url,
+ #[serde(rename = "client_id")]
+ client_alias: Box<str>,
+ },
+ Password {
+ username: Box<str>,
+ password: Box<str>,
+ scope: Option<Box<str>>,
+ },
+ ClientCredentials {
+ scope: Option<Box<str>>,
+ },
+ RefreshToken {
+ refresh_token: Box<str>,
+ scope: Option<Box<str>>,
+ },
+ #[serde(other)]
+ Unsupported,
+}
+
+#[derive(Clone, Deserialize)]
+struct TokenRequest {
+ #[serde(flatten)]
+ grant_type: GrantType,
+ // TODO support optional client credentials in here
+}
+
+#[derive(Clone, Serialize)]
+struct TokenResponse {
+ access_token: Box<str>,
+ token_type: Box<str>,
+ expires_in: i64,
+ refresh_token: Option<Box<str>>,
+ scope: Box<str>,
+}
+
+#[derive(Debug, Clone, Serialize)]
+#[serde(rename_all = "snake_case")]
+enum TokenErrorType {
+ InvalidRequest,
+ InvalidClient,
+ InvalidGrant,
+ UnauthorizedClient,
+ UnsupportedGrantType,
+ InvalidScope,
+}
+
+#[derive(Debug, Clone, Error, Serialize)]
+#[error("{error_description}")]
+struct TokenError {
+ #[serde(skip)]
+ status_code: StatusCode,
+ error: TokenErrorType,
+ error_description: Box<str>,
+ // TODO error uri
+}
+
+impl TokenError {
+ fn invalid_request() -> Self {
+ // TODO make this description better, and all the other ones while you're at it
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidRequest,
+ error_description: "Invalid request".into(),
+ }
+ }
+
+ fn unsupported_grant_type() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::UnsupportedGrantType,
+ error_description: "The given grant type is not supported".into(),
+ }
+ }
+
+ fn bad_auth_code(error: VerifyJwtError) -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidGrant,
+ error_description: error.to_string().into_boxed_str(),
+ }
+ }
+
+ fn no_authorization() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: Box::from(
+ "Client credentials must be provided in the HTTP Authorization header",
+ ),
+ }
+ }
+
+ fn client_not_found(alias: &str) -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: format!("No client with the client id: {alias} was found")
+ .into_boxed_str(),
+ }
+ }
+
+ fn mismatch_client_id() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: Box::from("The client ID in the Authorization header is not the same as the client ID in the request body"),
+ }
+ }
+
+ fn incorrect_client_secret() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: "The client secret is incorrect".into(),
+ }
+ }
+
+ fn client_not_confidential(alias: &str) -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::UnauthorizedClient,
+ error_description: format!("Only a confidential client may be used with this endpoint. The {alias} client is a public client.")
+ .into_boxed_str(),
+ }
+ }
+
+ fn no_scope() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidScope,
+ error_description: Box::from(
+ "No scope was provided, and the client doesn't have a default scope",
+ ),
+ }
+ }
+
+ fn excessive_scope() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidScope,
+ error_description: Box::from(
+ "The given scope exceeds what the client is allowed to have",
+ ),
+ }
+ }
+
+ fn bad_refresh_token(err: VerifyJwtError) -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidGrant,
+ error_description: err.to_string().into_boxed_str(),
+ }
+ }
+
+ fn untrusted_client() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: "Only trusted clients may use this grant".into(),
+ }
+ }
+
+ fn incorrect_user_credentials() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidRequest,
+ error_description: "The given credentials are incorrect".into(),
+ }
+ }
+}
+
+impl ResponseError for TokenError {
+ fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
+ let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
+
+ let mut builder = HttpResponseBuilder::new(self.status_code);
+
+ if self.status_code.as_u16() == 401 {
+ builder.insert_header((header::WWW_AUTHENTICATE, "Basic charset=\"UTF-8\""));
+ }
+
+ builder
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(self.clone())
+ }
+}
+
+#[post("/token")]
+async fn token(
+ db: web::Data<MySqlPool>,
+ req: web::Bytes,
+ authorization: Option<web::Header<authorization::BasicAuthorization>>,
+) -> HttpResponse {
+ // TODO protect against brute force attacks
+ let db = db.get_ref();
+ let request = serde_json::from_slice::<TokenRequest>(&req);
+ let Ok(request) = request else {
+ return TokenError::invalid_request().error_response();
+ };
+ let config = config::get_config().unwrap();
+
+ let self_id = config.url;
+ let duration = Duration::hours(1);
+ let token_type = Box::from("bearer");
+ let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
+
+ match request.grant_type {
+ GrantType::AuthorizationCode {
+ code,
+ redirect_uri,
+ client_alias,
+ } => {
+ let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else {
+ return TokenError::client_not_found(&client_alias).error_response();
+ };
+
+ // validate auth code
+ let claims =
+ match jwt::verify_auth_code(db, &code, &self_id, client_id, redirect_uri).await {
+ Ok(claims) => claims,
+ Err(err) => {
+ let err = err.unwrap();
+ return TokenError::bad_auth_code(err).error_response();
+ }
+ };
+
+ // verify client, if the client has credentials
+ if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() {
+ let Some(authorization) = authorization else {
+ return TokenError::no_authorization().error_response();
+ };
+
+ if authorization.username() != client_alias.deref() {
+ return TokenError::mismatch_client_id().error_response();
+ }
+ if !hash.check_password(authorization.password()).unwrap() {
+ return TokenError::incorrect_client_secret().error_response();
+ }
+ }
+
+ let access_token = jwt::Claims::access_token(
+ db,
+ Some(claims.id()),
+ self_id,
+ client_id,
+ claims.subject(),
+ duration,
+ claims.scopes(),
+ )
+ .await
+ .unwrap();
+
+ let expires_in = access_token.expires_in();
+ let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
+ let scope = access_token.scopes().into();
+
+ let access_token = access_token.to_jwt().unwrap();
+ let refresh_token = Some(refresh_token.to_jwt().unwrap());
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ GrantType::Password {
+ username,
+ password,
+ scope,
+ } => {
+ let Some(authorization) = authorization else {
+ return TokenError::no_authorization().error_response();
+ };
+ let client_alias = authorization.username();
+ let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
+ return TokenError::client_not_found(client_alias).error_response();
+ };
+
+ let trusted = db::is_client_trusted(db, client_id).await.unwrap().unwrap();
+ if !trusted {
+ return TokenError::untrusted_client().error_response();
+ }
+
+ // verify client
+ let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap();
+ if !hash.check_password(authorization.password()).unwrap() {
+ return TokenError::incorrect_client_secret().error_response();
+ }
+
+ // authenticate user
+ let Some(user_id) = authenticate_user(db, &username, &password).await.unwrap() else {
+ return TokenError::incorrect_user_credentials().error_response();
+ };
+
+ // verify scope
+ let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let scope = if let Some(scope) = &scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ return TokenError::no_scope().error_response();
+ };
+ scope
+ };
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ return TokenError::excessive_scope().error_response();
+ }
+
+ let access_token =
+ jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope)
+ .await
+ .unwrap();
+ let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
+
+ let expires_in = access_token.expires_in();
+ let scope = access_token.scopes().into();
+ let access_token = access_token.to_jwt().unwrap();
+ let refresh_token = Some(refresh_token.to_jwt().unwrap());
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ GrantType::ClientCredentials { scope } => {
+ let Some(authorization) = authorization else {
+ return TokenError::no_authorization().error_response();
+ };
+ let client_alias = authorization.username();
+ let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
+ return TokenError::client_not_found(client_alias).error_response();
+ };
+
+ let ty = db::get_client_type(db, client_id).await.unwrap().unwrap();
+ if ty != ClientType::Confidential {
+ return TokenError::client_not_confidential(client_alias).error_response();
+ }
+
+ // verify client
+ let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap();
+ if !hash.check_password(authorization.password()).unwrap() {
+ return TokenError::incorrect_client_secret().error_response();
+ }
+
+ // verify scope
+ let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let scope = if let Some(scope) = &scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ return TokenError::no_scope().error_response();
+ };
+ scope
+ };
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ return TokenError::excessive_scope().error_response();
+ }
+
+ let access_token = jwt::Claims::access_token(
+ db, None, self_id, client_id, client_id, duration, &scope,
+ )
+ .await
+ .unwrap();
+
+ let expires_in = access_token.expires_in();
+ let scope = access_token.scopes().into();
+ let access_token = access_token.to_jwt().unwrap();
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token: None,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ GrantType::RefreshToken {
+ refresh_token,
+ scope,
+ } => {
+ let client_id: Option<Uuid>;
+ if let Some(authorization) = authorization {
+ let client_alias = authorization.username();
+ let Some(id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
+ return TokenError::client_not_found(client_alias).error_response();
+ };
+ client_id = Some(id);
+ } else {
+ client_id = None;
+ }
+
+ let claims =
+ match jwt::verify_refresh_token(db, &refresh_token, &self_id, client_id).await {
+ Ok(claims) => claims,
+ Err(e) => {
+ let e = e.unwrap();
+ return TokenError::bad_refresh_token(e).error_response();
+ }
+ };
+
+ let scope = if let Some(scope) = scope {
+ if !scopes::is_subset_of(&scope, claims.scopes()) {
+ return TokenError::excessive_scope().error_response();
+ }
+
+ scope
+ } else {
+ claims.scopes().into()
+ };
+
+ let exp_time = Duration::hours(1);
+ let access_token = jwt::Claims::refreshed_access_token(db, &claims, exp_time)
+ .await
+ .unwrap();
+ let refresh_token = jwt::Claims::refresh_token(db, &claims).await.unwrap();
+
+ let access_token = access_token.to_jwt().unwrap();
+ let refresh_token = Some(refresh_token.to_jwt().unwrap());
+ let expires_in = exp_time.num_seconds();
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ _ => TokenError::unsupported_grant_type().error_response(),
+ }
+}
+
+pub fn service() -> Scope {
+ web::scope("/oauth")
+ .service(authorize_page)
+ .service(authorize)
+ .service(token)
+}
diff --git a/src/api/ops.rs b/src/api/ops.rs
index 555bb1b..2164f1f 100644
--- a/src/api/ops.rs
+++ b/src/api/ops.rs
@@ -1,70 +1,70 @@
-use std::str::FromStr;
-
-use actix_web::{get, http::StatusCode, post, web, HttpResponse, ResponseError, Scope};
-use raise::yeet;
-use serde::Deserialize;
-use sqlx::MySqlPool;
-use tera::Tera;
-use thiserror::Error;
-use unic_langid::subtags::Language;
-
-use crate::resources::{languages, templates};
-use crate::services::db;
-
-/// A request to login
-#[derive(Debug, Clone, Deserialize)]
-struct LoginRequest {
- username: Box<str>,
- password: Box<str>,
-}
-
-/// An error occurred when authenticating, because either the username or
-/// password was invalid.
-#[derive(Debug, Clone, Error)]
-enum LoginFailure {
- #[error("No user found with the given username")]
- UserNotFound { username: Box<str> },
- #[error("The given password is incorrect")]
- IncorrectPassword { username: Box<str> },
-}
-
-impl ResponseError for LoginFailure {
- fn status_code(&self) -> actix_web::http::StatusCode {
- match self {
- Self::UserNotFound { .. } => StatusCode::NOT_FOUND,
- Self::IncorrectPassword { .. } => StatusCode::UNAUTHORIZED,
- }
- }
-}
-
-/// Returns `200` if login was successful.
-/// Returns `404` if the username is invalid.
-/// Returns `401` if the password was invalid.
-#[post("/login")]
-async fn login(
- body: web::Json<LoginRequest>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, LoginFailure> {
- let conn = conn.get_ref();
-
- let user = db::get_user_by_username(conn, &body.username)
- .await
- .unwrap();
- let Some(user) = user else {
- yeet!(LoginFailure::UserNotFound{ username: body.username.clone() });
- };
-
- let good_password = user.check_password(&body.password).unwrap();
- let response = if good_password {
- HttpResponse::Ok().finish()
- } else {
- yeet!(LoginFailure::IncorrectPassword {
- username: body.username.clone()
- });
- };
- Ok(response)
-}
-
-pub fn service() -> Scope {
- web::scope("").service(login)
-}
+use std::str::FromStr;
+
+use actix_web::{get, http::StatusCode, post, web, HttpResponse, ResponseError, Scope};
+use raise::yeet;
+use serde::Deserialize;
+use sqlx::MySqlPool;
+use tera::Tera;
+use thiserror::Error;
+use unic_langid::subtags::Language;
+
+use crate::resources::{languages, templates};
+use crate::services::db;
+
+/// A request to login
+#[derive(Debug, Clone, Deserialize)]
+struct LoginRequest {
+ username: Box<str>,
+ password: Box<str>,
+}
+
+/// An error occurred when authenticating, because either the username or
+/// password was invalid.
+#[derive(Debug, Clone, Error)]
+enum LoginFailure {
+ #[error("No user found with the given username")]
+ UserNotFound { username: Box<str> },
+ #[error("The given password is incorrect")]
+ IncorrectPassword { username: Box<str> },
+}
+
+impl ResponseError for LoginFailure {
+ fn status_code(&self) -> actix_web::http::StatusCode {
+ match self {
+ Self::UserNotFound { .. } => StatusCode::NOT_FOUND,
+ Self::IncorrectPassword { .. } => StatusCode::UNAUTHORIZED,
+ }
+ }
+}
+
+/// Returns `200` if login was successful.
+/// Returns `404` if the username is invalid.
+/// Returns `401` if the password was invalid.
+#[post("/login")]
+async fn login(
+ body: web::Json<LoginRequest>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, LoginFailure> {
+ let conn = conn.get_ref();
+
+ let user = db::get_user_by_username(conn, &body.username)
+ .await
+ .unwrap();
+ let Some(user) = user else {
+ yeet!(LoginFailure::UserNotFound{ username: body.username.clone() });
+ };
+
+ let good_password = user.check_password(&body.password).unwrap();
+ let response = if good_password {
+ HttpResponse::Ok().finish()
+ } else {
+ yeet!(LoginFailure::IncorrectPassword {
+ username: body.username.clone()
+ });
+ };
+ Ok(response)
+}
+
+pub fn service() -> Scope {
+ web::scope("").service(login)
+}
diff --git a/src/api/users.rs b/src/api/users.rs
index 391a059..da2a0d0 100644
--- a/src/api/users.rs
+++ b/src/api/users.rs
@@ -1,272 +1,272 @@
-use actix_web::http::{header, StatusCode};
-use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope};
-use raise::yeet;
-use serde::{Deserialize, Serialize};
-use sqlx::MySqlPool;
-use thiserror::Error;
-use uuid::Uuid;
-
-use crate::models::user::User;
-use crate::services::crypto::PasswordHash;
-use crate::services::{db, id};
-
-/// Just a username. No password hash, because that'd be tempting fate.
-#[derive(Debug, Clone, Serialize)]
-#[serde(rename_all = "camelCase")]
-struct UserResponse {
- id: Uuid,
- username: Box<str>,
-}
-
-impl From<User> for UserResponse {
- fn from(user: User) -> Self {
- Self {
- id: user.id,
- username: user.username,
- }
- }
-}
-
-#[derive(Debug, Clone, Deserialize)]
-#[serde(rename_all = "camelCase")]
-struct SearchUsers {
- username: Option<Box<str>>,
- limit: Option<u32>,
- offset: Option<u32>,
-}
-
-#[get("")]
-async fn search_users(params: web::Query<SearchUsers>, conn: web::Data<MySqlPool>) -> HttpResponse {
- let conn = conn.get_ref();
-
- let username = params.username.clone().unwrap_or_default();
- let offset = params.offset.unwrap_or_default();
-
- let results: Box<[UserResponse]> = if let Some(limit) = params.limit {
- db::search_users_limit(conn, &username, offset, limit)
- .await
- .unwrap()
- .iter()
- .cloned()
- .map(|u| u.into())
- .collect()
- } else {
- db::search_users(conn, &username)
- .await
- .unwrap()
- .into_iter()
- .skip(offset as usize)
- .cloned()
- .map(|u| u.into())
- .collect()
- };
-
- let response = HttpResponse::Ok().json(results);
- response
-}
-
-#[derive(Debug, Clone, Error)]
-#[error("No user with the given ID exists")]
-struct UserNotFoundError {
- user_id: Uuid,
-}
-
-impl ResponseError for UserNotFoundError {
- fn status_code(&self) -> StatusCode {
- StatusCode::NOT_FOUND
- }
-}
-
-#[get("/{user_id}")]
-async fn get_user(
- user_id: web::Path<Uuid>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UserNotFoundError> {
- let conn = conn.get_ref();
-
- let id = user_id.to_owned();
- let username = db::get_username(conn, id).await.unwrap();
-
- let Some(username) = username else {
- yeet!(UserNotFoundError { user_id: id });
- };
-
- let response = UserResponse { id, username };
- let response = HttpResponse::Ok().json(response);
- Ok(response)
-}
-
-#[get("/{user_id}/username")]
-async fn get_username(
- user_id: web::Path<Uuid>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UserNotFoundError> {
- let conn = conn.get_ref();
-
- let user_id = user_id.to_owned();
- let username = db::get_username(conn, user_id).await.unwrap();
-
- let Some(username) = username else {
- yeet!(UserNotFoundError { user_id });
- };
-
- let response = HttpResponse::Ok().json(username);
- Ok(response)
-}
-
-/// A request to create or update user information
-#[derive(Clone, Deserialize)]
-#[serde(rename_all = "camelCase")]
-struct UserRequest {
- username: Box<str>,
- password: Box<str>,
-}
-
-#[derive(Debug, Clone, Error)]
-#[error("An account with the given username already exists.")]
-struct UsernameTakenError {
- username: Box<str>,
-}
-
-impl ResponseError for UsernameTakenError {
- fn status_code(&self) -> StatusCode {
- StatusCode::CONFLICT
- }
-}
-
-#[post("")]
-async fn create_user(
- body: web::Json<UserRequest>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UsernameTakenError> {
- let conn = conn.get_ref();
-
- let user_id = id::new_id(conn, db::user_id_exists).await.unwrap();
- let username = body.username.clone();
- let password = PasswordHash::new(&body.password).unwrap();
-
- if db::username_is_used(conn, &body.username).await.unwrap() {
- yeet!(UsernameTakenError { username });
- }
-
- let user = User {
- id: user_id,
- username,
- password,
- };
-
- db::create_user(conn, &user).await.unwrap();
-
- let response = HttpResponse::Created()
- .insert_header((header::LOCATION, format!("users/{user_id}")))
- .finish();
- Ok(response)
-}
-
-#[derive(Debug, Clone, Error)]
-enum UpdateUserError {
- #[error(transparent)]
- UsernameTaken(#[from] UsernameTakenError),
- #[error(transparent)]
- NotFound(#[from] UserNotFoundError),
-}
-
-impl ResponseError for UpdateUserError {
- fn status_code(&self) -> StatusCode {
- match self {
- Self::UsernameTaken(e) => e.status_code(),
- Self::NotFound(e) => e.status_code(),
- }
- }
-}
-
-#[put("/{user_id}")]
-async fn update_user(
- user_id: web::Path<Uuid>,
- body: web::Json<UserRequest>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateUserError> {
- let conn = conn.get_ref();
-
- let user_id = user_id.to_owned();
- let username = body.username.clone();
- let password = PasswordHash::new(&body.password).unwrap();
-
- let old_username = db::get_username(conn, user_id).await.unwrap().unwrap();
- if username != old_username && db::username_is_used(conn, &body.username).await.unwrap() {
- yeet!(UsernameTakenError { username }.into())
- }
-
- if !db::user_id_exists(conn, user_id).await.unwrap() {
- yeet!(UserNotFoundError { user_id }.into())
- }
-
- let user = User {
- id: user_id,
- username,
- password,
- };
-
- db::update_user(conn, &user).await.unwrap();
-
- let response = HttpResponse::NoContent().finish();
- Ok(response)
-}
-
-#[put("/{user_id}/username")]
-async fn update_username(
- user_id: web::Path<Uuid>,
- body: web::Json<Box<str>>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UpdateUserError> {
- let conn = conn.get_ref();
-
- let user_id = user_id.to_owned();
- let username = body.clone();
-
- let old_username = db::get_username(conn, user_id).await.unwrap().unwrap();
- if username != old_username && db::username_is_used(conn, &body).await.unwrap() {
- yeet!(UsernameTakenError { username }.into())
- }
-
- if !db::user_id_exists(conn, user_id).await.unwrap() {
- yeet!(UserNotFoundError { user_id }.into())
- }
-
- db::update_username(conn, user_id, &body).await.unwrap();
-
- let response = HttpResponse::NoContent().finish();
- Ok(response)
-}
-
-#[put("/{user_id}/password")]
-async fn update_password(
- user_id: web::Path<Uuid>,
- body: web::Json<Box<str>>,
- conn: web::Data<MySqlPool>,
-) -> Result<HttpResponse, UserNotFoundError> {
- let conn = conn.get_ref();
-
- let user_id = user_id.to_owned();
- let password = PasswordHash::new(&body).unwrap();
-
- if !db::user_id_exists(conn, user_id).await.unwrap() {
- yeet!(UserNotFoundError { user_id })
- }
-
- db::update_password(conn, user_id, &password).await.unwrap();
-
- let response = HttpResponse::NoContent().finish();
- Ok(response)
-}
-
-pub fn service() -> Scope {
- web::scope("/users")
- .service(search_users)
- .service(get_user)
- .service(get_username)
- .service(create_user)
- .service(update_user)
- .service(update_username)
- .service(update_password)
-}
+use actix_web::http::{header, StatusCode};
+use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope};
+use raise::yeet;
+use serde::{Deserialize, Serialize};
+use sqlx::MySqlPool;
+use thiserror::Error;
+use uuid::Uuid;
+
+use crate::models::user::User;
+use crate::services::crypto::PasswordHash;
+use crate::services::{db, id};
+
+/// Just a username. No password hash, because that'd be tempting fate.
+#[derive(Debug, Clone, Serialize)]
+#[serde(rename_all = "camelCase")]
+struct UserResponse {
+ id: Uuid,
+ username: Box<str>,
+}
+
+impl From<User> for UserResponse {
+ fn from(user: User) -> Self {
+ Self {
+ id: user.id,
+ username: user.username,
+ }
+ }
+}
+
+#[derive(Debug, Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct SearchUsers {
+ username: Option<Box<str>>,
+ limit: Option<u32>,
+ offset: Option<u32>,
+}
+
+#[get("")]
+async fn search_users(params: web::Query<SearchUsers>, conn: web::Data<MySqlPool>) -> HttpResponse {
+ let conn = conn.get_ref();
+
+ let username = params.username.clone().unwrap_or_default();
+ let offset = params.offset.unwrap_or_default();
+
+ let results: Box<[UserResponse]> = if let Some(limit) = params.limit {
+ db::search_users_limit(conn, &username, offset, limit)
+ .await
+ .unwrap()
+ .iter()
+ .cloned()
+ .map(|u| u.into())
+ .collect()
+ } else {
+ db::search_users(conn, &username)
+ .await
+ .unwrap()
+ .into_iter()
+ .skip(offset as usize)
+ .cloned()
+ .map(|u| u.into())
+ .collect()
+ };
+
+ let response = HttpResponse::Ok().json(results);
+ response
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("No user with the given ID exists")]
+struct UserNotFoundError {
+ user_id: Uuid,
+}
+
+impl ResponseError for UserNotFoundError {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::NOT_FOUND
+ }
+}
+
+#[get("/{user_id}")]
+async fn get_user(
+ user_id: web::Path<Uuid>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UserNotFoundError> {
+ let conn = conn.get_ref();
+
+ let id = user_id.to_owned();
+ let username = db::get_username(conn, id).await.unwrap();
+
+ let Some(username) = username else {
+ yeet!(UserNotFoundError { user_id: id });
+ };
+
+ let response = UserResponse { id, username };
+ let response = HttpResponse::Ok().json(response);
+ Ok(response)
+}
+
+#[get("/{user_id}/username")]
+async fn get_username(
+ user_id: web::Path<Uuid>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UserNotFoundError> {
+ let conn = conn.get_ref();
+
+ let user_id = user_id.to_owned();
+ let username = db::get_username(conn, user_id).await.unwrap();
+
+ let Some(username) = username else {
+ yeet!(UserNotFoundError { user_id });
+ };
+
+ let response = HttpResponse::Ok().json(username);
+ Ok(response)
+}
+
+/// A request to create or update user information
+#[derive(Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct UserRequest {
+ username: Box<str>,
+ password: Box<str>,
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("An account with the given username already exists.")]
+struct UsernameTakenError {
+ username: Box<str>,
+}
+
+impl ResponseError for UsernameTakenError {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::CONFLICT
+ }
+}
+
+#[post("")]
+async fn create_user(
+ body: web::Json<UserRequest>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UsernameTakenError> {
+ let conn = conn.get_ref();
+
+ let user_id = id::new_id(conn, db::user_id_exists).await.unwrap();
+ let username = body.username.clone();
+ let password = PasswordHash::new(&body.password).unwrap();
+
+ if db::username_is_used(conn, &body.username).await.unwrap() {
+ yeet!(UsernameTakenError { username });
+ }
+
+ let user = User {
+ id: user_id,
+ username,
+ password,
+ };
+
+ db::create_user(conn, &user).await.unwrap();
+
+ let response = HttpResponse::Created()
+ .insert_header((header::LOCATION, format!("users/{user_id}")))
+ .finish();
+ Ok(response)
+}
+
+#[derive(Debug, Clone, Error)]
+enum UpdateUserError {
+ #[error(transparent)]
+ UsernameTaken(#[from] UsernameTakenError),
+ #[error(transparent)]
+ NotFound(#[from] UserNotFoundError),
+}
+
+impl ResponseError for UpdateUserError {
+ fn status_code(&self) -> StatusCode {
+ match self {
+ Self::UsernameTaken(e) => e.status_code(),
+ Self::NotFound(e) => e.status_code(),
+ }
+ }
+}
+
+#[put("/{user_id}")]
+async fn update_user(
+ user_id: web::Path<Uuid>,
+ body: web::Json<UserRequest>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateUserError> {
+ let conn = conn.get_ref();
+
+ let user_id = user_id.to_owned();
+ let username = body.username.clone();
+ let password = PasswordHash::new(&body.password).unwrap();
+
+ let old_username = db::get_username(conn, user_id).await.unwrap().unwrap();
+ if username != old_username && db::username_is_used(conn, &body.username).await.unwrap() {
+ yeet!(UsernameTakenError { username }.into())
+ }
+
+ if !db::user_id_exists(conn, user_id).await.unwrap() {
+ yeet!(UserNotFoundError { user_id }.into())
+ }
+
+ let user = User {
+ id: user_id,
+ username,
+ password,
+ };
+
+ db::update_user(conn, &user).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{user_id}/username")]
+async fn update_username(
+ user_id: web::Path<Uuid>,
+ body: web::Json<Box<str>>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateUserError> {
+ let conn = conn.get_ref();
+
+ let user_id = user_id.to_owned();
+ let username = body.clone();
+
+ let old_username = db::get_username(conn, user_id).await.unwrap().unwrap();
+ if username != old_username && db::username_is_used(conn, &body).await.unwrap() {
+ yeet!(UsernameTakenError { username }.into())
+ }
+
+ if !db::user_id_exists(conn, user_id).await.unwrap() {
+ yeet!(UserNotFoundError { user_id }.into())
+ }
+
+ db::update_username(conn, user_id, &body).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{user_id}/password")]
+async fn update_password(
+ user_id: web::Path<Uuid>,
+ body: web::Json<Box<str>>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UserNotFoundError> {
+ let conn = conn.get_ref();
+
+ let user_id = user_id.to_owned();
+ let password = PasswordHash::new(&body).unwrap();
+
+ if !db::user_id_exists(conn, user_id).await.unwrap() {
+ yeet!(UserNotFoundError { user_id })
+ }
+
+ db::update_password(conn, user_id, &password).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+pub fn service() -> Scope {
+ web::scope("/users")
+ .service(search_users)
+ .service(get_user)
+ .service(get_username)
+ .service(create_user)
+ .service(update_user)
+ .service(update_username)
+ .service(update_password)
+}