summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormrw1593 <botahamec@outlook.com>2023-05-28 16:31:22 -0400
committermrw1593 <botahamec@outlook.com>2023-05-29 10:51:10 -0400
commit614c81c0f239940acb313e067dafc3213f399b10 (patch)
tree68835a73c225a3b4fefa590b173db1cd9d7a28b2
parente048d7d050f87e9e5bf06f01e39fd417d6868c7e (diff)
Add clients to the API
-rw-r--r--src/api/clients.rs311
-rw-r--r--src/api/mod.rs3
-rw-r--r--src/api/oauth.rs24
-rw-r--r--src/api/users.rs31
-rw-r--r--src/main.rs1
-rw-r--r--src/models/client.rs37
-rw-r--r--src/services/db.rs242
-rw-r--r--src/services/db/client.rs236
-rw-r--r--src/services/db/user.rs236
-rw-r--r--src/services/id.rs20
10 files changed, 878 insertions, 263 deletions
diff --git a/src/api/clients.rs b/src/api/clients.rs
new file mode 100644
index 0000000..7e8ca35
--- /dev/null
+++ b/src/api/clients.rs
@@ -0,0 +1,311 @@
+use actix_web::http::{header, StatusCode};
+use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope};
+use raise::yeet;
+use serde::Deserialize;
+use sqlx::MySqlPool;
+use thiserror::Error;
+use url::Url;
+use uuid::Uuid;
+
+use crate::models::client::{Client, ClientType, NoSecretError};
+use crate::services::crypto::PasswordHash;
+use crate::services::{db, id};
+
+#[derive(Debug, Clone, Copy, Error)]
+#[error("No client with the given client ID was found")]
+struct ClientNotFound {
+ id: Uuid,
+}
+
+impl ResponseError for ClientNotFound {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::NOT_FOUND
+ }
+}
+
+impl ClientNotFound {
+ fn new(id: Uuid) -> Self {
+ Self { id }
+ }
+}
+
+#[get("/{client_id}")]
+async fn get_client(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(client) = db::get_client_response(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let redirect_uris_link = format!("</clients/{client_id}/redirect-uris>; rel=\"redirect-uris\"");
+ let response = HttpResponse::Ok()
+ .append_header((header::LINK, redirect_uris_link))
+ .json(client);
+ Ok(response)
+}
+
+#[get("/{client_id}/alias")]
+async fn get_client_alias(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(alias) = db::get_client_alias(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ Ok(HttpResponse::Ok().json(alias))
+}
+
+#[get("/{client_id}/type")]
+async fn get_client_type(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(client_type) = db::get_client_type(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ Ok(HttpResponse::Ok().json(client_type))
+}
+
+#[get("/{client_id}/redirect-uris")]
+async fn get_client_redirect_uris(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let redirect_uris = db::get_client_redirect_uris(db, id).await.unwrap();
+
+ Ok(HttpResponse::Ok().json(redirect_uris))
+}
+
+#[derive(Debug, Clone, Deserialize)]
+struct ClientRequest {
+ alias: Box<str>,
+ ty: ClientType,
+ redirect_uris: Box<[Url]>,
+ secret: Option<Box<str>>,
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("The given client alias is already taken")]
+struct AliasTakenError {
+ alias: Box<str>,
+}
+
+impl ResponseError for AliasTakenError {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::CONFLICT
+ }
+}
+
+impl AliasTakenError {
+ fn new(alias: &str) -> Self {
+ Self {
+ alias: Box::from(alias),
+ }
+ }
+}
+
+#[post("")]
+async fn create_client(
+ body: web::Json<ClientRequest>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let alias = &body.alias;
+
+ if db::client_alias_exists(db, &alias).await.unwrap() {
+ yeet!(AliasTakenError::new(&alias).into());
+ }
+
+ let id = id::new_id(db, db::client_id_exists).await.unwrap();
+ let client = Client::new(
+ id,
+ &alias,
+ body.ty,
+ body.secret.as_deref(),
+ &body.redirect_uris,
+ )
+ .map_err(|e| e.unwrap())?;
+
+ let transaction = db.begin().await.unwrap();
+ db::create_client(transaction, &client).await.unwrap();
+
+ let response = HttpResponse::Created()
+ .insert_header((header::LOCATION, format!("clients/{id}")))
+ .finish();
+ Ok(response)
+}
+
+#[derive(Debug, Clone, Error)]
+enum UpdateClientError {
+ #[error(transparent)]
+ NotFound(#[from] ClientNotFound),
+ #[error(transparent)]
+ NoSecret(#[from] NoSecretError),
+ #[error(transparent)]
+ AliasTaken(#[from] AliasTakenError),
+}
+
+impl ResponseError for UpdateClientError {
+ fn status_code(&self) -> StatusCode {
+ match self {
+ Self::NotFound(e) => e.status_code(),
+ Self::NoSecret(e) => e.status_code(),
+ Self::AliasTaken(e) => e.status_code(),
+ }
+ }
+}
+
+#[put("/{id}")]
+async fn update_client(
+ id: web::Path<Uuid>,
+ body: web::Json<ClientRequest>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let alias = &body.alias;
+
+ let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id).into())
+ };
+ if old_alias != alias.clone() && db::client_alias_exists(db, &alias).await.unwrap() {
+ yeet!(AliasTakenError::new(&alias).into());
+ }
+
+ let client = Client::new(
+ id,
+ &alias,
+ body.ty,
+ body.secret.as_deref(),
+ &body.redirect_uris,
+ )
+ .map_err(|e| e.unwrap())?;
+
+ let transaction = db.begin().await.unwrap();
+ db::update_client(transaction, &client).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{id}/alias")]
+async fn update_client_alias(
+ id: web::Path<Uuid>,
+ body: web::Json<Box<str>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let alias = body.0;
+
+ let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id).into())
+ };
+ if old_alias == alias {
+ return Ok(HttpResponse::NoContent().finish());
+ }
+ if db::client_alias_exists(db, &alias).await.unwrap() {
+ yeet!(AliasTakenError::new(&alias).into());
+ }
+
+ db::update_client_alias(db, id, &alias).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{id}/type")]
+async fn update_client_type(
+ id: web::Path<Uuid>,
+ body: web::Json<ClientType>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let ty = body.0;
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_type(db, id, ty).await.unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("/{id}/redirect-uris")]
+async fn update_client_redirect_uris(
+ id: web::Path<Uuid>,
+ body: web::Json<Box<[Url]>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ let transaction = db.begin().await.unwrap();
+ db::update_client_redirect_uris(transaction, id, &body.0)
+ .await
+ .unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("{id}/secret")]
+async fn update_client_secret(
+ id: web::Path<Uuid>,
+ body: web::Json<Option<Box<str>>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+
+ let Some(client_type) = db::get_client_type(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id).into())
+ };
+
+ if client_type == ClientType::Confidential && body.is_none() {
+ yeet!(NoSecretError::new().into())
+ }
+
+ let secret = body.0.map(|s| PasswordHash::new(&s).unwrap());
+ db::update_client_secret(db, id, secret).await.unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+pub fn service() -> Scope {
+ web::scope("/clients")
+ .service(get_client)
+ .service(get_client_alias)
+ .service(get_client_type)
+ .service(get_client_redirect_uris)
+ .service(create_client)
+ .service(update_client)
+ .service(update_client_alias)
+ .service(update_client_type)
+ .service(update_client_redirect_uris)
+ .service(update_client_secret)
+}
diff --git a/src/api/mod.rs b/src/api/mod.rs
index 7627a60..3d74be8 100644
--- a/src/api/mod.rs
+++ b/src/api/mod.rs
@@ -1,7 +1,10 @@
mod liveops;
mod ops;
mod users;
+mod oauth;
+mod clients;
pub use liveops::service as liveops;
pub use ops::service as ops;
pub use users::service as users;
+pub use clients::service as clients;
diff --git a/src/api/oauth.rs b/src/api/oauth.rs
new file mode 100644
index 0000000..9e0e5c6
--- /dev/null
+++ b/src/api/oauth.rs
@@ -0,0 +1,24 @@
+use std::collections::HashMap;
+
+use actix_web::{web, HttpResponse};
+use serde::Deserialize;
+use url::Url;
+use uuid::Uuid;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
+#[serde(rename_all = "snake_case")]
+enum ResponseType {
+ Code,
+ Token,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+struct AuthorizationParameters {
+ response_type: ResponseType,
+ client_id: Uuid,
+ redirect_uri: Url,
+ state: Box<str>,
+
+ #[serde(flatten)]
+ additional_parameters: HashMap<Box<str>, Box<str>>,
+}
diff --git a/src/api/users.rs b/src/api/users.rs
index 2b67663..2cd70c0 100644
--- a/src/api/users.rs
+++ b/src/api/users.rs
@@ -12,6 +12,7 @@ use crate::services::{db, id};
/// Just a username. No password hash, because that'd be tempting fate.
#[derive(Debug, Clone, Serialize)]
+#[serde(rename_all = "camelCase")]
struct UserResponse {
id: Uuid,
username: Box<str>,
@@ -27,6 +28,7 @@ impl From<User> for UserResponse {
}
#[derive(Debug, Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
struct SearchUsers {
username: Option<Box<str>>,
limit: Option<u32>,
@@ -82,14 +84,14 @@ async fn get_user(
) -> Result<HttpResponse, UserNotFoundError> {
let conn = conn.get_ref();
- let user_id = user_id.to_owned();
- let user = db::get_user(conn, user_id).await.unwrap();
+ let id = user_id.to_owned();
+ let username = db::get_username(conn, id).await.unwrap();
- let Some(user) = user else {
- yeet!(UserNotFoundError {user_id});
+ let Some(username) = username else {
+ yeet!(UserNotFoundError { user_id: id });
};
- let response: UserResponse = user.into();
+ let response = UserResponse { id, username };
let response = HttpResponse::Ok().json(response);
Ok(response)
}
@@ -114,6 +116,7 @@ async fn get_username(
/// A request to create or update user information
#[derive(Debug, Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
struct UserRequest {
username: Box<str>,
password: Box<str>,
@@ -138,7 +141,7 @@ async fn create_user(
) -> Result<HttpResponse, UsernameTakenError> {
let conn = conn.get_ref();
- let user_id = id::new_user_id(conn).await.unwrap();
+ let user_id = id::new_id(conn, db::user_id_exists).await.unwrap();
let username = body.username.clone();
let password = PasswordHash::new(&body.password).unwrap();
@@ -152,7 +155,7 @@ async fn create_user(
password,
};
- db::new_user(conn, &user).await.unwrap();
+ db::create_user(conn, &user).await.unwrap();
let response = HttpResponse::Created()
.insert_header((header::LOCATION, format!("users/{user_id}")))
@@ -171,8 +174,8 @@ enum UpdateUserError {
impl ResponseError for UpdateUserError {
fn status_code(&self) -> StatusCode {
match self {
- Self::UsernameTaken(..) => StatusCode::CONFLICT,
- Self::NotFound(..) => StatusCode::NOT_FOUND,
+ Self::UsernameTaken(e) => e.status_code(),
+ Self::NotFound(e) => e.status_code(),
}
}
}
@@ -206,10 +209,7 @@ async fn update_user(
db::update_user(conn, &user).await.unwrap();
- let response = HttpResponse::NoContent()
- .insert_header((header::LOCATION, format!("users/{user_id}")))
- .finish();
-
+ let response = HttpResponse::NoContent().finish();
Ok(response)
}
@@ -235,10 +235,7 @@ async fn update_username(
db::update_username(conn, user_id, &body).await.unwrap();
- let response = HttpResponse::NoContent()
- .insert_header((header::LOCATION, format!("users/{user_id}/username")))
- .finish();
-
+ let response = HttpResponse::NoContent().finish();
Ok(response)
}
diff --git a/src/main.rs b/src/main.rs
index 7b25dd1..aca5977 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -56,6 +56,7 @@ async fn main() -> Result<(), RawUnexpected> {
// api services
.service(api::liveops())
.service(api::users())
+ .service(api::clients())
.service(api::ops())
})
.shutdown_timeout(1)
diff --git a/src/models/client.rs b/src/models/client.rs
index a7df936..44079de 100644
--- a/src/models/client.rs
+++ b/src/models/client.rs
@@ -1,7 +1,10 @@
use std::{hash::Hash, marker::PhantomData};
+use actix_web::{http::StatusCode, ResponseError};
use exun::{Expect, RawUnexpected};
use raise::yeet;
+use serde::{Deserialize, Serialize};
+use sqlx::FromRow;
use thiserror::Error;
use url::Url;
use uuid::Uuid;
@@ -10,8 +13,9 @@ use crate::services::crypto::PasswordHash;
/// There are two types of clients, based on their ability to maintain the
/// security of their client credentials.
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, sqlx::Type)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
#[sqlx(rename_all = "lowercase")]
+#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ClientType {
/// A client that is capable of maintaining the confidentiality of their
/// credentials, or capable of secure client authentication using other
@@ -26,12 +30,21 @@ pub enum ClientType {
#[derive(Debug, Clone)]
pub struct Client {
- ty: ClientType,
id: Uuid,
+ ty: ClientType,
+ alias: Box<str>,
secret: Option<PasswordHash>,
redirect_uris: Box<[Url]>,
}
+#[derive(Debug, Clone, Serialize, FromRow)]
+#[serde(rename_all = "camelCase")]
+pub struct ClientResponse {
+ pub id: Uuid,
+ pub alias: String,
+ pub client_type: ClientType,
+}
+
impl PartialEq for Client {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
@@ -52,8 +65,14 @@ pub struct NoSecretError {
_phantom: PhantomData<()>,
}
+impl ResponseError for NoSecretError {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::BAD_REQUEST
+ }
+}
+
impl NoSecretError {
- fn new() -> Self {
+ pub(crate) fn new() -> Self {
Self {
_phantom: PhantomData,
}
@@ -61,8 +80,9 @@ impl NoSecretError {
}
impl Client {
- pub fn new_public(
+ pub fn new(
id: Uuid,
+ alias: &str,
ty: ClientType,
secret: Option<&str>,
redirect_uris: &[Url],
@@ -79,6 +99,7 @@ impl Client {
Ok(Self {
id,
+ alias: Box::from(alias),
ty: ClientType::Public,
secret,
redirect_uris: redirect_uris.into_iter().cloned().collect(),
@@ -89,10 +110,18 @@ impl Client {
self.id
}
+ pub fn alias(&self) -> &str {
+ &self.alias
+ }
+
pub fn client_type(&self) -> ClientType {
self.ty
}
+ pub fn redirect_uris(&self) -> &[Url] {
+ &self.redirect_uris
+ }
+
pub fn secret_hash(&self) -> Option<&[u8]> {
self.secret.as_ref().map(|s| s.hash())
}
diff --git a/src/services/db.rs b/src/services/db.rs
index 79df260..9789e51 100644
--- a/src/services/db.rs
+++ b/src/services/db.rs
@@ -1,243 +1,13 @@
-use exun::*;
-use sqlx::{mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, MySql, MySqlPool};
-use uuid::Uuid;
+use exun::{RawUnexpected, ResultErrorExt};
+use sqlx::MySqlPool;
-use crate::models::user::User;
+mod client;
+mod user;
-use super::crypto::PasswordHash;
-
-struct UserRow {
- id: Vec<u8>,
- username: String,
- password_hash: Vec<u8>,
- password_salt: Vec<u8>,
- password_version: u32,
-}
-
-impl TryFrom<UserRow> for User {
- type Error = RawUnexpected;
-
- fn try_from(row: UserRow) -> Result<Self, Self::Error> {
- let password = PasswordHash::from_fields(
- &row.password_hash,
- &row.password_salt,
- row.password_version as u8,
- );
- let user = User {
- id: Uuid::from_slice(&row.id)?,
- username: row.username.into_boxed_str(),
- password,
- };
- Ok(user)
- }
-}
+pub use client::*;
+pub use user::*;
/// Intialize the connection pool
pub async fn initialize(db_url: &str) -> Result<MySqlPool, RawUnexpected> {
MySqlPool::connect(db_url).await.unexpect()
}
-
-/// Check if a user with a given user ID exists
-pub async fn user_id_exists<'c>(
- conn: impl Executor<'c, Database = MySql>,
- id: Uuid,
-) -> Result<bool, RawUnexpected> {
- let exists = query_scalar!(
- r#"SELECT EXISTS(SELECT id FROM users WHERE id = ?) as "e: bool""#,
- id
- )
- .fetch_one(conn)
- .await?;
-
- Ok(exists)
-}
-
-/// Check if a given username is taken
-pub async fn username_is_used<'c>(
- conn: impl Executor<'c, Database = MySql>,
- username: &str,
-) -> Result<bool, RawUnexpected> {
- let exists = query_scalar!(
- r#"SELECT EXISTS(SELECT id FROM users WHERE username = ?) as "e: bool""#,
- username
- )
- .fetch_one(conn)
- .await?;
-
- Ok(exists)
-}
-
-/// Get a user from their ID
-pub async fn get_user<'c>(
- conn: impl Executor<'c, Database = MySql>,
- user_id: Uuid,
-) -> Result<Option<User>, RawUnexpected> {
- let record = query_as!(
- UserRow,
- r"SELECT id, username, password_hash, password_salt, password_version
- FROM users WHERE id = ?",
- user_id
- )
- .fetch_optional(conn)
- .await?;
-
- let Some(record) = record else { return Ok(None) };
-
- Ok(Some(record.try_into()?))
-}
-
-/// Get a user from their username
-pub async fn get_user_by_username<'c>(
- conn: impl Executor<'c, Database = MySql>,
- username: &str,
-) -> Result<Option<User>, RawUnexpected> {
- let record = query_as!(
- UserRow,
- r"SELECT id, username, password_hash, password_salt, password_version
- FROM users WHERE username = ?",
- username
- )
- .fetch_optional(conn)
- .await?;
-
- let Some(record) = record else { return Ok(None) };
-
- Ok(Some(record.try_into()?))
-}
-
-/// Search the list of users for a given username
-pub async fn search_users<'c>(
- conn: impl Executor<'c, Database = MySql>,
- username: &str,
-) -> Result<Box<[User]>, RawUnexpected> {
- let records = query_as!(
- UserRow,
- r"SELECT id, username, password_hash, password_salt, password_version
- FROM users
- WHERE LOCATE(?, username) != 0",
- username,
- )
- .fetch_all(conn)
- .await?;
-
- Ok(records
- .into_iter()
- .map(|u| u.try_into())
- .collect::<Result<Box<[User]>, RawUnexpected>>()?)
-}
-
-/// Search the list of users, only returning a certain range of results
-pub async fn search_users_limit<'c>(
- conn: impl Executor<'c, Database = MySql>,
- username: &str,
- offset: u32,
- limit: u32,
-) -> Result<Box<[User]>, RawUnexpected> {
- let records = query_as!(
- UserRow,
- r"SELECT id, username, password_hash, password_salt, password_version
- FROM users
- WHERE LOCATE(?, username) != 0
- LIMIT ?
- OFFSET ?",
- username,
- offset,
- limit
- )
- .fetch_all(conn)
- .await?;
-
- Ok(records
- .into_iter()
- .map(|u| u.try_into())
- .collect::<Result<Box<[User]>, RawUnexpected>>()?)
-}
-
-/// Get the username of a user with a certain ID
-pub async fn get_username<'c>(
- conn: impl Executor<'c, Database = MySql>,
- user_id: Uuid,
-) -> Result<Option<Box<str>>, RawUnexpected> {
- let username = query_scalar!(r"SELECT username FROM users where id = ?", user_id)
- .fetch_optional(conn)
- .await?
- .map(String::into_boxed_str);
-
- Ok(username)
-}
-
-/// Create a new user
-pub async fn new_user<'c>(
- conn: impl Executor<'c, Database = MySql>,
- user: &User,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- query!(
- r"INSERT INTO users (id, username, password_hash, password_salt, password_version)
- VALUES (?, ?, ?, ?, ?)",
- user.id,
- user.username(),
- user.password_hash(),
- user.password_salt(),
- user.password_version()
- )
- .execute(conn)
- .await
-}
-
-/// Update a user
-pub async fn update_user<'c>(
- conn: impl Executor<'c, Database = MySql>,
- user: &User,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- query!(
- r"UPDATE users SET
- username = ?,
- password_hash = ?,
- password_salt = ?,
- password_version = ?
- WHERE id = ?",
- user.username(),
- user.password_hash(),
- user.password_salt(),
- user.password_version(),
- user.id
- )
- .execute(conn)
- .await
-}
-
-/// Update the username of a user with the given ID
-pub async fn update_username<'c>(
- conn: impl Executor<'c, Database = MySql>,
- user_id: Uuid,
- username: &str,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- query!(
- r"UPDATE users SET username = ? WHERE id = ?",
- username,
- user_id
- )
- .execute(conn)
- .await
-}
-
-/// Update the password of a user with the given ID
-pub async fn update_password<'c>(
- conn: impl Executor<'c, Database = MySql>,
- user_id: Uuid,
- password: &PasswordHash,
-) -> Result<MySqlQueryResult, sqlx::Error> {
- query!(
- r"UPDATE users SET
- password_hash = ?,
- password_salt = ?,
- password_version = ?
- WHERE id = ?",
- password.hash(),
- password.salt(),
- password.version(),
- user_id
- )
- .execute(conn)
- .await
-}
diff --git a/src/services/db/client.rs b/src/services/db/client.rs
new file mode 100644
index 0000000..d1531be
--- /dev/null
+++ b/src/services/db/client.rs
@@ -0,0 +1,236 @@
+use std::str::FromStr;
+
+use exun::{RawUnexpected, ResultErrorExt};
+use sqlx::{mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, MySql, Transaction};
+use url::Url;
+use uuid::Uuid;
+
+use crate::{
+ models::client::{Client, ClientResponse, ClientType},
+ services::crypto::PasswordHash,
+};
+
+pub async fn client_id_exists<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ r"SELECT EXISTS(SELECT id FROM clients WHERE id = ?) as `e: bool`",
+ id
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn client_alias_exists<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ alias: &str,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ "SELECT EXISTS(SELECT alias FROM clients WHERE alias = ?) as `e: bool`",
+ alias
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn get_client_response<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<ClientResponse>, RawUnexpected> {
+ let record = query_as!(
+ ClientResponse,
+ r"SELECT id as `id: Uuid`,
+ alias,
+ type as `client_type: ClientType`
+ FROM clients WHERE id = ?",
+ id
+ )
+ .fetch_optional(executor)
+ .await?;
+
+ Ok(record)
+}
+
+pub async fn get_client_alias<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<Box<str>>, RawUnexpected> {
+ let alias = query_scalar!("SELECT alias FROM clients WHERE id = ?", id)
+ .fetch_optional(executor)
+ .await
+ .unexpect()?;
+
+ Ok(alias.map(String::into_boxed_str))
+}
+
+pub async fn get_client_type<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<ClientType>, RawUnexpected> {
+ let ty = query_scalar!(
+ "SELECT type as `type: ClientType` FROM clients WHERE id = ?",
+ id
+ )
+ .fetch_optional(executor)
+ .await
+ .unexpect()?;
+
+ Ok(ty)
+}
+
+pub async fn get_client_redirect_uris<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Box<[Url]>, RawUnexpected> {
+ let uris = query_scalar!(
+ "SELECT redirect_uri FROM client_redirect_uris WHERE client_id = ?",
+ id
+ )
+ .fetch_all(executor)
+ .await
+ .unexpect()?;
+
+ uris.into_iter()
+ .map(|s| Url::from_str(&s).unexpect())
+ .collect()
+}
+
+async fn delete_client_redirect_uris<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<(), sqlx::Error> {
+ query!("DELETE FROM client_redirect_uris WHERE client_id = ?", id)
+ .execute(executor)
+ .await?;
+ Ok(())
+}
+
+async fn create_client_redirect_uris<'c>(
+ mut transaction: Transaction<'c, MySql>,
+ client_id: Uuid,
+ uris: &[Url],
+) -> Result<(), sqlx::Error> {
+ for uri in uris {
+ query!(
+ r"INSERT INTO client_redirect_uris (client_id, redirect_uri)
+ VALUES ( ?, ?)",
+ client_id,
+ uri.to_string()
+ )
+ .execute(&mut transaction)
+ .await?;
+ }
+
+ transaction.commit().await?;
+
+ Ok(())
+}
+
+pub async fn create_client<'c>(
+ mut transaction: Transaction<'c, MySql>,
+ client: &Client,
+) -> Result<(), sqlx::Error> {
+ query!(
+ r"INSERT INTO clients (id, alias, type, secret_hash, secret_salt, secret_version)
+ VALUES ( ?, ?, ?, ?, ?, ?)",
+ client.id(),
+ client.alias(),
+ client.client_type(),
+ client.secret_hash(),
+ client.secret_salt(),
+ client.secret_version(),
+ )
+ .execute(&mut transaction)
+ .await?;
+
+ create_client_redirect_uris(transaction, client.id(), client.redirect_uris()).await?;
+
+ Ok(())
+}
+
+pub async fn update_client<'c>(
+ mut transaction: Transaction<'c, MySql>,
+ client: &Client,
+) -> Result<(), sqlx::Error> {
+ query!(
+ r"UPDATE clients SET
+ alias = ?,
+ type = ?,
+ secret_hash = ?,
+ secret_salt = ?,
+ secret_version = ?
+ WHERE id = ?",
+ client.client_type(),
+ client.alias(),
+ client.secret_hash(),
+ client.secret_salt(),
+ client.secret_version(),
+ client.id()
+ )
+ .execute(&mut transaction)
+ .await?;
+
+ update_client_redirect_uris(transaction, client.id(), client.redirect_uris()).await?;
+
+ Ok(())
+}
+
+pub async fn update_client_alias<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ alias: &str,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!("UPDATE clients SET alias = ? WHERE id = ?", alias, id)
+ .execute(executor)
+ .await
+}
+
+pub async fn update_client_type<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ ty: ClientType,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!("UPDATE clients SET type = ? WHERE id = ?", ty, id)
+ .execute(executor)
+ .await
+}
+
+pub async fn update_client_redirect_uris<'c>(
+ mut transaction: Transaction<'c, MySql>,
+ id: Uuid,
+ uris: &[Url],
+) -> Result<(), sqlx::Error> {
+ delete_client_redirect_uris(&mut transaction, id).await?;
+ create_client_redirect_uris(transaction, id, uris).await?;
+ Ok(())
+}
+
+pub async fn update_client_secret<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ secret: Option<PasswordHash>,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ if let Some(secret) = secret {
+ query!(
+ "UPDATE clients SET secret_hash = ?, secret_salt = ?, secret_version = ? WHERE id = ?",
+ secret.hash(),
+ secret.salt(),
+ secret.version(),
+ id
+ )
+ .execute(executor)
+ .await
+ } else {
+ query!(
+ r"UPDATE clients
+ SET secret_hash = NULL, secret_salt = NULL, secret_version = NULL
+ WHERE id = ?",
+ id
+ )
+ .execute(executor)
+ .await
+ }
+}
diff --git a/src/services/db/user.rs b/src/services/db/user.rs
new file mode 100644
index 0000000..09a09da
--- /dev/null
+++ b/src/services/db/user.rs
@@ -0,0 +1,236 @@
+use exun::RawUnexpected;
+use sqlx::{mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, MySql};
+use uuid::Uuid;
+
+use crate::{models::user::User, services::crypto::PasswordHash};
+
+struct UserRow {
+ id: Uuid,
+ username: String,
+ password_hash: Vec<u8>,
+ password_salt: Vec<u8>,
+ password_version: u32,
+}
+
+impl TryFrom<UserRow> for User {
+ type Error = RawUnexpected;
+
+ fn try_from(row: UserRow) -> Result<Self, Self::Error> {
+ let password = PasswordHash::from_fields(
+ &row.password_hash,
+ &row.password_salt,
+ row.password_version as u8,
+ );
+ let user = User {
+ id: row.id,
+ username: row.username.into_boxed_str(),
+ password,
+ };
+ Ok(user)
+ }
+}
+
+/// Check if a user with a given user ID exists
+pub async fn user_id_exists<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<bool, RawUnexpected> {
+ let exists = query_scalar!(
+ r#"SELECT EXISTS(SELECT id FROM users WHERE id = ?) as `e: bool`"#,
+ id
+ )
+ .fetch_one(conn)
+ .await?;
+
+ Ok(exists)
+}
+
+/// Check if a given username is taken
+pub async fn username_is_used<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ username: &str,
+) -> Result<bool, RawUnexpected> {
+ let exists = query_scalar!(
+ r#"SELECT EXISTS(SELECT id FROM users WHERE username = ?) as "e: bool""#,
+ username
+ )
+ .fetch_one(conn)
+ .await?;
+
+ Ok(exists)
+}
+
+/// Get a user from their ID
+pub async fn get_user<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user_id: Uuid,
+) -> Result<Option<User>, RawUnexpected> {
+ let record = query_as!(
+ UserRow,
+ r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
+ FROM users WHERE id = ?",
+ user_id
+ )
+ .fetch_optional(conn)
+ .await?;
+
+ let Some(record) = record else { return Ok(None) };
+
+ Ok(Some(record.try_into()?))
+}
+
+/// Get a user from their username
+pub async fn get_user_by_username<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ username: &str,
+) -> Result<Option<User>, RawUnexpected> {
+ let record = query_as!(
+ UserRow,
+ r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
+ FROM users WHERE username = ?",
+ username
+ )
+ .fetch_optional(conn)
+ .await?;
+
+ let Some(record) = record else { return Ok(None) };
+
+ Ok(Some(record.try_into()?))
+}
+
+/// Search the list of users for a given username
+pub async fn search_users<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ username: &str,
+) -> Result<Box<[User]>, RawUnexpected> {
+ let records = query_as!(
+ UserRow,
+ r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
+ FROM users
+ WHERE LOCATE(?, username) != 0",
+ username,
+ )
+ .fetch_all(conn)
+ .await?;
+
+ Ok(records
+ .into_iter()
+ .map(|u| u.try_into())
+ .collect::<Result<Box<[User]>, RawUnexpected>>()?)
+}
+
+/// Search the list of users, only returning a certain range of results
+pub async fn search_users_limit<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ username: &str,
+ offset: u32,
+ limit: u32,
+) -> Result<Box<[User]>, RawUnexpected> {
+ let records = query_as!(
+ UserRow,
+ r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
+ FROM users
+ WHERE LOCATE(?, username) != 0
+ LIMIT ?
+ OFFSET ?",
+ username,
+ offset,
+ limit
+ )
+ .fetch_all(conn)
+ .await?;
+
+ Ok(records
+ .into_iter()
+ .map(|u| u.try_into())
+ .collect::<Result<Box<[User]>, RawUnexpected>>()?)
+}
+
+/// Get the username of a user with a certain ID
+pub async fn get_username<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user_id: Uuid,
+) -> Result<Option<Box<str>>, RawUnexpected> {
+ let username = query_scalar!(r"SELECT username FROM users where id = ?", user_id)
+ .fetch_optional(conn)
+ .await?
+ .map(String::into_boxed_str);
+
+ Ok(username)
+}
+
+/// Create a new user
+pub async fn create_user<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user: &User,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ r"INSERT INTO users (id, username, password_hash, password_salt, password_version)
+ VALUES ( ?, ?, ?, ?, ?)",
+ user.id,
+ user.username(),
+ user.password_hash(),
+ user.password_salt(),
+ user.password_version()
+ )
+ .execute(conn)
+ .await
+}
+
+/// Update a user
+pub async fn update_user<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user: &User,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ r"UPDATE users SET
+ username = ?,
+ password_hash = ?,
+ password_salt = ?,
+ password_version = ?
+ WHERE id = ?",
+ user.username(),
+ user.password_hash(),
+ user.password_salt(),
+ user.password_version(),
+ user.id
+ )
+ .execute(conn)
+ .await
+}
+
+/// Update the username of a user with the given ID
+pub async fn update_username<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user_id: Uuid,
+ username: &str,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ r"UPDATE users SET username = ? WHERE id = ?",
+ username,
+ user_id
+ )
+ .execute(conn)
+ .await
+}
+
+/// Update the password of a user with the given ID
+pub async fn update_password<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user_id: Uuid,
+ password: &PasswordHash,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ r"UPDATE users SET
+ password_hash = ?,
+ password_salt = ?,
+ password_version = ?
+ WHERE id = ?",
+ password.hash(),
+ password.salt(),
+ password.version(),
+ user_id
+ )
+ .execute(conn)
+ .await
+}
diff --git a/src/services/id.rs b/src/services/id.rs
index 7970c60..0c665ed 100644
--- a/src/services/id.rs
+++ b/src/services/id.rs
@@ -1,16 +1,24 @@
+use std::future::Future;
+
use exun::RawUnexpected;
use sqlx::{Executor, MySql};
use uuid::Uuid;
-use super::db;
-
-/// Create a unique user id, handling duplicate ID's
-pub async fn new_user_id<'c>(
- conn: impl Executor<'c, Database = MySql> + Clone,
+/// Create a unique id, handling duplicate ID's.
+///
+/// The given `unique_check` parameter returns `true` if the ID is used and
+/// `false` otherwise.
+pub async fn new_id<
+ 'c,
+ E: Executor<'c, Database = MySql> + Clone,
+ F: Future<Output = Result<bool, RawUnexpected>>,
+>(
+ conn: E,
+ unique_check: impl Fn(E, Uuid) -> F,
) -> Result<Uuid, RawUnexpected> {
let uuid = loop {
let uuid = Uuid::new_v4();
- if !db::user_id_exists(conn.clone(), uuid).await? {
+ if !unique_check(conn.clone(), uuid).await? {
break uuid;
}
};