//! 日志中间件模块 use axum::{ extract::{Request, ConnectInfo}, middleware::Next, response::Response, http::{Method, Uri, HeaderMap}, }; use std::{ net::SocketAddr, time::{Duration, Instant}, }; use tracing::{info, warn, error, Span}; use uuid::Uuid; /// 请求日志中间件 pub async fn request_logging_middleware( ConnectInfo(addr): ConnectInfo, request: Request, next: Next, ) -> Response { let start_time = Instant::now(); let method = request.method().clone(); let uri = request.uri().clone(); let version = request.version(); let headers = request.headers().clone(); // 生成请求ID let request_id = Uuid::new_v4(); // 提取用户代理和其他有用的头信息 let user_agent = headers .get("user-agent") .and_then(|v| v.to_str().ok()) .unwrap_or("unknown"); let content_length = headers .get("content-length") .and_then(|v| v.to_str().ok()) .and_then(|v| v.parse::().ok()) .unwrap_or(0); // 创建 span 用于结构化日志 let span = tracing::info_span!( "http_request", request_id = %request_id, method = %method, uri = %uri, version = ?version, remote_addr = %addr, user_agent = user_agent, content_length = content_length, ); let _enter = span.enter(); info!( "开始处理请求: {} {} from {}", method, uri, addr ); // 处理请求 let response = next.run(request).await; let duration = start_time.elapsed(); let status = response.status(); let response_size = response .headers() .get("content-length") .and_then(|v| v.to_str().ok()) .and_then(|v| v.parse::().ok()) .unwrap_or(0); // 根据状态码选择日志级别 match status.as_u16() { 200..=299 => { info!( status = status.as_u16(), duration_ms = duration.as_millis(), response_size = response_size, "请求处理完成" ); } 300..=399 => { info!( status = status.as_u16(), duration_ms = duration.as_millis(), response_size = response_size, "请求重定向" ); } 400..=499 => { warn!( status = status.as_u16(), duration_ms = duration.as_millis(), response_size = response_size, "客户端错误" ); } 500..=599 => { error!( status = status.as_u16(), duration_ms = duration.as_millis(), response_size = response_size, "服务器错误" ); } _ => { info!( status = status.as_u16(), duration_ms = duration.as_millis(), response_size = response_size, "请求处理完成" ); } } response } /// 性能监控中间件 pub async fn performance_middleware( request: Request, next: Next, ) -> Response { let start_time = Instant::now(); let method = request.method().clone(); let uri = request.uri().clone(); let response = next.run(request).await; let duration = start_time.elapsed(); // 记录慢请求 if duration > Duration::from_millis(1000) { warn!( method = %method, uri = %uri, duration_ms = duration.as_millis(), "慢请求检测" ); } // 记录性能指标 tracing::debug!( method = %method, uri = %uri, duration_ms = duration.as_millis(), status = response.status().as_u16(), "请求性能指标" ); response } /// 错误日志中间件 pub async fn error_logging_middleware( request: Request, next: Next, ) -> Response { let method = request.method().clone(); let uri = request.uri().clone(); let response = next.run(request).await; // 记录错误响应的详细信息 if response.status().is_server_error() { error!( method = %method, uri = %uri, status = response.status().as_u16(), "服务器内部错误" ); } else if response.status().is_client_error() { warn!( method = %method, uri = %uri, status = response.status().as_u16(), "客户端请求错误" ); } response } /// 安全日志中间件 pub async fn security_logging_middleware( ConnectInfo(addr): ConnectInfo, request: Request, next: Next, ) -> Response { let method = request.method().clone(); let uri = request.uri().clone(); let headers = request.headers().clone(); // 检测可疑的请求模式 let suspicious_patterns = [ "admin", "login", "auth", "password", "token", "sql", "script", "exec", "cmd", "shell", "..", "etc/passwd", "proc/", "sys/", ]; let uri_str = uri.to_string().to_lowercase(); let is_suspicious = suspicious_patterns.iter().any(|&pattern| uri_str.contains(pattern)); if is_suspicious { warn!( remote_addr = %addr, method = %method, uri = %uri, user_agent = headers.get("user-agent") .and_then(|v| v.to_str().ok()) .unwrap_or("unknown"), "检测到可疑请求" ); } // 检测暴力破解尝试 if uri.path().contains("login") || uri.path().contains("auth") { info!( remote_addr = %addr, method = %method, uri = %uri, "认证尝试" ); } let response = next.run(request).await; // 记录认证失败 if response.status() == 401 || response.status() == 403 { warn!( remote_addr = %addr, method = %method, uri = %uri, status = response.status().as_u16(), "认证或授权失败" ); } response } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{Request, StatusCode}, response::Response, middleware::Next, }; use std::convert::Infallible; async fn dummy_handler(_req: Request) -> Result, Infallible> { Ok(Response::builder() .status(StatusCode::OK) .body(Body::empty()) .unwrap()) } #[tokio::test] async fn test_request_logging_middleware() { let request = Request::builder() .method("GET") .uri("/test") .body(Body::empty()) .unwrap(); let addr = "127.0.0.1:8080".parse().unwrap(); let next = Next::new(dummy_handler); let response = request_logging_middleware( ConnectInfo(addr), request, next, ).await; assert_eq!(response.status(), StatusCode::OK); } #[tokio::test] async fn test_performance_middleware() { let request = Request::builder() .method("GET") .uri("/test") .body(Body::empty()) .unwrap(); let next = Next::new(dummy_handler); let response = performance_middleware(request, next).await; assert_eq!(response.status(), StatusCode::OK); } }