//! SQLite 数据库存储实现 use async_trait::async_trait; use sqlx::{SqlitePool, Row}; use uuid::Uuid; use chrono::{DateTime, Utc}; use crate::models::user::User; use crate::models::pagination::PaginationParams; use crate::models::search::UserSearchParams; use crate::models::role::UserRole; use crate::utils::errors::ApiError; use crate::storage::{UserStore, MigrationManager}; /// SQLite 用户存储 #[derive(Clone)] pub struct DatabaseUserStore { pool: SqlitePool, } impl DatabaseUserStore { /// 创建新的数据库存储实例 pub fn new(pool: SqlitePool) -> Self { Self { pool } } /// 从数据库 URL 创建新的数据库存储实例 pub async fn from_url(database_url: &str) -> Result { let pool = SqlitePool::connect(database_url) .await .map_err(|e| ApiError::InternalError(format!("无法连接到数据库: {}", e)))?; let store = Self::new(pool.clone()); // 使用迁移系统初始化数据库 let migration_manager = MigrationManager::new(pool); migration_manager.run_migrations().await?; Ok(store) } /// 初始化数据库表 (已弃用,现在使用迁移系统) #[allow(dead_code)] pub async fn init_tables(&self) -> Result<(), ApiError> { // 这个方法已被迁移系统替代 // 保留用于向后兼容,但不再使用 tracing::warn!("⚠️ init_tables 方法已弃用,请使用迁移系统"); Ok(()) } /// 创建用户 async fn create_user_impl(&self, user: User) -> Result { let result = sqlx::query( r#" INSERT INTO users (id, username, email, password_hash, role, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?) "#, ) .bind(user.id.to_string()) .bind(&user.username) .bind(&user.email) .bind(&user.password_hash) .bind(user.role.as_str()) .bind(user.created_at.to_rfc3339()) .bind(user.updated_at.to_rfc3339()) .execute(&self.pool) .await; match result { Ok(_) => Ok(user), Err(sqlx::Error::Database(db_err)) if db_err.is_unique_violation() => { Err(ApiError::Conflict("用户名已存在".to_string())) } Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))), } } /// 根据 ID 获取用户 async fn get_user_impl(&self, id: &Uuid) -> Result, ApiError> { let result = sqlx::query( "SELECT id, username, email, password_hash, role, created_at, updated_at FROM users WHERE id = ?" ) .bind(id.to_string()) .fetch_optional(&self.pool) .await; match result { Ok(Some(row)) => { let role_str: String = row.get("role"); let role = UserRole::from_str(&role_str).unwrap_or_default(); let user = User { id: Uuid::parse_str(&row.get::("id")) .map_err(|e| ApiError::InternalError(format!("UUID 解析错误: {}", e)))?, username: row.get("username"), email: row.get("email"), password_hash: row.get("password_hash"), role, created_at: DateTime::parse_from_rfc3339(&row.get::("created_at")) .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? .with_timezone(&Utc), updated_at: DateTime::parse_from_rfc3339(&row.get::("updated_at")) .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? .with_timezone(&Utc), }; Ok(Some(user)) } Ok(None) => Ok(None), Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))), } } /// 根据用户名获取用户 async fn get_user_by_username_impl(&self, username: &str) -> Result, ApiError> { let result = sqlx::query( "SELECT id, username, email, password_hash, role, created_at, updated_at FROM users WHERE username = ?" ) .bind(username) .fetch_optional(&self.pool) .await; match result { Ok(Some(row)) => { let role_str: String = row.get("role"); let role = UserRole::from_str(&role_str).unwrap_or_default(); let user = User { id: Uuid::parse_str(&row.get::("id")) .map_err(|e| ApiError::InternalError(format!("UUID 解析错误: {}", e)))?, username: row.get("username"), email: row.get("email"), password_hash: row.get("password_hash"), role, created_at: DateTime::parse_from_rfc3339(&row.get::("created_at")) .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? .with_timezone(&Utc), updated_at: DateTime::parse_from_rfc3339(&row.get::("updated_at")) .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? .with_timezone(&Utc), }; Ok(Some(user)) } Ok(None) => Ok(None), Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))), } } /// 获取所有用户 async fn list_users_impl(&self) -> Result, ApiError> { let result = sqlx::query( "SELECT id, username, email, password_hash, role, created_at, updated_at FROM users ORDER BY created_at DESC" ) .fetch_all(&self.pool) .await; match result { Ok(rows) => { let mut users = Vec::new(); for row in rows { let role_str: String = row.get("role"); let role = UserRole::from_str(&role_str).unwrap_or_default(); let user = User { id: Uuid::parse_str(&row.get::("id")) .map_err(|e| ApiError::InternalError(format!("UUID 解析错误: {}", e)))?, username: row.get("username"), email: row.get("email"), password_hash: row.get("password_hash"), role, created_at: DateTime::parse_from_rfc3339(&row.get::("created_at")) .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? .with_timezone(&Utc), updated_at: DateTime::parse_from_rfc3339(&row.get::("updated_at")) .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? .with_timezone(&Utc), }; users.push(user); } Ok(users) } Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))), } } /// 分页获取用户列表 async fn list_users_paginated_impl(&self, params: &PaginationParams) -> Result<(Vec, u64), ApiError> { // 首先获取总数 let count_result = sqlx::query("SELECT COUNT(*) as count FROM users") .fetch_one(&self.pool) .await; let total_count = match count_result { Ok(row) => row.get::("count") as u64, Err(e) => return Err(ApiError::InternalError(format!("获取用户总数失败: {}", e))), }; // 然后获取分页数据 let result = sqlx::query( "SELECT id, username, email, password_hash, role, created_at, updated_at FROM users ORDER BY created_at DESC LIMIT ? OFFSET ?" ) .bind(params.limit() as i64) .bind(params.offset() as i64) .fetch_all(&self.pool) .await; match result { Ok(rows) => { let mut users = Vec::new(); for row in rows { let role_str: String = row.get("role"); let role = UserRole::from_str(&role_str).unwrap_or_default(); let user = User { id: Uuid::parse_str(&row.get::("id")) .map_err(|e| ApiError::InternalError(format!("UUID 解析错误: {}", e)))?, username: row.get("username"), email: row.get("email"), password_hash: row.get("password_hash"), role, created_at: DateTime::parse_from_rfc3339(&row.get::("created_at")) .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? .with_timezone(&Utc), updated_at: DateTime::parse_from_rfc3339(&row.get::("updated_at")) .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? .with_timezone(&Utc), }; users.push(user); } Ok((users, total_count)) } Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))), } } /// 搜索和过滤用户(带分页) async fn search_users_impl(&self, search_params: &UserSearchParams, pagination_params: &PaginationParams) -> Result<(Vec, u64), ApiError> { // 构建 WHERE 子句和参数 let mut where_conditions = Vec::new(); let mut bind_values: Vec = Vec::new(); // 通用搜索(在用户名和邮箱中搜索) if let Some(q) = &search_params.q { where_conditions.push("(username LIKE ? OR email LIKE ?)".to_string()); let search_pattern = format!("%{}%", q); bind_values.push(search_pattern.clone()); bind_values.push(search_pattern); } // 用户名过滤 if let Some(username) = &search_params.username { where_conditions.push("username LIKE ?".to_string()); bind_values.push(format!("%{}%", username)); } // 邮箱过滤 if let Some(email) = &search_params.email { where_conditions.push("email LIKE ?".to_string()); bind_values.push(format!("%{}%", email)); } // 创建时间范围过滤 if let Some(created_after) = &search_params.created_after { if DateTime::parse_from_rfc3339(created_after).is_ok() { where_conditions.push("created_at >= ?".to_string()); bind_values.push(created_after.clone()); } } if let Some(created_before) = &search_params.created_before { if DateTime::parse_from_rfc3339(created_before).is_ok() { where_conditions.push("created_at <= ?".to_string()); bind_values.push(created_before.clone()); } } // 构建 WHERE 子句 let where_clause = if where_conditions.is_empty() { String::new() } else { format!("WHERE {}", where_conditions.join(" AND ")) }; // 构建 ORDER BY 子句 let sort_field = match search_params.get_sort_by() { "username" => "username", "email" => "email", _ => "created_at", // 默认按创建时间排序 }; let sort_order = if search_params.get_sort_order() == "asc" { "ASC" } else { "DESC" }; let order_clause = format!("ORDER BY {} {}", sort_field, sort_order); // 首先获取总数 let count_query = format!("SELECT COUNT(*) as count FROM users {}", where_clause); let mut count_query_builder = sqlx::query(&count_query); // 绑定参数到计数查询 for value in &bind_values { count_query_builder = count_query_builder.bind(value); } let count_result = count_query_builder.fetch_one(&self.pool).await; let total_count = match count_result { Ok(row) => row.get::("count") as u64, Err(e) => return Err(ApiError::InternalError(format!("获取搜索结果总数失败: {}", e))), }; // 然后获取分页数据 let data_query = format!( "SELECT id, username, email, password_hash, role, created_at, updated_at FROM users {} {} LIMIT ? OFFSET ?", where_clause, order_clause ); let mut data_query_builder = sqlx::query(&data_query); // 绑定搜索参数 for value in &bind_values { data_query_builder = data_query_builder.bind(value); } // 绑定分页参数 data_query_builder = data_query_builder .bind(pagination_params.limit() as i64) .bind(pagination_params.offset() as i64); let result = data_query_builder.fetch_all(&self.pool).await; match result { Ok(rows) => { let mut users = Vec::new(); for row in rows { let role_str: String = row.get("role"); let role = UserRole::from_str(&role_str).unwrap_or_default(); let user = User { id: Uuid::parse_str(&row.get::("id")) .map_err(|e| ApiError::InternalError(format!("UUID 解析错误: {}", e)))?, username: row.get("username"), email: row.get("email"), password_hash: row.get("password_hash"), role, created_at: DateTime::parse_from_rfc3339(&row.get::("created_at")) .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? .with_timezone(&Utc), updated_at: DateTime::parse_from_rfc3339(&row.get::("updated_at")) .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? .with_timezone(&Utc), }; users.push(user); } Ok((users, total_count)) } Err(e) => Err(ApiError::InternalError(format!("数据库搜索错误: {}", e))), } } /// 更新用户 async fn update_user_impl(&self, id: &Uuid, updated_user: User) -> Result, ApiError> { let result = sqlx::query( r#" UPDATE users SET username = ?, email = ?, role = ?, updated_at = ? WHERE id = ? "#, ) .bind(&updated_user.username) .bind(&updated_user.email) .bind(updated_user.role.as_str()) .bind(updated_user.updated_at.to_rfc3339()) .bind(id.to_string()) .execute(&self.pool) .await; match result { Ok(query_result) => { if query_result.rows_affected() > 0 { Ok(Some(updated_user)) } else { Ok(None) } } Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))), } } /// 删除用户 async fn delete_user_impl(&self, id: &Uuid) -> Result { let result = sqlx::query("DELETE FROM users WHERE id = ?") .bind(id.to_string()) .execute(&self.pool) .await; match result { Ok(query_result) => Ok(query_result.rows_affected() > 0), Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))), } } } #[async_trait] impl UserStore for DatabaseUserStore { async fn create_user(&self, user: User) -> Result { self.create_user_impl(user).await } async fn get_user(&self, id: &Uuid) -> Result, ApiError> { self.get_user_impl(id).await } async fn get_user_by_username(&self, username: &str) -> Result, ApiError> { self.get_user_by_username_impl(username).await } async fn list_users(&self) -> Result, ApiError> { self.list_users_impl().await } async fn list_users_paginated(&self, params: &PaginationParams) -> Result<(Vec, u64), ApiError> { self.list_users_paginated_impl(params).await } async fn search_users(&self, search_params: &UserSearchParams, pagination_params: &PaginationParams) -> Result<(Vec, u64), ApiError> { self.search_users_impl(search_params, pagination_params).await } async fn update_user(&self, id: &Uuid, updated_user: User) -> Result, ApiError> { self.update_user_impl(id, updated_user).await } async fn delete_user(&self, id: &Uuid) -> Result { self.delete_user_impl(id).await } }