From c18f34547573ad9bcbfdf7ccf1a0f92c2dd6d844 Mon Sep 17 00:00:00 2001 From: enoch Date: Mon, 4 Aug 2025 20:02:20 +0800 Subject: [PATCH] =?UTF-8?q?Stage=208:=20=E5=AE=8C=E6=88=90=20SQLite=20?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E5=AD=98=E5=82=A8=E9=9B=86=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ✨ 新功能: - 实现 DatabaseUserStore 结构体,支持 SQLite 数据库操作 - 创建 UserStore trait 抽象层,支持多种存储后端 - 添加数据库连接和表初始化功能 - 支持通过环境变量配置数据库 URL 🔧 技术改进: - 使用 SQLx 进行异步数据库操作 - 实现 trait-based 存储架构,便于扩展 - 添加完整的 CRUD 数据库操作 - 支持数据库错误处理和类型转换 🧪 测试验证: - 所有 API 端点正常工作 - 数据持久化到 SQLite 文件 - 用户创建、查询、登录功能完整 - 重复用户名检查正常工作 📁 文件变更: - 新增: src/storage/database.rs - SQLite 存储实现 - 新增: src/storage/traits.rs - 存储抽象 trait - 更新: src/storage/mod.rs - 导出新模块 - 更新: src/routes/mod.rs - 支持 trait 对象 - 更新: src/handlers/user.rs - 使用 trait 抽象 - 更新: src/main.rs - 数据库初始化逻辑 - 更新: Cargo.toml - 添加 SQLx 依赖 - 新增: .env - 数据库配置文件 🎯 学习目标达成: - 掌握 Rust 数据库集成 - 理解 trait 抽象设计模式 - 学习异步数据库操作 - 实践配置管理 --- Cargo.toml | 6 + src/handlers/user.rs | 35 ++--- src/lib.rs | 14 +- src/main.rs | 19 ++- src/routes/mod.rs | 9 +- src/services/user_service.rs | 4 +- src/storage/database.rs | 243 +++++++++++++++++++++++++++++++++++ src/storage/memory.rs | 41 ++++-- src/storage/mod.rs | 6 +- src/storage/traits.rs | 28 ++++ 10 files changed, 359 insertions(+), 46 deletions(-) create mode 100644 src/storage/database.rs create mode 100644 src/storage/traits.rs diff --git a/Cargo.toml b/Cargo.toml index f64da9e..3b618e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,12 @@ jsonwebtoken = "9.0" # 验证 validator = { version = "0.16", features = ["derive"] } +# 数据库 +sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "sqlite", "chrono", "uuid"] } + +# 异步 trait +async-trait = "0.1" + # HTTP 客户端(用于测试) [dev-dependencies] reqwest = { version = "0.11", features = ["json", "rustls-tls"], default-features = false } diff --git a/src/handlers/user.rs b/src/handlers/user.rs index b51b098..2840453 100644 --- a/src/handlers/user.rs +++ b/src/handlers/user.rs @@ -1,5 +1,6 @@ //! 用户相关的 HTTP 处理器 +use std::sync::Arc; use axum::{ extract::{Path, State}, http::StatusCode, @@ -11,20 +12,20 @@ use chrono::Utc; use validator::Validate; use crate::models::user::{User, UserResponse, CreateUserRequest, UpdateUserRequest, LoginRequest, LoginResponse}; -use crate::storage::memory::MemoryUserStore; +use crate::storage::UserStore; use crate::utils::errors::ApiError; use crate::middleware::auth::create_jwt; /// 创建用户 pub async fn create_user( - State(store): State, + State(store): State>, RequestJson(payload): RequestJson, ) -> Result<(StatusCode, Json), ApiError> { // 验证请求数据 payload.validate()?; // 检查用户名是否已存在 - if store.get_user_by_username(&payload.username).await.is_some() { + if let Ok(Some(_)) = store.get_user_by_username(&payload.username).await { return Err(ApiError::Conflict("用户名已存在".to_string())); } @@ -40,16 +41,16 @@ pub async fn create_user( match store.create_user(user).await { Ok(user) => Ok((StatusCode::CREATED, Json(user.into()))), - Err(e) => Err(ApiError::InternalError(e)), + Err(e) => Err(e), } } /// 获取单个用户 pub async fn get_user( - State(store): State, + State(store): State>, Path(id): Path, ) -> Result, ApiError> { - match store.get_user(&id).await { + match store.get_user(&id).await? { Some(user) => Ok(Json(user.into())), None => Err(ApiError::NotFound("用户不存在".to_string())), } @@ -57,23 +58,23 @@ pub async fn get_user( /// 获取所有用户 pub async fn list_users( - State(store): State, -) -> Json> { - let users = store.list_users().await; + State(store): State>, +) -> Result>, ApiError> { + let users = store.list_users().await?; let responses: Vec = users.into_iter().map(|u| u.into()).collect(); - Json(responses) + Ok(Json(responses)) } /// 更新用户 pub async fn update_user( - State(store): State, + State(store): State>, Path(id): Path, RequestJson(payload): RequestJson, ) -> Result, ApiError> { // 验证请求数据 payload.validate()?; - match store.get_user(&id).await { + match store.get_user(&id).await? { Some(mut user) => { if let Some(username) = payload.username { user.username = username; @@ -83,7 +84,7 @@ pub async fn update_user( } user.updated_at = Utc::now(); - match store.update_user(&id, user).await { + match store.update_user(&id, user).await? { Some(updated_user) => Ok(Json(updated_user.into())), None => Err(ApiError::InternalError("更新用户失败".to_string())), } @@ -94,10 +95,10 @@ pub async fn update_user( /// 删除用户 pub async fn delete_user( - State(store): State, + State(store): State>, Path(id): Path, ) -> Result { - if store.delete_user(&id).await { + if store.delete_user(&id).await? { Ok(StatusCode::NO_CONTENT) } else { Err(ApiError::NotFound("用户不存在".to_string())) @@ -106,11 +107,11 @@ pub async fn delete_user( /// 用户登录 pub async fn login( - State(store): State, + State(store): State>, RequestJson(payload): RequestJson, ) -> Result, ApiError> { // 根据用户名查找用户 - match store.get_user_by_username(&payload.username).await { + match store.get_user_by_username(&payload.username).await? { Some(user) => { // 验证密码 if verify_password(&payload.password, &user.password_hash) { diff --git a/src/lib.rs b/src/lib.rs index 3415049..afeddcc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,7 +21,7 @@ pub use utils::errors::ApiError; mod tests { use super::*; use crate::models::user::{CreateUserRequest, LoginRequest}; - use crate::storage::memory::MemoryUserStore; + use crate::storage::{memory::MemoryUserStore, UserStore}; use uuid::Uuid; use validator::Validate; @@ -44,31 +44,31 @@ mod tests { assert!(result.is_ok()); // 测试获取用户 - let retrieved_user = store.get_user(&user_id).await; + let retrieved_user = store.get_user(&user_id).await.unwrap(); assert!(retrieved_user.is_some()); assert_eq!(retrieved_user.unwrap().username, "testuser"); // 测试按用户名获取用户 - let user_by_name = store.get_user_by_username("testuser").await; + let user_by_name = store.get_user_by_username("testuser").await.unwrap(); assert!(user_by_name.is_some()); assert_eq!(user_by_name.unwrap().id, user_id); // 测试列出所有用户 - let users = store.list_users().await; + let users = store.list_users().await.unwrap(); assert_eq!(users.len(), 1); // 测试更新用户 let mut updated_user = user.clone(); updated_user.username = "updated_user".to_string(); - let update_result = store.update_user(&user_id, updated_user).await; + let update_result = store.update_user(&user_id, updated_user).await.unwrap(); assert!(update_result.is_some()); // 测试删除用户 - let delete_result = store.delete_user(&user_id).await; + let delete_result = store.delete_user(&user_id).await.unwrap(); assert!(delete_result); // 验证用户已被删除 - let deleted_user = store.get_user(&user_id).await; + let deleted_user = store.get_user(&user_id).await.unwrap(); assert!(deleted_user.is_none()); } diff --git a/src/main.rs b/src/main.rs index e62c0ce..db28d71 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,12 @@ //! Rust User API 服务器主程序 use std::net::SocketAddr; +use std::sync::Arc; use tracing_subscriber; use rust_user_api::{ config::Config, routes::create_routes, - storage::memory::MemoryUserStore, + storage::{memory::MemoryUserStore, database::DatabaseUserStore, UserStore}, }; #[tokio::main] @@ -16,8 +17,20 @@ async fn main() { // 加载配置 let config = Config::from_env(); - // 创建存储实例 - let store = MemoryUserStore::new(); + // 根据配置创建存储实例 + let store: Arc = if let Some(database_url) = &config.database_url { + println!("🗄️ 使用 SQLite 数据库存储: {}", database_url); + + // 创建数据库存储 + let db_store = DatabaseUserStore::from_url(database_url) + .await + .expect("无法连接到数据库"); + + Arc::new(db_store) + } else { + println!("💾 使用内存存储"); + Arc::new(MemoryUserStore::new()) + }; // 创建路由 let app = create_routes(store); diff --git a/src/routes/mod.rs b/src/routes/mod.rs index e3edbac..a4b5709 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -1,14 +1,15 @@ //! 路由配置模块 +use std::sync::Arc; use axum::{ Router, - routing::{get, post, put, delete}, + routing::{get, post}, }; use crate::handlers; -use crate::storage::memory::MemoryUserStore; +use crate::storage::UserStore; /// 创建应用路由 -pub fn create_routes(store: MemoryUserStore) -> Router { +pub fn create_routes(store: Arc) -> Router { Router::new() .route("/", get(handlers::root)) .route("/health", get(handlers::health_check)) @@ -17,7 +18,7 @@ pub fn create_routes(store: MemoryUserStore) -> Router { } /// API 路由 -fn api_routes() -> Router { +fn api_routes() -> Router> { Router::new() .route("/users", get(handlers::user::list_users) diff --git a/src/services/user_service.rs b/src/services/user_service.rs index 2c1a966..b611796 100644 --- a/src/services/user_service.rs +++ b/src/services/user_service.rs @@ -1,7 +1,7 @@ //! 用户业务逻辑服务 use crate::models::user::User; -use crate::storage::memory::MemoryUserStore; +use crate::storage::{memory::MemoryUserStore, UserStore}; use crate::utils::errors::ApiError; /// 用户服务 @@ -16,7 +16,7 @@ impl UserService { /// 验证用户凭据 pub async fn authenticate_user(&self, username: &str, password: &str) -> Result { - match self.store.get_user_by_username(username).await { + match self.store.get_user_by_username(username).await? { Some(user) => { if user.password_hash == format!("hashed_{}", password) { Ok(user) diff --git a/src/storage/database.rs b/src/storage/database.rs new file mode 100644 index 0000000..ea60d95 --- /dev/null +++ b/src/storage/database.rs @@ -0,0 +1,243 @@ +//! SQLite 数据库存储实现 + +use async_trait::async_trait; +use sqlx::{SqlitePool, Row}; +use uuid::Uuid; +use chrono::{DateTime, Utc}; +use crate::models::user::User; +use crate::utils::errors::ApiError; +use crate::storage::UserStore; + +/// SQLite 用户存储 +#[derive(Clone)] +pub struct DatabaseUserStore { + pool: SqlitePool, +} + +impl DatabaseUserStore { + /// 创建新的数据库存储实例 + pub fn new(pool: SqlitePool) -> Self { + Self { pool } + } + + /// 从数据库 URL 创建新的数据库存储实例 + pub async fn from_url(database_url: &str) -> Result { + let pool = SqlitePool::connect(database_url) + .await + .map_err(|e| ApiError::InternalError(format!("无法连接到数据库: {}", e)))?; + + let store = Self::new(pool); + store.init_tables().await?; + Ok(store) + } + + /// 初始化数据库表 + pub async fn init_tables(&self) -> Result<(), ApiError> { + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + username TEXT UNIQUE NOT NULL, + email TEXT NOT NULL, + password_hash TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ) + "#, + ) + .execute(&self.pool) + .await + .map_err(|e| ApiError::InternalError(format!("数据库初始化错误: {}", e)))?; + + Ok(()) + } + + /// 创建用户 + async fn create_user_impl(&self, user: User) -> Result { + let result = sqlx::query( + r#" + INSERT INTO users (id, username, email, password_hash, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + "#, + ) + .bind(user.id.to_string()) + .bind(&user.username) + .bind(&user.email) + .bind(&user.password_hash) + .bind(user.created_at.to_rfc3339()) + .bind(user.updated_at.to_rfc3339()) + .execute(&self.pool) + .await; + + match result { + Ok(_) => Ok(user), + Err(sqlx::Error::Database(db_err)) if db_err.is_unique_violation() => { + Err(ApiError::Conflict("用户名已存在".to_string())) + } + Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))), + } + } + + /// 根据 ID 获取用户 + async fn get_user_impl(&self, id: &Uuid) -> Result, ApiError> { + let result = sqlx::query( + "SELECT id, username, email, password_hash, created_at, updated_at FROM users WHERE id = ?" + ) + .bind(id.to_string()) + .fetch_optional(&self.pool) + .await; + + match result { + Ok(Some(row)) => { + let user = User { + id: Uuid::parse_str(&row.get::("id")) + .map_err(|e| ApiError::InternalError(format!("UUID 解析错误: {}", e)))?, + username: row.get("username"), + email: row.get("email"), + password_hash: row.get("password_hash"), + created_at: DateTime::parse_from_rfc3339(&row.get::("created_at")) + .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? + .with_timezone(&Utc), + updated_at: DateTime::parse_from_rfc3339(&row.get::("updated_at")) + .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? + .with_timezone(&Utc), + }; + Ok(Some(user)) + } + Ok(None) => Ok(None), + Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))), + } + } + + /// 根据用户名获取用户 + async fn get_user_by_username_impl(&self, username: &str) -> Result, ApiError> { + let result = sqlx::query( + "SELECT id, username, email, password_hash, created_at, updated_at FROM users WHERE username = ?" + ) + .bind(username) + .fetch_optional(&self.pool) + .await; + + match result { + Ok(Some(row)) => { + let user = User { + id: Uuid::parse_str(&row.get::("id")) + .map_err(|e| ApiError::InternalError(format!("UUID 解析错误: {}", e)))?, + username: row.get("username"), + email: row.get("email"), + password_hash: row.get("password_hash"), + created_at: DateTime::parse_from_rfc3339(&row.get::("created_at")) + .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? + .with_timezone(&Utc), + updated_at: DateTime::parse_from_rfc3339(&row.get::("updated_at")) + .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? + .with_timezone(&Utc), + }; + Ok(Some(user)) + } + Ok(None) => Ok(None), + Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))), + } + } + + /// 获取所有用户 + async fn list_users_impl(&self) -> Result, ApiError> { + let result = sqlx::query( + "SELECT id, username, email, password_hash, created_at, updated_at FROM users ORDER BY created_at DESC" + ) + .fetch_all(&self.pool) + .await; + + match result { + Ok(rows) => { + let mut users = Vec::new(); + for row in rows { + let user = User { + id: Uuid::parse_str(&row.get::("id")) + .map_err(|e| ApiError::InternalError(format!("UUID 解析错误: {}", e)))?, + username: row.get("username"), + email: row.get("email"), + password_hash: row.get("password_hash"), + created_at: DateTime::parse_from_rfc3339(&row.get::("created_at")) + .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? + .with_timezone(&Utc), + updated_at: DateTime::parse_from_rfc3339(&row.get::("updated_at")) + .map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))? + .with_timezone(&Utc), + }; + users.push(user); + } + Ok(users) + } + Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))), + } + } + + /// 更新用户 + async fn update_user_impl(&self, id: &Uuid, updated_user: User) -> Result, ApiError> { + let result = sqlx::query( + r#" + UPDATE users + SET username = ?, email = ?, updated_at = ? + WHERE id = ? + "#, + ) + .bind(&updated_user.username) + .bind(&updated_user.email) + .bind(updated_user.updated_at.to_rfc3339()) + .bind(id.to_string()) + .execute(&self.pool) + .await; + + match result { + Ok(query_result) => { + if query_result.rows_affected() > 0 { + Ok(Some(updated_user)) + } else { + Ok(None) + } + } + Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))), + } + } + + /// 删除用户 + async fn delete_user_impl(&self, id: &Uuid) -> Result { + let result = sqlx::query("DELETE FROM users WHERE id = ?") + .bind(id.to_string()) + .execute(&self.pool) + .await; + + match result { + Ok(query_result) => Ok(query_result.rows_affected() > 0), + Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))), + } + } +} + +#[async_trait] +impl UserStore for DatabaseUserStore { + async fn create_user(&self, user: User) -> Result { + self.create_user_impl(user).await + } + + async fn get_user(&self, id: &Uuid) -> Result, ApiError> { + self.get_user_impl(id).await + } + + async fn get_user_by_username(&self, username: &str) -> Result, ApiError> { + self.get_user_by_username_impl(username).await + } + + async fn list_users(&self) -> Result, ApiError> { + self.list_users_impl().await + } + + async fn update_user(&self, id: &Uuid, updated_user: User) -> Result, ApiError> { + self.update_user_impl(id, updated_user).await + } + + async fn delete_user(&self, id: &Uuid) -> Result { + self.delete_user_impl(id).await + } +} \ No newline at end of file diff --git a/src/storage/memory.rs b/src/storage/memory.rs index ee8fcf8..ed28a7f 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -3,7 +3,10 @@ use std::collections::HashMap; use std::sync::{Arc, RwLock}; use uuid::Uuid; +use async_trait::async_trait; use crate::models::user::User; +use crate::utils::errors::ApiError; +use crate::storage::traits::UserStore; /// 线程安全的用户存储类型 pub type UserStorage = Arc>>; @@ -22,42 +25,56 @@ impl MemoryUserStore { } } +} + +#[async_trait] +impl UserStore for MemoryUserStore { /// 创建用户 - pub async fn create_user(&self, user: User) -> Result { + async fn create_user(&self, user: User) -> Result { let mut users = self.users.write().unwrap(); + + // 检查用户名是否已存在 + if users.values().any(|u| u.username == user.username) { + return Err(ApiError::Conflict("用户名已存在".to_string())); + } + users.insert(user.id, user.clone()); Ok(user) } /// 根据 ID 获取用户 - pub async fn get_user(&self, id: &Uuid) -> Option { + async fn get_user(&self, id: &Uuid) -> Result, ApiError> { let users = self.users.read().unwrap(); - users.get(id).cloned() + Ok(users.get(id).cloned()) } /// 根据用户名获取用户 - pub async fn get_user_by_username(&self, username: &str) -> Option { + async fn get_user_by_username(&self, username: &str) -> Result, ApiError> { let users = self.users.read().unwrap(); - users.values().find(|u| u.username == username).cloned() + Ok(users.values().find(|u| u.username == username).cloned()) } /// 获取所有用户 - pub async fn list_users(&self) -> Vec { + async fn list_users(&self) -> Result, ApiError> { let users = self.users.read().unwrap(); - users.values().cloned().collect() + Ok(users.values().cloned().collect()) } /// 更新用户 - pub async fn update_user(&self, id: &Uuid, updated_user: User) -> Option { + async fn update_user(&self, id: &Uuid, updated_user: User) -> Result, ApiError> { let mut users = self.users.write().unwrap(); - users.insert(*id, updated_user.clone()); - Some(updated_user) + if users.contains_key(id) { + users.insert(*id, updated_user.clone()); + Ok(Some(updated_user)) + } else { + Ok(None) + } } /// 删除用户 - pub async fn delete_user(&self, id: &Uuid) -> bool { + async fn delete_user(&self, id: &Uuid) -> Result { let mut users = self.users.write().unwrap(); - users.remove(id).is_some() + Ok(users.remove(id).is_some()) } } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 129f673..cf672aa 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1,5 +1,9 @@ //! 数据存储模块 pub mod memory; +pub mod database; +pub mod traits; -pub use memory::MemoryUserStore; \ No newline at end of file +pub use memory::MemoryUserStore; +pub use database::DatabaseUserStore; +pub use traits::UserStore; \ No newline at end of file diff --git a/src/storage/traits.rs b/src/storage/traits.rs new file mode 100644 index 0000000..dbe86f5 --- /dev/null +++ b/src/storage/traits.rs @@ -0,0 +1,28 @@ +//! 存储抽象 trait + +use async_trait::async_trait; +use uuid::Uuid; +use crate::models::user::User; +use crate::utils::errors::ApiError; + +/// 用户存储 trait +#[async_trait] +pub trait UserStore: Send + Sync { + /// 创建用户 + async fn create_user(&self, user: User) -> Result; + + /// 根据 ID 获取用户 + async fn get_user(&self, id: &Uuid) -> Result, ApiError>; + + /// 根据用户名获取用户 + async fn get_user_by_username(&self, username: &str) -> Result, ApiError>; + + /// 获取所有用户 + async fn list_users(&self) -> Result, ApiError>; + + /// 更新用户 + async fn update_user(&self, id: &Uuid, updated_user: User) -> Result, ApiError>; + + /// 删除用户 + async fn delete_user(&self, id: &Uuid) -> Result; +} \ No newline at end of file