diff --git a/.env.example b/.env.example index 9bd9044..a0cc7ff 100644 --- a/.env.example +++ b/.env.example @@ -32,6 +32,24 @@ OPENAI_BASE_URL=https://api.openai.com # 请求超时时间(毫秒) REQUEST_TIMEOUT=30000 +# =========================================== +# 性能优化配置 +# =========================================== +# 最大连接数 +MAX_SOCKETS=100 + +# 最大空闲连接数 +MAX_FREE_SOCKETS=20 + +# 启用 Keep-Alive +ENABLE_KEEP_ALIVE=true + +# 禁用压缩(减少CPU开销) +DISABLE_COMPRESSION=true + +# 缓冲区大小(字节) +BUFFER_SIZE=32768 + # =========================================== # 认证配置 # =========================================== diff --git a/cmd/main.go b/cmd/main.go index d71ffcd..0ffb1d0 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -44,13 +44,14 @@ func main() { // 设置路由 router := proxyServer.SetupRoutes() - // 创建HTTP服务器 + // 创建HTTP服务器,优化超时配置 server := &http.Server{ - Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), - Handler: router, - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - IdleTimeout: 60 * time.Second, + Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), + Handler: router, + ReadTimeout: 60 * time.Second, // 增加读超时,支持大文件上传 + WriteTimeout: 300 * time.Second, // 增加写超时,支持流式响应 + IdleTimeout: 120 * time.Second, // 增加空闲超时,复用连接 + MaxHeaderBytes: 1 << 20, // 1MB header limit } // 启动服务器 @@ -90,7 +91,7 @@ func main() { func displayStartupInfo(cfg *config.Config) { logrus.Info("🚀 OpenAI 多密钥代理服务器 v2.0.0 (Go版本)") logrus.Info("") - + // 显示配置 config.DisplayConfig(cfg) logrus.Info("") diff --git a/internal/config/config.go b/internal/config/config.go index edf6a45..1c0663c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -16,21 +16,21 @@ import ( // Constants 配置常量 type Constants struct { - MinPort int - MaxPort int - MinTimeout int - DefaultTimeout int - DefaultMaxSockets int + MinPort int + MaxPort int + MinTimeout int + DefaultTimeout int + DefaultMaxSockets int DefaultMaxFreeSockets int } // DefaultConstants 默认常量 var DefaultConstants = Constants{ - MinPort: 1, - MaxPort: 65535, - MinTimeout: 1000, - DefaultTimeout: 30000, - DefaultMaxSockets: 50, + MinPort: 1, + MaxPort: 65535, + MinTimeout: 1000, + DefaultTimeout: 30000, + DefaultMaxSockets: 50, DefaultMaxFreeSockets: 10, } @@ -67,8 +67,11 @@ type CORSConfig struct { // PerformanceConfig 性能配置 type PerformanceConfig struct { - MaxSockets int `json:"maxSockets"` - MaxFreeSockets int `json:"maxFreeSockets"` + MaxSockets int `json:"maxSockets"` + MaxFreeSockets int `json:"maxFreeSockets"` + EnableKeepAlive bool `json:"enableKeepAlive"` + DisableCompression bool `json:"disableCompression"` + BufferSize int `json:"bufferSize"` } // Config 应用配置 @@ -114,8 +117,11 @@ func LoadConfig() (*Config, error) { AllowedOrigins: parseArray(os.Getenv("ALLOWED_ORIGINS"), []string{"*"}), }, Performance: PerformanceConfig{ - MaxSockets: parseInteger(os.Getenv("MAX_SOCKETS"), DefaultConstants.DefaultMaxSockets), - MaxFreeSockets: parseInteger(os.Getenv("MAX_FREE_SOCKETS"), DefaultConstants.DefaultMaxFreeSockets), + MaxSockets: parseInteger(os.Getenv("MAX_SOCKETS"), DefaultConstants.DefaultMaxSockets), + MaxFreeSockets: parseInteger(os.Getenv("MAX_FREE_SOCKETS"), DefaultConstants.DefaultMaxFreeSockets), + EnableKeepAlive: parseBoolean(os.Getenv("ENABLE_KEEP_ALIVE"), true), + DisableCompression: parseBoolean(os.Getenv("DISABLE_COMPRESSION"), true), + BufferSize: parseInteger(os.Getenv("BUFFER_SIZE"), 32*1024), }, } @@ -186,13 +192,13 @@ func DisplayConfig(config *Config) { logrus.Infof(" 黑名单阈值: %d 次错误", config.Keys.BlacklistThreshold) logrus.Infof(" 上游地址: %s", config.OpenAI.BaseURL) logrus.Infof(" 请求超时: %dms", config.OpenAI.Timeout) - + authStatus := "未启用" if config.Auth.Enabled { authStatus = "已启用" } logrus.Infof(" 认证: %s", authStatus) - + corsStatus := "已禁用" if config.CORS.Enabled { corsStatus = "已启用" @@ -227,7 +233,7 @@ func parseArray(value string, defaultValue []string) []string { if value == "" { return defaultValue } - + parts := strings.Split(value, ",") result := make([]string, 0, len(parts)) for _, part := range parts { @@ -235,7 +241,7 @@ func parseArray(value string, defaultValue []string) []string { result = append(result, trimmed) } } - + if len(result) == 0 { return defaultValue } diff --git a/internal/keymanager/keymanager.go b/internal/keymanager/keymanager.go index 4da7c58..02b89d3 100644 --- a/internal/keymanager/keymanager.go +++ b/internal/keymanager/keymanager.go @@ -27,13 +27,13 @@ type KeyInfo struct { // Stats 统计信息 type Stats struct { - CurrentIndex int64 `json:"currentIndex"` - TotalKeys int `json:"totalKeys"` - HealthyKeys int `json:"healthyKeys"` - BlacklistedKeys int `json:"blacklistedKeys"` - SuccessCount int64 `json:"successCount"` - FailureCount int64 `json:"failureCount"` - MemoryUsage MemoryUsage `json:"memoryUsage"` + CurrentIndex int64 `json:"currentIndex"` + TotalKeys int `json:"totalKeys"` + HealthyKeys int `json:"healthyKeys"` + BlacklistedKeys int `json:"blacklistedKeys"` + SuccessCount int64 `json:"successCount"` + FailureCount int64 `json:"failureCount"` + MemoryUsage MemoryUsage `json:"memoryUsage"` } // MemoryUsage 内存使用情况 @@ -60,22 +60,22 @@ type BlacklistInfo struct { // KeyManager 密钥管理器 type KeyManager struct { - keysFilePath string - keys []string - keyPreviews []string - currentIndex int64 - blacklistedKeys sync.Map - successCount int64 - failureCount int64 + keysFilePath string + keys []string + keyPreviews []string + currentIndex int64 + blacklistedKeys sync.Map + successCount int64 + failureCount int64 keyFailureCounts sync.Map - + // 性能优化:预编译正则表达式 permanentErrorPatterns []*regexp.Regexp - + // 内存管理 cleanupTicker *time.Ticker stopCleanup chan bool - + // 读写锁保护密钥列表 keysMutex sync.RWMutex } @@ -85,12 +85,12 @@ func NewKeyManager(keysFilePath string) *KeyManager { if keysFilePath == "" { keysFilePath = config.AppConfig.Keys.FilePath } - + km := &KeyManager{ keysFilePath: keysFilePath, currentIndex: int64(config.AppConfig.Keys.StartIndex), stopCleanup: make(chan bool), - + // 预编译正则表达式 permanentErrorPatterns: []*regexp.Regexp{ regexp.MustCompile(`(?i)invalid api key`), @@ -101,10 +101,10 @@ func NewKeyManager(keysFilePath string) *KeyManager { regexp.MustCompile(`(?i)billing`), }, } - + // 启动内存清理 km.setupMemoryCleanup() - + return km } @@ -115,25 +115,25 @@ func (km *KeyManager) LoadKeys() error { return fmt.Errorf("无法打开密钥文件: %w", err) } defer file.Close() - + var keys []string scanner := bufio.NewScanner(file) - + for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line != "" && strings.HasPrefix(line, "sk-") { keys = append(keys, line) } } - + if err := scanner.Err(); err != nil { return fmt.Errorf("读取密钥文件失败: %w", err) } - + if len(keys) == 0 { return fmt.Errorf("密钥文件中没有有效的API密钥") } - + km.keysMutex.Lock() km.keys = keys // 预生成密钥预览,避免运行时重复计算 @@ -146,7 +146,7 @@ func (km *KeyManager) LoadKeys() error { } } km.keysMutex.Unlock() - + logrus.Infof("✅ 成功加载 %d 个 API 密钥", len(keys)) return nil } @@ -155,36 +155,40 @@ func (km *KeyManager) LoadKeys() error { func (km *KeyManager) GetNextKey() (*KeyInfo, error) { km.keysMutex.RLock() keysLen := len(km.keys) - km.keysMutex.RUnlock() - if keysLen == 0 { + km.keysMutex.RUnlock() return nil, fmt.Errorf("没有可用的 API 密钥") } - - // 检查是否所有密钥都被拉黑 - blacklistedCount := 0 - km.blacklistedKeys.Range(func(key, value interface{}) bool { - blacklistedCount++ - return true - }) - - if blacklistedCount >= keysLen { - logrus.Warn("⚠️ 所有密钥都被拉黑,重置黑名单") - km.blacklistedKeys = sync.Map{} - km.keyFailureCounts = sync.Map{} + + // 快速路径:直接获取下一个密钥,避免黑名单检查的开销 + currentIdx := atomic.AddInt64(&km.currentIndex, 1) - 1 + keyIndex := int(currentIdx) % keysLen + selectedKey := km.keys[keyIndex] + keyPreview := km.keyPreviews[keyIndex] + km.keysMutex.RUnlock() + + // 检查是否被拉黑 + if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted { + return &KeyInfo{ + Key: selectedKey, + Index: keyIndex, + Preview: keyPreview, + }, nil } - - // 使用原子操作避免竞态条件 + + // 慢速路径:寻找可用密钥 attempts := 0 - for attempts < keysLen { - currentIdx := atomic.AddInt64(&km.currentIndex, 1) - 1 - keyIndex := int(currentIdx) % keysLen - + maxAttempts := keysLen * 2 // 最多尝试两轮 + + for attempts < maxAttempts { + currentIdx = atomic.AddInt64(&km.currentIndex, 1) - 1 + keyIndex = int(currentIdx) % keysLen + km.keysMutex.RLock() - selectedKey := km.keys[keyIndex] - keyPreview := km.keyPreviews[keyIndex] + selectedKey = km.keys[keyIndex] + keyPreview = km.keyPreviews[keyIndex] km.keysMutex.RUnlock() - + if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted { return &KeyInfo{ Key: selectedKey, @@ -192,21 +196,36 @@ func (km *KeyManager) GetNextKey() (*KeyInfo, error) { Preview: keyPreview, }, nil } - + attempts++ } - - // 兜底:返回第一个密钥 - km.keysMutex.RLock() - firstKey := km.keys[0] - firstPreview := km.keyPreviews[0] - km.keysMutex.RUnlock() - - return &KeyInfo{ - Key: firstKey, - Index: 0, - Preview: firstPreview, - }, nil + + // 检查是否所有密钥都被拉黑,如果是则重置 + blacklistedCount := 0 + km.blacklistedKeys.Range(func(key, value interface{}) bool { + blacklistedCount++ + return blacklistedCount < keysLen // 提前退出优化 + }) + + if blacklistedCount >= keysLen { + logrus.Warn("⚠️ 所有密钥都被拉黑,重置黑名单") + km.blacklistedKeys = sync.Map{} + km.keyFailureCounts = sync.Map{} + + // 重置后返回第一个密钥 + km.keysMutex.RLock() + firstKey := km.keys[0] + firstPreview := km.keyPreviews[0] + km.keysMutex.RUnlock() + + return &KeyInfo{ + Key: firstKey, + Index: 0, + Preview: firstPreview, + }, nil + } + + return nil, fmt.Errorf("暂时没有可用的 API 密钥") } // RecordSuccess 记录密钥使用成功 @@ -219,7 +238,7 @@ func (km *KeyManager) RecordSuccess(key string) { // RecordFailure 记录密钥使用失败 func (km *KeyManager) RecordFailure(key string, err error) { atomic.AddInt64(&km.failureCount, 1) - + // 检查是否是永久性错误 if km.isPermanentError(err) { km.blacklistedKeys.Store(key, true) @@ -227,7 +246,7 @@ func (km *KeyManager) RecordFailure(key string, err error) { logrus.Warnf("🚫 密钥已被拉黑(永久性错误): %s (%s)", key[:20]+"...", err.Error()) return } - + // 临时性错误:增加失败计数 currentFailures := 0 if val, exists := km.keyFailureCounts.Load(key); exists { @@ -235,7 +254,7 @@ func (km *KeyManager) RecordFailure(key string, err error) { } newFailures := currentFailures + 1 km.keyFailureCounts.Store(key, newFailures) - + threshold := config.AppConfig.Keys.BlacklistThreshold if newFailures >= threshold { km.blacklistedKeys.Store(key, true) @@ -262,19 +281,19 @@ func (km *KeyManager) GetStats() *Stats { km.keysMutex.RLock() totalKeys := len(km.keys) km.keysMutex.RUnlock() - + blacklistedCount := 0 km.blacklistedKeys.Range(func(key, value interface{}) bool { blacklistedCount++ return true }) - + failureCountsSize := 0 km.keyFailureCounts.Range(func(key, value interface{}) bool { failureCountsSize++ return true }) - + return &Stats{ CurrentIndex: atomic.LoadInt64(&km.currentIndex), TotalKeys: totalKeys, @@ -296,16 +315,16 @@ func (km *KeyManager) ResetKeys() map[string]interface{} { beforeCount++ return true }) - + km.blacklistedKeys = sync.Map{} km.keyFailureCounts = sync.Map{} - + logrus.Infof("🔄 密钥状态已重置,清除了 %d 个黑名单密钥", beforeCount) - + km.keysMutex.RLock() totalKeys := len(km.keys) km.keysMutex.RUnlock() - + return map[string]interface{}{ "success": true, "message": fmt.Sprintf("已清除 %d 个黑名单密钥", beforeCount), @@ -317,12 +336,12 @@ func (km *KeyManager) ResetKeys() map[string]interface{} { // GetBlacklistDetails 获取黑名单详情 func (km *KeyManager) GetBlacklistDetails() *BlacklistInfo { var blacklistDetails []BlacklistDetail - + km.keysMutex.RLock() keys := km.keys keyPreviews := km.keyPreviews km.keysMutex.RUnlock() - + for i, key := range keys { if _, blacklisted := km.blacklistedKeys.Load(key); blacklisted { blacklistDetails = append(blacklistDetails, BlacklistDetail{ @@ -333,7 +352,7 @@ func (km *KeyManager) GetBlacklistDetails() *BlacklistInfo { }) } } - + return &BlacklistInfo{ TotalBlacklisted: len(blacklistDetails), TotalKeys: len(keys), @@ -345,7 +364,7 @@ func (km *KeyManager) GetBlacklistDetails() *BlacklistInfo { // setupMemoryCleanup 设置内存清理机制 func (km *KeyManager) setupMemoryCleanup() { km.cleanupTicker = time.NewTicker(10 * time.Minute) - + go func() { for { select { @@ -367,20 +386,20 @@ func (km *KeyManager) performMemoryCleanup() { maxSize = 1000 } km.keysMutex.RUnlock() - + currentSize := 0 km.keyFailureCounts.Range(func(key, value interface{}) bool { currentSize++ return true }) - + if currentSize > maxSize { logrus.Infof("🧹 清理失败计数缓存 (%d -> %d)", currentSize, maxSize) - + // 简单策略:清理一半的失败计数 cleared := 0 target := currentSize - maxSize - + km.keyFailureCounts.Range(func(key, value interface{}) bool { if cleared < target { km.keyFailureCounts.Delete(key) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index abf0477..2035462 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -4,7 +4,6 @@ package proxy import ( - "bytes" "context" "fmt" "io" @@ -46,16 +45,22 @@ func NewProxyServer() (*ProxyServer, error) { // 创建高性能HTTP客户端 transport := &http.Transport{ - MaxIdleConns: config.AppConfig.Performance.MaxSockets, - MaxIdleConnsPerHost: config.AppConfig.Performance.MaxFreeSockets, - IdleConnTimeout: 90 * time.Second, - DisableCompression: false, - ForceAttemptHTTP2: true, + MaxIdleConns: config.AppConfig.Performance.MaxSockets, + MaxIdleConnsPerHost: config.AppConfig.Performance.MaxFreeSockets, + MaxConnsPerHost: 0, // 无限制,避免连接池瓶颈 + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DisableCompression: true, // 禁用压缩以减少CPU开销,让上游处理 + ForceAttemptHTTP2: true, + WriteBufferSize: 32 * 1024, // 32KB写缓冲 + ReadBufferSize: 32 * 1024, // 32KB读缓冲 } httpClient := &http.Client{ Transport: transport, - Timeout: time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond, + // 移除全局超时,使用更细粒度的超时控制 + // Timeout: time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond, } return &ProxyServer{ @@ -180,9 +185,19 @@ func (ps *ProxyServer) authMiddleware() gin.HandlerFunc { } } -// loggerMiddleware 自定义日志中间件 +// loggerMiddleware 高性能日志中间件 func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc { return func(c *gin.Context) { + // 只在非生产环境或调试模式下记录详细日志 + if gin.Mode() == gin.ReleaseMode { + // 生产模式下只记录错误和关键信息 + c.Next() + if c.Writer.Status() >= 400 { + logrus.Errorf("Error %d: %s %s", c.Writer.Status(), c.Request.Method, c.Request.URL.Path) + } + return + } + start := time.Now() path := c.Request.URL.Path raw := c.Request.URL.RawQuery @@ -193,16 +208,14 @@ func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc { // 计算响应时间 latency := time.Since(start) - // 获取客户端IP - clientIP := c.ClientIP() - - // 获取方法和状态码 + // 获取基本信息 method := c.Request.Method statusCode := c.Writer.Status() - // 构建完整路径 + // 构建完整路径(避免字符串拼接) + fullPath := path if raw != "" { - path = path + "?" + raw + fullPath = path + "?" + raw } // 获取密钥信息(如果存在) @@ -213,22 +226,8 @@ func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc { } } - // 根据状态码选择颜色 - var statusColor string - if statusCode >= 200 && statusCode < 300 { - statusColor = "\033[32m" // 绿色 - } else { - statusColor = "\033[31m" // 红色 - } - resetColor := "\033[0m" - keyColor := "\033[36m" // 青色 - - // 输出日志 - logrus.Infof("%s[%s] %s %s%s%s%s - %s%d%s - %v - %s", - statusColor, time.Now().Format(time.RFC3339), method, path, resetColor, - keyColor, keyInfo, resetColor, - statusColor, statusCode, resetColor, - latency, clientIP) + // 简化日志输出,减少格式化开销 + logrus.Infof("%s %s - %d - %v%s", method, fullPath, statusCode, latency, keyInfo) } } @@ -311,36 +310,30 @@ func (ps *ProxyServer) handleProxy(c *gin.Context) { c.Set("keyIndex", keyInfo.Index) c.Set("keyPreview", keyInfo.Preview) - // 读取请求体 - var bodyBytes []byte - if c.Request.Body != nil { - bodyBytes, err = io.ReadAll(c.Request.Body) - if err != nil { - logrus.Errorf("读取请求体失败: %v", err) - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "读取请求体失败", - "type": "request_error", - "code": "invalid_request_body", - "timestamp": time.Now().Format(time.RFC3339), - }, - }) - return - } - } + // 直接使用请求体,避免完整读取到内存 + var requestBody io.Reader = c.Request.Body + var contentLength int64 = c.Request.ContentLength // 构建上游请求URL targetURL := *ps.upstreamURL targetURL.Path = c.Request.URL.Path targetURL.RawQuery = c.Request.URL.RawQuery - // 创建上游请求 + // 创建带超时的上下文 + timeout := time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + // 创建上游请求,直接使用原始请求体进行流式传输 req, err := http.NewRequestWithContext( - context.Background(), + ctx, c.Request.Method, targetURL.String(), - bytes.NewReader(bodyBytes), + requestBody, ) + if err == nil && contentLength > 0 { + req.ContentLength = contentLength + } if err != nil { logrus.Errorf("创建上游请求失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{