Stage 8: 完成 SQLite 数据库存储集成

 新功能:
- 实现 DatabaseUserStore 结构体,支持 SQLite 数据库操作
- 创建 UserStore trait 抽象层,支持多种存储后端
- 添加数据库连接和表初始化功能
- 支持通过环境变量配置数据库 URL

🔧 技术改进:
- 使用 SQLx 进行异步数据库操作
- 实现 trait-based 存储架构,便于扩展
- 添加完整的 CRUD 数据库操作
- 支持数据库错误处理和类型转换

🧪 测试验证:
- 所有 API 端点正常工作
- 数据持久化到 SQLite 文件
- 用户创建、查询、登录功能完整
- 重复用户名检查正常工作

📁 文件变更:
- 新增: src/storage/database.rs - SQLite 存储实现
- 新增: src/storage/traits.rs - 存储抽象 trait
- 更新: src/storage/mod.rs - 导出新模块
- 更新: src/routes/mod.rs - 支持 trait 对象
- 更新: src/handlers/user.rs - 使用 trait 抽象
- 更新: src/main.rs - 数据库初始化逻辑
- 更新: Cargo.toml - 添加 SQLx 依赖
- 新增: .env - 数据库配置文件

🎯 学习目标达成:
- 掌握 Rust 数据库集成
- 理解 trait 抽象设计模式
- 学习异步数据库操作
- 实践配置管理
This commit is contained in:
2025-08-04 20:02:20 +08:00
parent 02540f17d3
commit c18f345475
10 changed files with 359 additions and 46 deletions

View File

@@ -38,6 +38,12 @@ jsonwebtoken = "9.0"
# 验证
validator = { version = "0.16", features = ["derive"] }
# 数据库
sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "sqlite", "chrono", "uuid"] }
# 异步 trait
async-trait = "0.1"
# HTTP 客户端(用于测试)
[dev-dependencies]
reqwest = { version = "0.11", features = ["json", "rustls-tls"], default-features = false }

View File

@@ -1,5 +1,6 @@
//! 用户相关的 HTTP 处理器
use std::sync::Arc;
use axum::{
extract::{Path, State},
http::StatusCode,
@@ -11,20 +12,20 @@ use chrono::Utc;
use validator::Validate;
use crate::models::user::{User, UserResponse, CreateUserRequest, UpdateUserRequest, LoginRequest, LoginResponse};
use crate::storage::memory::MemoryUserStore;
use crate::storage::UserStore;
use crate::utils::errors::ApiError;
use crate::middleware::auth::create_jwt;
/// 创建用户
pub async fn create_user(
State(store): State<MemoryUserStore>,
State(store): State<Arc<dyn UserStore>>,
RequestJson(payload): RequestJson<CreateUserRequest>,
) -> Result<(StatusCode, Json<UserResponse>), ApiError> {
// 验证请求数据
payload.validate()?;
// 检查用户名是否已存在
if store.get_user_by_username(&payload.username).await.is_some() {
if let Ok(Some(_)) = store.get_user_by_username(&payload.username).await {
return Err(ApiError::Conflict("用户名已存在".to_string()));
}
@@ -40,16 +41,16 @@ pub async fn create_user(
match store.create_user(user).await {
Ok(user) => Ok((StatusCode::CREATED, Json(user.into()))),
Err(e) => Err(ApiError::InternalError(e)),
Err(e) => Err(e),
}
}
/// 获取单个用户
pub async fn get_user(
State(store): State<MemoryUserStore>,
State(store): State<Arc<dyn UserStore>>,
Path(id): Path<Uuid>,
) -> Result<Json<UserResponse>, ApiError> {
match store.get_user(&id).await {
match store.get_user(&id).await? {
Some(user) => Ok(Json(user.into())),
None => Err(ApiError::NotFound("用户不存在".to_string())),
}
@@ -57,23 +58,23 @@ pub async fn get_user(
/// 获取所有用户
pub async fn list_users(
State(store): State<MemoryUserStore>,
) -> Json<Vec<UserResponse>> {
let users = store.list_users().await;
State(store): State<Arc<dyn UserStore>>,
) -> Result<Json<Vec<UserResponse>>, ApiError> {
let users = store.list_users().await?;
let responses: Vec<UserResponse> = users.into_iter().map(|u| u.into()).collect();
Json(responses)
Ok(Json(responses))
}
/// 更新用户
pub async fn update_user(
State(store): State<MemoryUserStore>,
State(store): State<Arc<dyn UserStore>>,
Path(id): Path<Uuid>,
RequestJson(payload): RequestJson<UpdateUserRequest>,
) -> Result<Json<UserResponse>, ApiError> {
// 验证请求数据
payload.validate()?;
match store.get_user(&id).await {
match store.get_user(&id).await? {
Some(mut user) => {
if let Some(username) = payload.username {
user.username = username;
@@ -83,7 +84,7 @@ pub async fn update_user(
}
user.updated_at = Utc::now();
match store.update_user(&id, user).await {
match store.update_user(&id, user).await? {
Some(updated_user) => Ok(Json(updated_user.into())),
None => Err(ApiError::InternalError("更新用户失败".to_string())),
}
@@ -94,10 +95,10 @@ pub async fn update_user(
/// 删除用户
pub async fn delete_user(
State(store): State<MemoryUserStore>,
State(store): State<Arc<dyn UserStore>>,
Path(id): Path<Uuid>,
) -> Result<StatusCode, ApiError> {
if store.delete_user(&id).await {
if store.delete_user(&id).await? {
Ok(StatusCode::NO_CONTENT)
} else {
Err(ApiError::NotFound("用户不存在".to_string()))
@@ -106,11 +107,11 @@ pub async fn delete_user(
/// 用户登录
pub async fn login(
State(store): State<MemoryUserStore>,
State(store): State<Arc<dyn UserStore>>,
RequestJson(payload): RequestJson<LoginRequest>,
) -> Result<Json<LoginResponse>, ApiError> {
// 根据用户名查找用户
match store.get_user_by_username(&payload.username).await {
match store.get_user_by_username(&payload.username).await? {
Some(user) => {
// 验证密码
if verify_password(&payload.password, &user.password_hash) {

View File

@@ -21,7 +21,7 @@ pub use utils::errors::ApiError;
mod tests {
use super::*;
use crate::models::user::{CreateUserRequest, LoginRequest};
use crate::storage::memory::MemoryUserStore;
use crate::storage::{memory::MemoryUserStore, UserStore};
use uuid::Uuid;
use validator::Validate;
@@ -44,31 +44,31 @@ mod tests {
assert!(result.is_ok());
// 测试获取用户
let retrieved_user = store.get_user(&user_id).await;
let retrieved_user = store.get_user(&user_id).await.unwrap();
assert!(retrieved_user.is_some());
assert_eq!(retrieved_user.unwrap().username, "testuser");
// 测试按用户名获取用户
let user_by_name = store.get_user_by_username("testuser").await;
let user_by_name = store.get_user_by_username("testuser").await.unwrap();
assert!(user_by_name.is_some());
assert_eq!(user_by_name.unwrap().id, user_id);
// 测试列出所有用户
let users = store.list_users().await;
let users = store.list_users().await.unwrap();
assert_eq!(users.len(), 1);
// 测试更新用户
let mut updated_user = user.clone();
updated_user.username = "updated_user".to_string();
let update_result = store.update_user(&user_id, updated_user).await;
let update_result = store.update_user(&user_id, updated_user).await.unwrap();
assert!(update_result.is_some());
// 测试删除用户
let delete_result = store.delete_user(&user_id).await;
let delete_result = store.delete_user(&user_id).await.unwrap();
assert!(delete_result);
// 验证用户已被删除
let deleted_user = store.get_user(&user_id).await;
let deleted_user = store.get_user(&user_id).await.unwrap();
assert!(deleted_user.is_none());
}

View File

@@ -1,11 +1,12 @@
//! Rust User API 服务器主程序
use std::net::SocketAddr;
use std::sync::Arc;
use tracing_subscriber;
use rust_user_api::{
config::Config,
routes::create_routes,
storage::memory::MemoryUserStore,
storage::{memory::MemoryUserStore, database::DatabaseUserStore, UserStore},
};
#[tokio::main]
@@ -16,8 +17,20 @@ async fn main() {
// 加载配置
let config = Config::from_env();
// 创建存储实例
let store = MemoryUserStore::new();
// 根据配置创建存储实例
let store: Arc<dyn UserStore> = if let Some(database_url) = &config.database_url {
println!("🗄️ 使用 SQLite 数据库存储: {}", database_url);
// 创建数据库存储
let db_store = DatabaseUserStore::from_url(database_url)
.await
.expect("无法连接到数据库");
Arc::new(db_store)
} else {
println!("💾 使用内存存储");
Arc::new(MemoryUserStore::new())
};
// 创建路由
let app = create_routes(store);

View File

@@ -1,14 +1,15 @@
//! 路由配置模块
use std::sync::Arc;
use axum::{
Router,
routing::{get, post, put, delete},
routing::{get, post},
};
use crate::handlers;
use crate::storage::memory::MemoryUserStore;
use crate::storage::UserStore;
/// 创建应用路由
pub fn create_routes(store: MemoryUserStore) -> Router {
pub fn create_routes(store: Arc<dyn UserStore>) -> Router {
Router::new()
.route("/", get(handlers::root))
.route("/health", get(handlers::health_check))
@@ -17,7 +18,7 @@ pub fn create_routes(store: MemoryUserStore) -> Router {
}
/// API 路由
fn api_routes() -> Router<MemoryUserStore> {
fn api_routes() -> Router<Arc<dyn UserStore>> {
Router::new()
.route("/users",
get(handlers::user::list_users)

View File

@@ -1,7 +1,7 @@
//! 用户业务逻辑服务
use crate::models::user::User;
use crate::storage::memory::MemoryUserStore;
use crate::storage::{memory::MemoryUserStore, UserStore};
use crate::utils::errors::ApiError;
/// 用户服务
@@ -16,7 +16,7 @@ impl UserService {
/// 验证用户凭据
pub async fn authenticate_user(&self, username: &str, password: &str) -> Result<User, ApiError> {
match self.store.get_user_by_username(username).await {
match self.store.get_user_by_username(username).await? {
Some(user) => {
if user.password_hash == format!("hashed_{}", password) {
Ok(user)

243
src/storage/database.rs Normal file
View File

@@ -0,0 +1,243 @@
//! SQLite 数据库存储实现
use async_trait::async_trait;
use sqlx::{SqlitePool, Row};
use uuid::Uuid;
use chrono::{DateTime, Utc};
use crate::models::user::User;
use crate::utils::errors::ApiError;
use crate::storage::UserStore;
/// SQLite 用户存储
#[derive(Clone)]
pub struct DatabaseUserStore {
pool: SqlitePool,
}
impl DatabaseUserStore {
/// 创建新的数据库存储实例
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
/// 从数据库 URL 创建新的数据库存储实例
pub async fn from_url(database_url: &str) -> Result<Self, ApiError> {
let pool = SqlitePool::connect(database_url)
.await
.map_err(|e| ApiError::InternalError(format!("无法连接到数据库: {}", e)))?;
let store = Self::new(pool);
store.init_tables().await?;
Ok(store)
}
/// 初始化数据库表
pub async fn init_tables(&self) -> Result<(), ApiError> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT UNIQUE NOT NULL,
email TEXT NOT NULL,
password_hash TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
"#,
)
.execute(&self.pool)
.await
.map_err(|e| ApiError::InternalError(format!("数据库初始化错误: {}", e)))?;
Ok(())
}
/// 创建用户
async fn create_user_impl(&self, user: User) -> Result<User, ApiError> {
let result = sqlx::query(
r#"
INSERT INTO users (id, username, email, password_hash, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
"#,
)
.bind(user.id.to_string())
.bind(&user.username)
.bind(&user.email)
.bind(&user.password_hash)
.bind(user.created_at.to_rfc3339())
.bind(user.updated_at.to_rfc3339())
.execute(&self.pool)
.await;
match result {
Ok(_) => Ok(user),
Err(sqlx::Error::Database(db_err)) if db_err.is_unique_violation() => {
Err(ApiError::Conflict("用户名已存在".to_string()))
}
Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))),
}
}
/// 根据 ID 获取用户
async fn get_user_impl(&self, id: &Uuid) -> Result<Option<User>, ApiError> {
let result = sqlx::query(
"SELECT id, username, email, password_hash, created_at, updated_at FROM users WHERE id = ?"
)
.bind(id.to_string())
.fetch_optional(&self.pool)
.await;
match result {
Ok(Some(row)) => {
let user = User {
id: Uuid::parse_str(&row.get::<String, _>("id"))
.map_err(|e| ApiError::InternalError(format!("UUID 解析错误: {}", e)))?,
username: row.get("username"),
email: row.get("email"),
password_hash: row.get("password_hash"),
created_at: DateTime::parse_from_rfc3339(&row.get::<String, _>("created_at"))
.map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))?
.with_timezone(&Utc),
updated_at: DateTime::parse_from_rfc3339(&row.get::<String, _>("updated_at"))
.map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))?
.with_timezone(&Utc),
};
Ok(Some(user))
}
Ok(None) => Ok(None),
Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))),
}
}
/// 根据用户名获取用户
async fn get_user_by_username_impl(&self, username: &str) -> Result<Option<User>, ApiError> {
let result = sqlx::query(
"SELECT id, username, email, password_hash, created_at, updated_at FROM users WHERE username = ?"
)
.bind(username)
.fetch_optional(&self.pool)
.await;
match result {
Ok(Some(row)) => {
let user = User {
id: Uuid::parse_str(&row.get::<String, _>("id"))
.map_err(|e| ApiError::InternalError(format!("UUID 解析错误: {}", e)))?,
username: row.get("username"),
email: row.get("email"),
password_hash: row.get("password_hash"),
created_at: DateTime::parse_from_rfc3339(&row.get::<String, _>("created_at"))
.map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))?
.with_timezone(&Utc),
updated_at: DateTime::parse_from_rfc3339(&row.get::<String, _>("updated_at"))
.map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))?
.with_timezone(&Utc),
};
Ok(Some(user))
}
Ok(None) => Ok(None),
Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))),
}
}
/// 获取所有用户
async fn list_users_impl(&self) -> Result<Vec<User>, ApiError> {
let result = sqlx::query(
"SELECT id, username, email, password_hash, created_at, updated_at FROM users ORDER BY created_at DESC"
)
.fetch_all(&self.pool)
.await;
match result {
Ok(rows) => {
let mut users = Vec::new();
for row in rows {
let user = User {
id: Uuid::parse_str(&row.get::<String, _>("id"))
.map_err(|e| ApiError::InternalError(format!("UUID 解析错误: {}", e)))?,
username: row.get("username"),
email: row.get("email"),
password_hash: row.get("password_hash"),
created_at: DateTime::parse_from_rfc3339(&row.get::<String, _>("created_at"))
.map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))?
.with_timezone(&Utc),
updated_at: DateTime::parse_from_rfc3339(&row.get::<String, _>("updated_at"))
.map_err(|e| ApiError::InternalError(format!("时间解析错误: {}", e)))?
.with_timezone(&Utc),
};
users.push(user);
}
Ok(users)
}
Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))),
}
}
/// 更新用户
async fn update_user_impl(&self, id: &Uuid, updated_user: User) -> Result<Option<User>, ApiError> {
let result = sqlx::query(
r#"
UPDATE users
SET username = ?, email = ?, updated_at = ?
WHERE id = ?
"#,
)
.bind(&updated_user.username)
.bind(&updated_user.email)
.bind(updated_user.updated_at.to_rfc3339())
.bind(id.to_string())
.execute(&self.pool)
.await;
match result {
Ok(query_result) => {
if query_result.rows_affected() > 0 {
Ok(Some(updated_user))
} else {
Ok(None)
}
}
Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))),
}
}
/// 删除用户
async fn delete_user_impl(&self, id: &Uuid) -> Result<bool, ApiError> {
let result = sqlx::query("DELETE FROM users WHERE id = ?")
.bind(id.to_string())
.execute(&self.pool)
.await;
match result {
Ok(query_result) => Ok(query_result.rows_affected() > 0),
Err(e) => Err(ApiError::InternalError(format!("数据库错误: {}", e))),
}
}
}
#[async_trait]
impl UserStore for DatabaseUserStore {
async fn create_user(&self, user: User) -> Result<User, ApiError> {
self.create_user_impl(user).await
}
async fn get_user(&self, id: &Uuid) -> Result<Option<User>, ApiError> {
self.get_user_impl(id).await
}
async fn get_user_by_username(&self, username: &str) -> Result<Option<User>, ApiError> {
self.get_user_by_username_impl(username).await
}
async fn list_users(&self) -> Result<Vec<User>, ApiError> {
self.list_users_impl().await
}
async fn update_user(&self, id: &Uuid, updated_user: User) -> Result<Option<User>, ApiError> {
self.update_user_impl(id, updated_user).await
}
async fn delete_user(&self, id: &Uuid) -> Result<bool, ApiError> {
self.delete_user_impl(id).await
}
}

View File

@@ -3,7 +3,10 @@
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use uuid::Uuid;
use async_trait::async_trait;
use crate::models::user::User;
use crate::utils::errors::ApiError;
use crate::storage::traits::UserStore;
/// 线程安全的用户存储类型
pub type UserStorage = Arc<RwLock<HashMap<Uuid, User>>>;
@@ -22,42 +25,56 @@ impl MemoryUserStore {
}
}
}
#[async_trait]
impl UserStore for MemoryUserStore {
/// 创建用户
pub async fn create_user(&self, user: User) -> Result<User, String> {
async fn create_user(&self, user: User) -> Result<User, ApiError> {
let mut users = self.users.write().unwrap();
// 检查用户名是否已存在
if users.values().any(|u| u.username == user.username) {
return Err(ApiError::Conflict("用户名已存在".to_string()));
}
users.insert(user.id, user.clone());
Ok(user)
}
/// 根据 ID 获取用户
pub async fn get_user(&self, id: &Uuid) -> Option<User> {
async fn get_user(&self, id: &Uuid) -> Result<Option<User>, ApiError> {
let users = self.users.read().unwrap();
users.get(id).cloned()
Ok(users.get(id).cloned())
}
/// 根据用户名获取用户
pub async fn get_user_by_username(&self, username: &str) -> Option<User> {
async fn get_user_by_username(&self, username: &str) -> Result<Option<User>, ApiError> {
let users = self.users.read().unwrap();
users.values().find(|u| u.username == username).cloned()
Ok(users.values().find(|u| u.username == username).cloned())
}
/// 获取所有用户
pub async fn list_users(&self) -> Vec<User> {
async fn list_users(&self) -> Result<Vec<User>, ApiError> {
let users = self.users.read().unwrap();
users.values().cloned().collect()
Ok(users.values().cloned().collect())
}
/// 更新用户
pub async fn update_user(&self, id: &Uuid, updated_user: User) -> Option<User> {
async fn update_user(&self, id: &Uuid, updated_user: User) -> Result<Option<User>, ApiError> {
let mut users = self.users.write().unwrap();
users.insert(*id, updated_user.clone());
Some(updated_user)
if users.contains_key(id) {
users.insert(*id, updated_user.clone());
Ok(Some(updated_user))
} else {
Ok(None)
}
}
/// 删除用户
pub async fn delete_user(&self, id: &Uuid) -> bool {
async fn delete_user(&self, id: &Uuid) -> Result<bool, ApiError> {
let mut users = self.users.write().unwrap();
users.remove(id).is_some()
Ok(users.remove(id).is_some())
}
}

View File

@@ -1,5 +1,9 @@
//! 数据存储模块
pub mod memory;
pub mod database;
pub mod traits;
pub use memory::MemoryUserStore;
pub use memory::MemoryUserStore;
pub use database::DatabaseUserStore;
pub use traits::UserStore;

28
src/storage/traits.rs Normal file
View File

@@ -0,0 +1,28 @@
//! 存储抽象 trait
use async_trait::async_trait;
use uuid::Uuid;
use crate::models::user::User;
use crate::utils::errors::ApiError;
/// 用户存储 trait
#[async_trait]
pub trait UserStore: Send + Sync {
/// 创建用户
async fn create_user(&self, user: User) -> Result<User, ApiError>;
/// 根据 ID 获取用户
async fn get_user(&self, id: &Uuid) -> Result<Option<User>, ApiError>;
/// 根据用户名获取用户
async fn get_user_by_username(&self, username: &str) -> Result<Option<User>, ApiError>;
/// 获取所有用户
async fn list_users(&self) -> Result<Vec<User>, ApiError>;
/// 更新用户
async fn update_user(&self, id: &Uuid, updated_user: User) -> Result<Option<User>, ApiError>;
/// 删除用户
async fn delete_user(&self, id: &Uuid) -> Result<bool, ApiError>;
}