diff --git a/.env.example b/.env.example index 065f69f..1b1667e 100644 --- a/.env.example +++ b/.env.example @@ -23,6 +23,9 @@ START_INDEX=0 # 黑名单阈值(错误多少次后拉黑密钥) BLACKLIST_THRESHOLD=1 +# 最大重试次数(换key重试) +MAX_RETRIES=3 + # =========================================== # OpenAI API 配置 # =========================================== @@ -65,6 +68,9 @@ LOG_ENABLE_FILE=false # 日志文件路径 LOG_FILE_PATH=logs/app.log +# 启用请求日志(生产环境可设为 false 以提高性能) +LOG_ENABLE_REQUEST=true + # =========================================== # 认证配置 # =========================================== diff --git a/internal/config/config.go b/internal/config/config.go index b3217be..e7f9ed4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -45,6 +45,7 @@ type KeysConfig struct { FilePath string `json:"filePath"` StartIndex int `json:"startIndex"` BlacklistThreshold int `json:"blacklistThreshold"` + MaxRetries int `json:"maxRetries"` // 最大重试次数 } // OpenAIConfig OpenAI API 配置 @@ -76,10 +77,11 @@ type PerformanceConfig struct { // LogConfig 日志配置 type LogConfig struct { - Level string `json:"level"` // debug, info, warn, error - Format string `json:"format"` // text, json - EnableFile bool `json:"enableFile"` // 是否启用文件日志 - FilePath string `json:"filePath"` // 日志文件路径 + Level string `json:"level"` // debug, info, warn, error + Format string `json:"format"` // text, json + EnableFile bool `json:"enableFile"` // 是否启用文件日志 + FilePath string `json:"filePath"` // 日志文件路径 + EnableRequest bool `json:"enableRequest"` // 是否启用请求日志 } // Config 应用配置 @@ -112,6 +114,7 @@ func LoadConfig() (*Config, error) { FilePath: getEnvOrDefault("KEYS_FILE", "keys.txt"), StartIndex: parseInteger(os.Getenv("START_INDEX"), 0), BlacklistThreshold: parseInteger(os.Getenv("BLACKLIST_THRESHOLD"), 1), + MaxRetries: parseInteger(os.Getenv("MAX_RETRIES"), 3), }, OpenAI: OpenAIConfig{ BaseURL: getEnvOrDefault("OPENAI_BASE_URL", "https://api.openai.com"), @@ -133,10 +136,11 @@ func LoadConfig() (*Config, error) { BufferSize: parseInteger(os.Getenv("BUFFER_SIZE"), 32*1024), }, Log: LogConfig{ - Level: getEnvOrDefault("LOG_LEVEL", "info"), - Format: getEnvOrDefault("LOG_FORMAT", "text"), - EnableFile: parseBoolean(os.Getenv("LOG_ENABLE_FILE"), false), - FilePath: getEnvOrDefault("LOG_FILE_PATH", "logs/app.log"), + Level: getEnvOrDefault("LOG_LEVEL", "info"), + Format: getEnvOrDefault("LOG_FORMAT", "text"), + EnableFile: parseBoolean(os.Getenv("LOG_ENABLE_FILE"), false), + FilePath: getEnvOrDefault("LOG_FILE_PATH", "logs/app.log"), + EnableRequest: parseBoolean(os.Getenv("LOG_ENABLE_REQUEST"), true), }, } @@ -205,6 +209,7 @@ func DisplayConfig(config *Config) { logrus.Infof(" 密钥文件: %s", config.Keys.FilePath) logrus.Infof(" 起始索引: %d", config.Keys.StartIndex) logrus.Infof(" 黑名单阈值: %d 次错误", config.Keys.BlacklistThreshold) + logrus.Infof(" 最大重试次数: %d", config.Keys.MaxRetries) logrus.Infof(" 上游地址: %s", config.OpenAI.BaseURL) logrus.Infof(" 请求超时: %dms", config.OpenAI.Timeout) @@ -233,6 +238,13 @@ func DisplayConfig(config *Config) { } logrus.Infof(" 压缩: %s", compressionStatus) logrus.Infof(" 缓冲区大小: %d bytes", config.Performance.BufferSize) + + // 显示日志配置 + requestLogStatus := "已启用" + if !config.Log.EnableRequest { + requestLogStatus = "已禁用" + } + logrus.Infof(" 请求日志: %s", requestLogStatus) } // 辅助函数 diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index bc46ba0..bc7130a 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -193,10 +193,11 @@ func (ps *ProxyServer) authMiddleware() gin.HandlerFunc { // loggerMiddleware 高性能日志中间件 func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - // 只在非生产环境或调试模式下记录详细日志 - if gin.Mode() == gin.ReleaseMode { - // 生产模式下只记录错误和关键信息 + // 检查是否启用请求日志 + if !config.AppConfig.Log.EnableRequest { + // 不记录请求日志,只处理请求 c.Next() + // 只记录错误 if c.Writer.Status() >= 400 { logrus.Errorf("Error %d: %s %s", c.Writer.Status(), c.Request.Method, c.Request.URL.Path) } @@ -231,8 +232,20 @@ func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc { } } - // 简化日志输出,减少格式化开销 - logrus.Infof("%s %s - %d - %v%s", method, fullPath, statusCode, latency, keyInfo) + // 获取重试信息(如果存在) + retryInfo := "" + if retryCount, exists := c.Get("retryCount"); exists { + retryInfo = fmt.Sprintf(" - Retry[%d]", retryCount) + } + + // 根据状态码选择日志级别 + if statusCode >= 500 { + logrus.Errorf("%s %s - %d - %v%s%s", method, fullPath, statusCode, latency, keyInfo, retryInfo) + } else if statusCode >= 400 { + logrus.Warnf("%s %s - %d - %v%s%s", method, fullPath, statusCode, latency, keyInfo, retryInfo) + } else { + logrus.Infof("%s %s - %d - %v%s%s", method, fullPath, statusCode, latency, keyInfo, retryInfo) + } } } @@ -296,6 +309,31 @@ func (ps *ProxyServer) handleProxy(c *gin.Context) { // 增加请求计数 atomic.AddInt64(&ps.requestCount, 1) + // 读取请求体用于重试 + var bodyBytes []byte + if c.Request.Body != nil { + var err error + bodyBytes, err = io.ReadAll(c.Request.Body) + if err != nil { + logrus.Errorf("读取请求体失败: %v", err) + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "读取请求体失败", + "type": "request_error", + "code": "invalid_request_body", + "timestamp": time.Now().Format(time.RFC3339), + }, + }) + return + } + } + + // 执行带重试的请求 + ps.executeRequestWithRetry(c, startTime, bodyBytes, 0) +} + +// executeRequestWithRetry 执行带重试的请求 +func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Time, bodyBytes []byte, retryCount int) { // 获取密钥信息 keyInfo, err := ps.keyManager.GetNextKey() if err != nil { @@ -315,9 +353,18 @@ func (ps *ProxyServer) handleProxy(c *gin.Context) { c.Set("keyIndex", keyInfo.Index) c.Set("keyPreview", keyInfo.Preview) - // 直接使用请求体,避免完整读取到内存 - var requestBody io.Reader = c.Request.Body - var contentLength int64 = c.Request.ContentLength + // 设置重试信息到上下文 + if retryCount > 0 { + c.Set("retryCount", retryCount) + } + + // 使用缓存的请求体 + var requestBody io.Reader + var contentLength int64 + if len(bodyBytes) > 0 { + requestBody = strings.NewReader(string(bodyBytes)) + contentLength = int64(len(bodyBytes)) + } // 构建上游请求URL targetURL := *ps.upstreamURL @@ -368,14 +415,28 @@ func (ps *ProxyServer) handleProxy(c *gin.Context) { resp, err := ps.httpClient.Do(req) if err != nil { responseTime := time.Since(startTime) - logrus.Errorf("代理请求失败: %v (响应时间: %v)", err, responseTime) + + // 记录失败日志 + if retryCount > 0 { + logrus.Warnf("重试请求失败 (第 %d 次): %v (响应时间: %v)", retryCount, err, responseTime) + } else { + logrus.Warnf("请求失败: %v (响应时间: %v)", err, responseTime) + } // 异步记录失败 go ps.keyManager.RecordFailure(keyInfo.Key, err) + // 检查是否可以重试 + if retryCount < config.AppConfig.Keys.MaxRetries { + logrus.Infof("准备重试请求 (第 %d/%d 次)", retryCount+1, config.AppConfig.Keys.MaxRetries) + ps.executeRequestWithRetry(c, startTime, bodyBytes, retryCount+1) + return + } + + // 达到最大重试次数 c.JSON(http.StatusBadGateway, gin.H{ "error": gin.H{ - "message": "代理请求失败: " + err.Error(), + "message": fmt.Sprintf("代理请求失败 (已重试 %d 次): %s", retryCount, err.Error()), "type": "proxy_error", "code": "request_failed", "timestamp": time.Now().Format(time.RFC3339), @@ -387,6 +448,27 @@ func (ps *ProxyServer) handleProxy(c *gin.Context) { responseTime := time.Since(startTime) + // 检查HTTP状态码是否需要重试 + // 429 (Too Many Requests) 和 5xx 服务器错误都需要重试 + if (resp.StatusCode == 429 || resp.StatusCode >= 500) && retryCount < config.AppConfig.Keys.MaxRetries { + // 记录失败日志 + if retryCount > 0 { + logrus.Warnf("重试请求返回错误 %d (第 %d 次) (响应时间: %v)", resp.StatusCode, retryCount, responseTime) + } else { + logrus.Warnf("请求返回错误 %d (响应时间: %v)", resp.StatusCode, responseTime) + } + + // 异步记录失败 + go ps.keyManager.RecordFailure(keyInfo.Key, fmt.Errorf("HTTP %d", resp.StatusCode)) + + // 关闭当前响应 + resp.Body.Close() + + logrus.Infof("准备重试请求 (第 %d/%d 次)", retryCount+1, config.AppConfig.Keys.MaxRetries) + ps.executeRequestWithRetry(c, startTime, bodyBytes, retryCount+1) + return + } + // 异步记录统计信息(不阻塞响应) go func() { if resp.StatusCode >= 200 && resp.StatusCode < 400 { diff --git a/scripts/validate_config.go b/scripts/validate_config.go deleted file mode 100644 index 7aacfc0..0000000 --- a/scripts/validate_config.go +++ /dev/null @@ -1,116 +0,0 @@ -// 配置验证脚本 -package main - -import ( - "fmt" - "reflect" - - "github.com/joho/godotenv" - "github.com/sirupsen/logrus" - - "gpt-load/internal/config" -) - -func main() { - // 加载测试配置 - if err := godotenv.Load("test_config.env"); err != nil { - logrus.Warnf("无法加载测试配置文件: %v", err) - } - - // 加载配置 - cfg, err := config.LoadConfig() - if err != nil { - logrus.Fatalf("配置加载失败: %v", err) - } - - fmt.Println("🔍 配置验证报告") - fmt.Println("=" * 50) - - // 验证服务器配置 - fmt.Printf("📡 服务器配置:\n") - fmt.Printf(" Host: %s\n", cfg.Server.Host) - fmt.Printf(" Port: %d\n", cfg.Server.Port) - fmt.Println() - - // 验证密钥配置 - fmt.Printf("🔑 密钥配置:\n") - fmt.Printf(" 文件路径: %s\n", cfg.Keys.FilePath) - fmt.Printf(" 起始索引: %d\n", cfg.Keys.StartIndex) - fmt.Printf(" 黑名单阈值: %d\n", cfg.Keys.BlacklistThreshold) - fmt.Println() - - // 验证 OpenAI 配置 - fmt.Printf("🤖 OpenAI 配置:\n") - fmt.Printf(" Base URL: %s\n", cfg.OpenAI.BaseURL) - fmt.Printf(" 超时时间: %dms\n", cfg.OpenAI.Timeout) - fmt.Println() - - // 验证认证配置 - fmt.Printf("🔐 认证配置:\n") - fmt.Printf(" 启用状态: %t\n", cfg.Auth.Enabled) - if cfg.Auth.Enabled { - fmt.Printf(" 密钥长度: %d\n", len(cfg.Auth.Key)) - } - fmt.Println() - - // 验证 CORS 配置 - fmt.Printf("🌐 CORS 配置:\n") - fmt.Printf(" 启用状态: %t\n", cfg.CORS.Enabled) - fmt.Printf(" 允许来源: %v\n", cfg.CORS.AllowedOrigins) - fmt.Println() - - // 验证性能配置 - fmt.Printf("⚡ 性能配置:\n") - fmt.Printf(" 最大连接数: %d\n", cfg.Performance.MaxSockets) - fmt.Printf(" 最大空闲连接数: %d\n", cfg.Performance.MaxFreeSockets) - fmt.Printf(" Keep-Alive: %t\n", cfg.Performance.EnableKeepAlive) - fmt.Printf(" 禁用压缩: %t\n", cfg.Performance.DisableCompression) - fmt.Printf(" 缓冲区大小: %d bytes\n", cfg.Performance.BufferSize) - fmt.Println() - - // 验证日志配置 - fmt.Printf("📝 日志配置:\n") - fmt.Printf(" 日志级别: %s\n", cfg.Log.Level) - fmt.Printf(" 日志格式: %s\n", cfg.Log.Format) - fmt.Printf(" 文件日志: %t\n", cfg.Log.EnableFile) - if cfg.Log.EnableFile { - fmt.Printf(" 文件路径: %s\n", cfg.Log.FilePath) - } - fmt.Println() - - // 检查配置完整性 - fmt.Printf("✅ 配置完整性检查:\n") - checkConfigCompleteness(cfg) - - fmt.Println("🎉 配置验证完成!") -} - -func checkConfigCompleteness(cfg *config.Config) { - v := reflect.ValueOf(cfg).Elem() - t := reflect.TypeOf(cfg).Elem() - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - fieldType := t.Field(i) - - if field.Kind() == reflect.Struct { - checkStruct(field, fieldType.Name) - } - } -} - -func checkStruct(v reflect.Value, name string) { - t := v.Type() - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - fieldType := t.Field(i) - - // 检查字段是否为零值 - if field.IsZero() && fieldType.Name != "Enabled" { - fmt.Printf(" ⚠️ %s.%s 为零值\n", name, fieldType.Name) - } else { - fmt.Printf(" ✅ %s.%s 已配置\n", name, fieldType.Name) - } - } -}