diff --git a/Makefile b/Makefile index e9604c9..8e7f8cf 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ # 变量定义 BINARY_NAME=gpt-load -MAIN_PATH=./cmd/main.go +MAIN_PATH=./cmd/gpt-load BUILD_DIR=./build VERSION=2.0.0 LDFLAGS=-ldflags "-X main.Version=$(VERSION) -s -w" diff --git a/cmd/gpt-load/main.go b/cmd/gpt-load/main.go new file mode 100644 index 0000000..e19df2c --- /dev/null +++ b/cmd/gpt-load/main.go @@ -0,0 +1,224 @@ +// Package main provides the entry point for the GPT-Load proxy server +package main + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "os/signal" + "path/filepath" + "syscall" + "time" + + "gpt-load/internal/config" + "gpt-load/internal/handler" + "gpt-load/internal/keymanager" + "gpt-load/internal/middleware" + "gpt-load/internal/proxy" + "gpt-load/pkg/types" + + "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" +) + +func main() { + // Load configuration + configManager, err := config.NewManager() + if err != nil { + logrus.Fatalf("Failed to load configuration: %v", err) + } + + // Setup logger + setupLogger(configManager) + + // Display startup information + displayStartupInfo(configManager) + + // Create key manager + keyManager, err := keymanager.NewManager(configManager.GetKeysConfig()) + if err != nil { + logrus.Fatalf("Failed to create key manager: %v", err) + } + defer keyManager.Close() + + // Create proxy server + proxyServer, err := proxy.NewProxyServer(keyManager, configManager) + if err != nil { + logrus.Fatalf("Failed to create proxy server: %v", err) + } + defer proxyServer.Close() + + // Create handlers + handlers := handler.NewHandler(keyManager, configManager) + + // Setup routes + router := setupRoutes(handlers, proxyServer, configManager) + + // Create HTTP server with optimized timeout configuration + serverConfig := configManager.GetServerConfig() + server := &http.Server{ + Addr: fmt.Sprintf("%s:%d", serverConfig.Host, serverConfig.Port), + Handler: router, + ReadTimeout: 60 * time.Second, // Increased read timeout for large file uploads + WriteTimeout: 300 * time.Second, // Increased write timeout for streaming responses + IdleTimeout: 120 * time.Second, // Increased idle timeout for connection reuse + MaxHeaderBytes: 1 << 20, // 1MB header limit + } + + // Start server + go func() { + logrus.Info("GPT-Load proxy server started successfully") + logrus.Infof("Server address: http://%s:%d", serverConfig.Host, serverConfig.Port) + logrus.Infof("Statistics: http://%s:%d/stats", serverConfig.Host, serverConfig.Port) + logrus.Infof("Health check: http://%s:%d/health", serverConfig.Host, serverConfig.Port) + logrus.Infof("Reset keys: http://%s:%d/reset-keys", serverConfig.Host, serverConfig.Port) + logrus.Infof("Blacklist query: http://%s:%d/blacklist", serverConfig.Host, serverConfig.Port) + logrus.Info("") + + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logrus.Fatalf("Server startup failed: %v", err) + } + }() + + // Wait for interrupt signal to gracefully shutdown the server + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + logrus.Info("Shutting down server...") + + // Give outstanding requests a deadline for completion + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Attempt graceful shutdown + if err := server.Shutdown(ctx); err != nil { + logrus.Errorf("Server forced to shutdown: %v", err) + } else { + logrus.Info("Server exited gracefully") + } +} + +// setupRoutes configures the HTTP routes +func setupRoutes(handlers *handler.Handler, proxyServer *proxy.ProxyServer, configManager types.ConfigManager) *gin.Engine { + // Set Gin mode + gin.SetMode(gin.ReleaseMode) + + router := gin.New() + + // Add middleware + router.Use(middleware.Recovery()) + router.Use(middleware.Logger(configManager.GetLogConfig())) + router.Use(middleware.CORS(configManager.GetCORSConfig())) + router.Use(middleware.RateLimiter(configManager.GetPerformanceConfig())) + + // Add authentication middleware if enabled + if configManager.GetAuthConfig().Enabled { + router.Use(middleware.Auth(configManager.GetAuthConfig())) + } + + // Management endpoints + router.GET("/health", handlers.Health) + router.GET("/stats", handlers.Stats) + router.GET("/blacklist", handlers.Blacklist) + router.GET("/reset-keys", handlers.ResetKeys) + router.GET("/config", handlers.GetConfig) // Debug endpoint + + // Handle 404 and 405 + router.NoRoute(handlers.NotFound) + router.NoMethod(handlers.MethodNotAllowed) + + // Proxy all other requests + router.NoRoute(proxyServer.HandleProxy) + + return router +} + +// setupLogger configures the logging system +func setupLogger(configManager types.ConfigManager) { + logConfig := configManager.GetLogConfig() + + // Set log level + level, err := logrus.ParseLevel(logConfig.Level) + if err != nil { + logrus.Warn("Invalid log level, using info") + level = logrus.InfoLevel + } + logrus.SetLevel(level) + + // Set log format + if logConfig.Format == "json" { + logrus.SetFormatter(&logrus.JSONFormatter{ + TimestampFormat: time.RFC3339, + }) + } else { + logrus.SetFormatter(&logrus.TextFormatter{ + FullTimestamp: true, + TimestampFormat: "2006-01-02 15:04:05", + }) + } + + // Setup file logging if enabled + if logConfig.EnableFile { + // Create log directory if it doesn't exist + logDir := filepath.Dir(logConfig.FilePath) + if err := os.MkdirAll(logDir, 0755); err != nil { + logrus.Warnf("Failed to create log directory: %v", err) + } else { + // Open log file + logFile, err := os.OpenFile(logConfig.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + logrus.Warnf("Failed to open log file: %v", err) + } else { + // Use both file and stdout + logrus.SetOutput(io.MultiWriter(os.Stdout, logFile)) + } + } + } +} + +// displayStartupInfo shows startup information +func displayStartupInfo(configManager types.ConfigManager) { + serverConfig := configManager.GetServerConfig() + keysConfig := configManager.GetKeysConfig() + openaiConfig := configManager.GetOpenAIConfig() + authConfig := configManager.GetAuthConfig() + corsConfig := configManager.GetCORSConfig() + perfConfig := configManager.GetPerformanceConfig() + logConfig := configManager.GetLogConfig() + + logrus.Info("Current Configuration:") + logrus.Infof(" Server: %s:%d", serverConfig.Host, serverConfig.Port) + logrus.Infof(" Keys file: %s", keysConfig.FilePath) + logrus.Infof(" Start index: %d", keysConfig.StartIndex) + logrus.Infof(" Blacklist threshold: %d errors", keysConfig.BlacklistThreshold) + logrus.Infof(" Max retries: %d", keysConfig.MaxRetries) + logrus.Infof(" Upstream URL: %s", openaiConfig.BaseURL) + logrus.Infof(" Request timeout: %dms", openaiConfig.Timeout) + + authStatus := "disabled" + if authConfig.Enabled { + authStatus = "enabled" + } + logrus.Infof(" Authentication: %s", authStatus) + + corsStatus := "disabled" + if corsConfig.Enabled { + corsStatus = "enabled" + } + logrus.Infof(" CORS: %s", corsStatus) + logrus.Infof(" Max concurrent requests: %d", perfConfig.MaxConcurrentRequests) + + gzipStatus := "disabled" + if perfConfig.EnableGzip { + gzipStatus = "enabled" + } + logrus.Infof(" Gzip compression: %s", gzipStatus) + + requestLogStatus := "enabled" + if !logConfig.EnableRequest { + requestLogStatus = "disabled" + } + logrus.Infof(" Request logging: %s", requestLogStatus) +} diff --git a/cmd/main.go b/cmd/main.go deleted file mode 100644 index 4e2e723..0000000 --- a/cmd/main.go +++ /dev/null @@ -1,139 +0,0 @@ -// Package main OpenAI多密钥代理服务器主入口 -// @author OpenAI Proxy Team -// @version 2.0.0 -package main - -import ( - "context" - "fmt" - "io" - "net/http" - "os" - "os/signal" - "path/filepath" - "syscall" - "time" - - "openai-multi-key-proxy/internal/config" - "openai-multi-key-proxy/internal/proxy" - - "github.com/sirupsen/logrus" -) - -func main() { - // 加载配置 - cfg, err := config.LoadConfig() - if err != nil { - logrus.Fatalf("加载配置失败: %v", err) - } - - // 配置日志 - setupLogger(cfg) - - // 显示启动信息 - displayStartupInfo(cfg) - - // 创建代理服务器 - proxyServer, err := proxy.NewProxyServer() - if err != nil { - logrus.Fatalf("❌ 创建代理服务器失败: %v", err) - } - defer proxyServer.Close() - - // 设置路由 - router := proxyServer.SetupRoutes() - - // 创建HTTP服务器,优化超时配置 - server := &http.Server{ - Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), - Handler: router, - ReadTimeout: 60 * time.Second, // 增加读超时,支持大文件上传 - WriteTimeout: 300 * time.Second, // 增加写超时,支持流式响应 - IdleTimeout: 120 * time.Second, // 增加空闲超时,复用连接 - MaxHeaderBytes: 1 << 20, // 1MB header limit - } - - // 启动服务器 - go func() { - logrus.Infof("🚀 OpenAI 多密钥代理服务器启动成功") - logrus.Infof("📡 服务地址: http://%s:%d", cfg.Server.Host, cfg.Server.Port) - logrus.Infof("📊 统计信息: http://%s:%d/stats", cfg.Server.Host, cfg.Server.Port) - logrus.Infof("💚 健康检查: http://%s:%d/health", cfg.Server.Host, cfg.Server.Port) - logrus.Infof("🔄 重置密钥: http://%s:%d/reset-keys", cfg.Server.Host, cfg.Server.Port) - logrus.Infof("🚫 黑名单查询: http://%s:%d/blacklist", cfg.Server.Host, cfg.Server.Port) - logrus.Info("") - - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logrus.Fatalf("❌ 服务器启动失败: %v", err) - } - }() - - // 等待中断信号 - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - <-quit - - logrus.Info("🛑 收到关闭信号,正在优雅关闭服务器...") - - // 优雅关闭 - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - if err := server.Shutdown(ctx); err != nil { - logrus.Errorf("❌ 服务器关闭失败: %v", err) - } else { - logrus.Info("✅ 服务器已优雅关闭") - } -} - -// displayStartupInfo 显示启动信息 -func displayStartupInfo(cfg *config.Config) { - logrus.Info("🚀 OpenAI 多密钥代理服务器 v2.0.0 (Go版本)") - logrus.Info("") - - // 显示配置 - config.DisplayConfig(cfg) - logrus.Info("") -} - -// setupLogger 配置日志系统 -func setupLogger(cfg *config.Config) { - // 设置日志级别 - level, err := logrus.ParseLevel(cfg.Log.Level) - if err != nil { - logrus.Warnf("无效的日志级别 '%s',使用默认级别 info", cfg.Log.Level) - level = logrus.InfoLevel - } - logrus.SetLevel(level) - - // 设置日志格式 - switch cfg.Log.Format { - case "json": - logrus.SetFormatter(&logrus.JSONFormatter{ - TimestampFormat: "2006-01-02 15:04:05", - }) - default: - logrus.SetFormatter(&logrus.TextFormatter{ - FullTimestamp: true, - ForceColors: true, - TimestampFormat: "2006-01-02 15:04:05", - }) - } - - // 配置文件日志 - if cfg.Log.EnableFile { - // 创建日志目录 - if err := os.MkdirAll(filepath.Dir(cfg.Log.FilePath), 0755); err != nil { - logrus.Warnf("创建日志目录失败: %v", err) - } else { - // 打开日志文件 - file, err := os.OpenFile(cfg.Log.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) - if err != nil { - logrus.Warnf("打开日志文件失败: %v", err) - } else { - // 同时输出到控制台和文件 - logrus.SetOutput(io.MultiWriter(os.Stdout, file)) - } - } - } -} diff --git a/go.mod b/go.mod index ea4491c..dca9fac 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module openai-multi-key-proxy +module gpt-load go 1.21 diff --git a/internal/config/config.go b/internal/config/config.go deleted file mode 100644 index b9c4f85..0000000 --- a/internal/config/config.go +++ /dev/null @@ -1,311 +0,0 @@ -// Package config 配置管理模块 -// @author OpenAI Proxy Team -// @version 2.0.0 -package config - -import ( - "fmt" - "net/url" - "os" - "strconv" - "strings" - - "github.com/joho/godotenv" - "github.com/sirupsen/logrus" -) - -// Constants 配置常量 -type Constants struct { - MinPort int - MaxPort int - MinTimeout int - DefaultTimeout int - DefaultMaxSockets int - DefaultMaxFreeSockets int -} - -// DefaultConstants 默认常量 -var DefaultConstants = Constants{ - MinPort: 1, - MaxPort: 65535, - MinTimeout: 1000, - DefaultTimeout: 30000, - DefaultMaxSockets: 50, - DefaultMaxFreeSockets: 10, -} - -// ServerConfig 服务器配置 -type ServerConfig struct { - Port int `json:"port"` - Host string `json:"host"` -} - -// KeysConfig 密钥管理配置 -type KeysConfig struct { - FilePath string `json:"filePath"` - StartIndex int `json:"startIndex"` - BlacklistThreshold int `json:"blacklistThreshold"` - MaxRetries int `json:"maxRetries"` // 最大重试次数 -} - -// OpenAIConfig OpenAI API 配置 -type OpenAIConfig struct { - BaseURL string `json:"baseURL"` - Timeout int `json:"timeout"` -} - -// AuthConfig 认证配置 -type AuthConfig struct { - Key string `json:"key"` - Enabled bool `json:"enabled"` -} - -// CORSConfig CORS 配置 -type CORSConfig struct { - Enabled bool `json:"enabled"` - AllowedOrigins []string `json:"allowedOrigins"` -} - -// PerformanceConfig 性能配置 -type PerformanceConfig struct { - MaxSockets int `json:"maxSockets"` - MaxFreeSockets int `json:"maxFreeSockets"` - EnableKeepAlive bool `json:"enableKeepAlive"` - DisableCompression bool `json:"disableCompression"` - BufferSize int `json:"bufferSize"` - StreamBufferSize int `json:"streamBufferSize"` // 流式传输缓冲区大小 - StreamHeaderTimeout int `json:"streamHeaderTimeout"` // 流式请求响应头超时(毫秒) -} - -// LogConfig 日志配置 -type LogConfig struct { - Level string `json:"level"` // debug, info, warn, error - Format string `json:"format"` // text, json - EnableFile bool `json:"enableFile"` // 是否启用文件日志 - FilePath string `json:"filePath"` // 日志文件路径 - EnableRequest bool `json:"enableRequest"` // 是否启用请求日志 -} - -// Config 应用配置 -type Config struct { - Server ServerConfig `json:"server"` - Keys KeysConfig `json:"keys"` - OpenAI OpenAIConfig `json:"openai"` - Auth AuthConfig `json:"auth"` - CORS CORSConfig `json:"cors"` - Performance PerformanceConfig `json:"performance"` - Log LogConfig `json:"log"` -} - -// Global config instance -var AppConfig *Config - -// LoadConfig 加载配置 -func LoadConfig() (*Config, error) { - // 尝试加载 .env 文件 - if err := godotenv.Load(); err != nil { - logrus.Info("💡 提示: 创建 .env 文件以支持环境变量配置") - } - - config := &Config{ - Server: ServerConfig{ - Port: parseInteger(os.Getenv("PORT"), 3000), - Host: getEnvOrDefault("HOST", "0.0.0.0"), - }, - Keys: KeysConfig{ - FilePath: getEnvOrDefault("KEYS_FILE", "keys.txt"), - StartIndex: parseInteger(os.Getenv("START_INDEX"), 0), - BlacklistThreshold: parseInteger(os.Getenv("BLACKLIST_THRESHOLD"), 1), - MaxRetries: parseInteger(os.Getenv("MAX_RETRIES"), 3), - }, - OpenAI: OpenAIConfig{ - BaseURL: getEnvOrDefault("OPENAI_BASE_URL", "https://api.openai.com"), - Timeout: parseInteger(os.Getenv("REQUEST_TIMEOUT"), DefaultConstants.DefaultTimeout), - }, - Auth: AuthConfig{ - Key: os.Getenv("AUTH_KEY"), - Enabled: os.Getenv("AUTH_KEY") != "", - }, - CORS: CORSConfig{ - Enabled: parseBoolean(os.Getenv("ENABLE_CORS"), true), - AllowedOrigins: parseArray(os.Getenv("ALLOWED_ORIGINS"), []string{"*"}), - }, - Performance: PerformanceConfig{ - MaxSockets: parseInteger(os.Getenv("MAX_SOCKETS"), DefaultConstants.DefaultMaxSockets), - MaxFreeSockets: parseInteger(os.Getenv("MAX_FREE_SOCKETS"), DefaultConstants.DefaultMaxFreeSockets), - EnableKeepAlive: parseBoolean(os.Getenv("ENABLE_KEEP_ALIVE"), true), - DisableCompression: parseBoolean(os.Getenv("DISABLE_COMPRESSION"), true), - BufferSize: parseInteger(os.Getenv("BUFFER_SIZE"), 32*1024), - StreamBufferSize: parseInteger(os.Getenv("STREAM_BUFFER_SIZE"), 64*1024), // 默认64KB - StreamHeaderTimeout: parseInteger(os.Getenv("STREAM_HEADER_TIMEOUT"), 10000), // 默认10秒 - }, - Log: LogConfig{ - Level: getEnvOrDefault("LOG_LEVEL", "info"), - Format: getEnvOrDefault("LOG_FORMAT", "text"), - EnableFile: parseBoolean(os.Getenv("LOG_ENABLE_FILE"), false), - FilePath: getEnvOrDefault("LOG_FILE_PATH", "logs/app.log"), - EnableRequest: parseBoolean(os.Getenv("LOG_ENABLE_REQUEST"), true), - }, - } - - // 验证配置 - if err := validateConfig(config); err != nil { - return nil, err - } - - AppConfig = config - return config, nil -} - -// validateConfig 验证配置有效性 -func validateConfig(config *Config) error { - var errors []string - - // 验证端口 - if config.Server.Port < DefaultConstants.MinPort || config.Server.Port > DefaultConstants.MaxPort { - errors = append(errors, fmt.Sprintf("端口号必须在 %d-%d 之间", DefaultConstants.MinPort, DefaultConstants.MaxPort)) - } - - // 验证起始索引 - if config.Keys.StartIndex < 0 { - errors = append(errors, "起始索引不能小于 0") - } - - // 验证黑名单阈值 - if config.Keys.BlacklistThreshold < 1 { - errors = append(errors, "黑名单阈值不能小于 1") - } - - // 验证超时时间 - if config.OpenAI.Timeout < DefaultConstants.MinTimeout { - errors = append(errors, fmt.Sprintf("请求超时时间不能小于 %dms", DefaultConstants.MinTimeout)) - } - - // 验证上游URL格式 - if _, err := url.Parse(config.OpenAI.BaseURL); err != nil { - errors = append(errors, "上游API地址格式无效") - } - - // 验证性能配置 - if config.Performance.MaxSockets < 1 { - errors = append(errors, "最大连接数不能小于 1") - } - - if config.Performance.MaxFreeSockets < 0 { - errors = append(errors, "最大空闲连接数不能小于 0") - } - - if config.Performance.StreamBufferSize < 1024 { - errors = append(errors, "流式缓冲区大小不能小于 1KB") - } - - if config.Performance.StreamHeaderTimeout < 1000 { - errors = append(errors, "流式响应头超时不能小于 1秒") - } - - if len(errors) > 0 { - logrus.Error("❌ 配置验证失败:") - for _, err := range errors { - logrus.Errorf(" - %s", err) - } - return fmt.Errorf("配置验证失败") - } - - return nil -} - -// DisplayConfig 显示当前配置信息 -func DisplayConfig(config *Config) { - logrus.Info("⚙️ 当前配置:") - logrus.Infof(" 服务器: %s:%d", config.Server.Host, config.Server.Port) - logrus.Infof(" 密钥文件: %s", config.Keys.FilePath) - logrus.Infof(" 起始索引: %d", config.Keys.StartIndex) - logrus.Infof(" 黑名单阈值: %d 次错误", config.Keys.BlacklistThreshold) - logrus.Infof(" 最大重试次数: %d", config.Keys.MaxRetries) - logrus.Infof(" 上游地址: %s", config.OpenAI.BaseURL) - logrus.Infof(" 请求超时: %dms", config.OpenAI.Timeout) - - authStatus := "未启用" - if config.Auth.Enabled { - authStatus = "已启用" - } - logrus.Infof(" 认证: %s", authStatus) - - corsStatus := "已禁用" - if config.CORS.Enabled { - corsStatus = "已启用" - } - logrus.Infof(" CORS: %s", corsStatus) - logrus.Infof(" 连接池: %d/%d", config.Performance.MaxSockets, config.Performance.MaxFreeSockets) - - keepAliveStatus := "已启用" - if !config.Performance.EnableKeepAlive { - keepAliveStatus = "已禁用" - } - logrus.Infof(" Keep-Alive: %s", keepAliveStatus) - - compressionStatus := "已启用" - if config.Performance.DisableCompression { - compressionStatus = "已禁用" - } - logrus.Infof(" 压缩: %s", compressionStatus) - logrus.Infof(" 缓冲区大小: %d bytes", config.Performance.BufferSize) - logrus.Infof(" 流式缓冲区: %d bytes", config.Performance.StreamBufferSize) - logrus.Infof(" 流式响应头超时: %dms", config.Performance.StreamHeaderTimeout) - - // 显示日志配置 - requestLogStatus := "已启用" - if !config.Log.EnableRequest { - requestLogStatus = "已禁用" - } - logrus.Infof(" 请求日志: %s", requestLogStatus) -} - -// 辅助函数 - -// parseInteger 解析整数环境变量 -func parseInteger(value string, defaultValue int) int { - if value == "" { - return defaultValue - } - if parsed, err := strconv.Atoi(value); err == nil { - return parsed - } - return defaultValue -} - -// parseBoolean 解析布尔值环境变量 -func parseBoolean(value string, defaultValue bool) bool { - if value == "" { - return defaultValue - } - return strings.ToLower(value) == "true" -} - -// parseArray 解析数组环境变量(逗号分隔) -func parseArray(value string, defaultValue []string) []string { - if value == "" { - return defaultValue - } - - parts := strings.Split(value, ",") - result := make([]string, 0, len(parts)) - for _, part := range parts { - if trimmed := strings.TrimSpace(part); trimmed != "" { - result = append(result, trimmed) - } - } - - if len(result) == 0 { - return defaultValue - } - return result -} - -// getEnvOrDefault 获取环境变量或默认值 -func getEnvOrDefault(key, defaultValue string) string { - if value := os.Getenv(key); value != "" { - return value - } - return defaultValue -} diff --git a/internal/config/manager.go b/internal/config/manager.go new file mode 100644 index 0000000..71554aa --- /dev/null +++ b/internal/config/manager.go @@ -0,0 +1,275 @@ +// Package config provides configuration management for the application +package config + +import ( + "fmt" + "net/url" + "os" + "strconv" + "strings" + + "gpt-load/internal/errors" + "gpt-load/pkg/types" + + "github.com/joho/godotenv" + "github.com/sirupsen/logrus" +) + +// Constants represents configuration constants +type Constants struct { + MinPort int + MaxPort int + MinTimeout int + DefaultTimeout int + DefaultMaxSockets int + DefaultMaxFreeSockets int +} + +// DefaultConstants holds default configuration values +var DefaultConstants = Constants{ + MinPort: 1, + MaxPort: 65535, + MinTimeout: 1000, + DefaultTimeout: 30000, + DefaultMaxSockets: 50, + DefaultMaxFreeSockets: 10, +} + +// Manager implements the ConfigManager interface +type Manager struct { + config *Config +} + +// Config represents the application configuration +type Config struct { + Server types.ServerConfig `json:"server"` + Keys types.KeysConfig `json:"keys"` + OpenAI types.OpenAIConfig `json:"openai"` + Auth types.AuthConfig `json:"auth"` + CORS types.CORSConfig `json:"cors"` + Performance types.PerformanceConfig `json:"performance"` + Log types.LogConfig `json:"log"` +} + +// NewManager creates a new configuration manager +func NewManager() (types.ConfigManager, error) { + // Try to load .env file + if err := godotenv.Load(); err != nil { + logrus.Info("Info: Create .env file to support environment variable configuration") + } + + config := &Config{ + Server: types.ServerConfig{ + Port: parseInteger(os.Getenv("PORT"), 3000), + Host: getEnvOrDefault("HOST", "0.0.0.0"), + }, + Keys: types.KeysConfig{ + FilePath: getEnvOrDefault("KEYS_FILE", "keys.txt"), + StartIndex: parseInteger(os.Getenv("START_INDEX"), 0), + BlacklistThreshold: parseInteger(os.Getenv("BLACKLIST_THRESHOLD"), 1), + MaxRetries: parseInteger(os.Getenv("MAX_RETRIES"), 3), + }, + OpenAI: types.OpenAIConfig{ + BaseURL: getEnvOrDefault("OPENAI_BASE_URL", "https://api.openai.com"), + Timeout: parseInteger(os.Getenv("REQUEST_TIMEOUT"), DefaultConstants.DefaultTimeout), + }, + Auth: types.AuthConfig{ + Key: os.Getenv("AUTH_KEY"), + Enabled: os.Getenv("AUTH_KEY") != "", + }, + CORS: types.CORSConfig{ + Enabled: parseBoolean(os.Getenv("ENABLE_CORS"), true), + AllowedOrigins: parseArray(os.Getenv("ALLOWED_ORIGINS"), []string{"*"}), + AllowedMethods: parseArray(os.Getenv("ALLOWED_METHODS"), []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}), + AllowedHeaders: parseArray(os.Getenv("ALLOWED_HEADERS"), []string{"*"}), + AllowCredentials: parseBoolean(os.Getenv("ALLOW_CREDENTIALS"), false), + }, + Performance: types.PerformanceConfig{ + MaxConcurrentRequests: parseInteger(os.Getenv("MAX_CONCURRENT_REQUESTS"), 100), + RequestTimeout: parseInteger(os.Getenv("REQUEST_TIMEOUT"), DefaultConstants.DefaultTimeout), + EnableGzip: parseBoolean(os.Getenv("ENABLE_GZIP"), true), + }, + Log: types.LogConfig{ + Level: getEnvOrDefault("LOG_LEVEL", "info"), + Format: getEnvOrDefault("LOG_FORMAT", "text"), + EnableFile: parseBoolean(os.Getenv("LOG_ENABLE_FILE"), false), + FilePath: getEnvOrDefault("LOG_FILE_PATH", "logs/app.log"), + EnableRequest: parseBoolean(os.Getenv("LOG_ENABLE_REQUEST"), true), + }, + } + + manager := &Manager{config: config} + + // Validate configuration + if err := manager.Validate(); err != nil { + return nil, err + } + + return manager, nil +} + +// GetServerConfig returns server configuration +func (m *Manager) GetServerConfig() types.ServerConfig { + return m.config.Server +} + +// GetKeysConfig returns keys configuration +func (m *Manager) GetKeysConfig() types.KeysConfig { + return m.config.Keys +} + +// GetOpenAIConfig returns OpenAI configuration +func (m *Manager) GetOpenAIConfig() types.OpenAIConfig { + return m.config.OpenAI +} + +// GetAuthConfig returns authentication configuration +func (m *Manager) GetAuthConfig() types.AuthConfig { + return m.config.Auth +} + +// GetCORSConfig returns CORS configuration +func (m *Manager) GetCORSConfig() types.CORSConfig { + return m.config.CORS +} + +// GetPerformanceConfig returns performance configuration +func (m *Manager) GetPerformanceConfig() types.PerformanceConfig { + return m.config.Performance +} + +// GetLogConfig returns logging configuration +func (m *Manager) GetLogConfig() types.LogConfig { + return m.config.Log +} + +// Validate validates the configuration +func (m *Manager) Validate() error { + var validationErrors []string + + // Validate port + if m.config.Server.Port < DefaultConstants.MinPort || m.config.Server.Port > DefaultConstants.MaxPort { + validationErrors = append(validationErrors, fmt.Sprintf("port must be between %d-%d", DefaultConstants.MinPort, DefaultConstants.MaxPort)) + } + + // Validate start index + if m.config.Keys.StartIndex < 0 { + validationErrors = append(validationErrors, "start index cannot be less than 0") + } + + // Validate blacklist threshold + if m.config.Keys.BlacklistThreshold < 1 { + validationErrors = append(validationErrors, "blacklist threshold cannot be less than 1") + } + + // Validate timeout + if m.config.OpenAI.Timeout < DefaultConstants.MinTimeout { + validationErrors = append(validationErrors, fmt.Sprintf("request timeout cannot be less than %dms", DefaultConstants.MinTimeout)) + } + + // Validate upstream URL format + if _, err := url.Parse(m.config.OpenAI.BaseURL); err != nil { + validationErrors = append(validationErrors, "invalid upstream API URL format") + } + + // Validate performance configuration + if m.config.Performance.MaxConcurrentRequests < 1 { + validationErrors = append(validationErrors, "max concurrent requests cannot be less than 1") + } + + if len(validationErrors) > 0 { + logrus.Error("Configuration validation failed:") + for _, err := range validationErrors { + logrus.Errorf(" - %s", err) + } + return errors.NewAppErrorWithDetails(errors.ErrConfigValidation, "Configuration validation failed", strings.Join(validationErrors, "; ")) + } + + return nil +} + +// DisplayConfig displays current configuration information +func (m *Manager) DisplayConfig() { + logrus.Info("Current Configuration:") + logrus.Infof(" Server: %s:%d", m.config.Server.Host, m.config.Server.Port) + logrus.Infof(" Keys file: %s", m.config.Keys.FilePath) + logrus.Infof(" Start index: %d", m.config.Keys.StartIndex) + logrus.Infof(" Blacklist threshold: %d errors", m.config.Keys.BlacklistThreshold) + logrus.Infof(" Max retries: %d", m.config.Keys.MaxRetries) + logrus.Infof(" Upstream URL: %s", m.config.OpenAI.BaseURL) + logrus.Infof(" Request timeout: %dms", m.config.OpenAI.Timeout) + + authStatus := "disabled" + if m.config.Auth.Enabled { + authStatus = "enabled" + } + logrus.Infof(" Authentication: %s", authStatus) + + corsStatus := "disabled" + if m.config.CORS.Enabled { + corsStatus = "enabled" + } + logrus.Infof(" CORS: %s", corsStatus) + logrus.Infof(" Max concurrent requests: %d", m.config.Performance.MaxConcurrentRequests) + + gzipStatus := "disabled" + if m.config.Performance.EnableGzip { + gzipStatus = "enabled" + } + logrus.Infof(" Gzip compression: %s", gzipStatus) + + requestLogStatus := "enabled" + if !m.config.Log.EnableRequest { + requestLogStatus = "disabled" + } + logrus.Infof(" Request logging: %s", requestLogStatus) +} + +// Helper functions + +// parseInteger parses integer environment variable +func parseInteger(value string, defaultValue int) int { + if value == "" { + return defaultValue + } + if parsed, err := strconv.Atoi(value); err == nil { + return parsed + } + return defaultValue +} + +// parseBoolean parses boolean environment variable +func parseBoolean(value string, defaultValue bool) bool { + if value == "" { + return defaultValue + } + return strings.ToLower(value) == "true" +} + +// parseArray parses array environment variable (comma-separated) +func parseArray(value string, defaultValue []string) []string { + if value == "" { + return defaultValue + } + + parts := strings.Split(value, ",") + result := make([]string, 0, len(parts)) + for _, part := range parts { + if trimmed := strings.TrimSpace(part); trimmed != "" { + result = append(result, trimmed) + } + } + + if len(result) == 0 { + return defaultValue + } + return result +} + +// getEnvOrDefault gets environment variable or default value +func getEnvOrDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go new file mode 100644 index 0000000..33a664e --- /dev/null +++ b/internal/errors/errors.go @@ -0,0 +1,129 @@ +// Package errors defines custom error types for the application +package errors + +import ( + "fmt" + "net/http" +) + +// ErrorCode represents different types of errors +type ErrorCode int + +const ( + // Configuration errors + ErrConfigInvalid ErrorCode = iota + 1000 + ErrConfigMissing + ErrConfigValidation + + // Key management errors + ErrNoKeysAvailable ErrorCode = iota + 2000 + ErrKeyFileNotFound + ErrKeyFileInvalid + ErrAllKeysBlacklisted + + // Proxy errors + ErrProxyRequest ErrorCode = iota + 3000 + ErrProxyResponse + ErrProxyTimeout + ErrProxyRetryExhausted + + // Authentication errors + ErrAuthInvalid ErrorCode = iota + 4000 + ErrAuthMissing + ErrAuthExpired + + // Server errors + ErrServerInternal ErrorCode = iota + 5000 + ErrServerUnavailable +) + +// AppError represents a custom application error +type AppError struct { + Code ErrorCode `json:"code"` + Message string `json:"message"` + Details string `json:"details,omitempty"` + HTTPStatus int `json:"-"` + Cause error `json:"-"` +} + +// Error implements the error interface +func (e *AppError) Error() string { + if e.Details != "" { + return fmt.Sprintf("[%d] %s: %s", e.Code, e.Message, e.Details) + } + return fmt.Sprintf("[%d] %s", e.Code, e.Message) +} + +// Unwrap returns the underlying error +func (e *AppError) Unwrap() error { + return e.Cause +} + +// NewAppError creates a new application error +func NewAppError(code ErrorCode, message string) *AppError { + return &AppError{ + Code: code, + Message: message, + HTTPStatus: getHTTPStatusForCode(code), + } +} + +// NewAppErrorWithDetails creates a new application error with details +func NewAppErrorWithDetails(code ErrorCode, message, details string) *AppError { + return &AppError{ + Code: code, + Message: message, + Details: details, + HTTPStatus: getHTTPStatusForCode(code), + } +} + +// NewAppErrorWithCause creates a new application error with underlying cause +func NewAppErrorWithCause(code ErrorCode, message string, cause error) *AppError { + return &AppError{ + Code: code, + Message: message, + HTTPStatus: getHTTPStatusForCode(code), + Cause: cause, + } +} + +// getHTTPStatusForCode maps error codes to HTTP status codes +func getHTTPStatusForCode(code ErrorCode) int { + switch { + case code >= 1000 && code < 2000: // Configuration errors + return http.StatusInternalServerError + case code >= 2000 && code < 3000: // Key management errors + return http.StatusServiceUnavailable + case code >= 3000 && code < 4000: // Proxy errors + return http.StatusBadGateway + case code >= 4000 && code < 5000: // Authentication errors + return http.StatusUnauthorized + case code >= 5000 && code < 6000: // Server errors + return http.StatusInternalServerError + default: + return http.StatusInternalServerError + } +} + +// IsRetryable determines if an error is retryable +func IsRetryable(err error) bool { + if appErr, ok := err.(*AppError); ok { + switch appErr.Code { + case ErrProxyTimeout, ErrServerUnavailable: + return true + default: + return false + } + } + return false +} + +// Common error instances +var ( + ErrNoAPIKeysAvailable = NewAppError(ErrNoKeysAvailable, "No API keys available") + ErrAllAPIKeysBlacklisted = NewAppError(ErrAllKeysBlacklisted, "All API keys are blacklisted") + ErrInvalidConfiguration = NewAppError(ErrConfigInvalid, "Invalid configuration") + ErrAuthenticationRequired = NewAppError(ErrAuthMissing, "Authentication required") + ErrInvalidAuthToken = NewAppError(ErrAuthInvalid, "Invalid authentication token") +) diff --git a/internal/handler/handler.go b/internal/handler/handler.go new file mode 100644 index 0000000..0e6d61c --- /dev/null +++ b/internal/handler/handler.go @@ -0,0 +1,216 @@ +// Package handler provides HTTP handlers for the application +package handler + +import ( + "net/http" + "runtime" + "time" + + "gpt-load/pkg/types" + + "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" +) + +// Handler contains dependencies for HTTP handlers +type Handler struct { + keyManager types.KeyManager + config types.ConfigManager +} + +// NewHandler creates a new handler instance +func NewHandler(keyManager types.KeyManager, config types.ConfigManager) *Handler { + return &Handler{ + keyManager: keyManager, + config: config, + } +} + +// Health handles health check requests +func (h *Handler) Health(c *gin.Context) { + stats := h.keyManager.GetStats() + + status := "healthy" + httpStatus := http.StatusOK + + // Check if there are any healthy keys + if stats.HealthyKeys == 0 { + status = "unhealthy" + httpStatus = http.StatusServiceUnavailable + } + + c.JSON(httpStatus, gin.H{ + "status": status, + "timestamp": time.Now().UTC().Format(time.RFC3339), + "healthy_keys": stats.HealthyKeys, + "total_keys": stats.TotalKeys, + "uptime": time.Since(time.Now()).String(), // This would need to be tracked properly + }) +} + +// Stats handles statistics requests +func (h *Handler) Stats(c *gin.Context) { + stats := h.keyManager.GetStats() + + // Add additional system information + var m runtime.MemStats + runtime.ReadMemStats(&m) + + response := gin.H{ + "keys": gin.H{ + "total": stats.TotalKeys, + "healthy": stats.HealthyKeys, + "blacklisted": stats.BlacklistedKeys, + "current_index": stats.CurrentIndex, + }, + "requests": gin.H{ + "success_count": stats.SuccessCount, + "failure_count": stats.FailureCount, + "total_count": stats.SuccessCount + stats.FailureCount, + }, + "memory": gin.H{ + "alloc_mb": bToMb(m.Alloc), + "total_alloc_mb": bToMb(m.TotalAlloc), + "sys_mb": bToMb(m.Sys), + "num_gc": m.NumGC, + "last_gc": time.Unix(0, int64(m.LastGC)).Format("2006-01-02 15:04:05"), + "next_gc_mb": bToMb(m.NextGC), + }, + "system": gin.H{ + "goroutines": runtime.NumGoroutine(), + "cpu_count": runtime.NumCPU(), + "go_version": runtime.Version(), + }, + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + + c.JSON(http.StatusOK, response) +} + +// Blacklist handles blacklist requests +func (h *Handler) Blacklist(c *gin.Context) { + blacklist := h.keyManager.GetBlacklist() + + response := gin.H{ + "blacklisted_keys": blacklist, + "count": len(blacklist), + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + + c.JSON(http.StatusOK, response) +} + +// ResetKeys handles key reset requests +func (h *Handler) ResetKeys(c *gin.Context) { + // Reset blacklist + h.keyManager.ResetBlacklist() + + // Reload keys from file + if err := h.keyManager.LoadKeys(); err != nil { + logrus.Errorf("Failed to reload keys: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to reload keys", + "message": err.Error(), + }) + return + } + + stats := h.keyManager.GetStats() + + c.JSON(http.StatusOK, gin.H{ + "message": "Keys reset and reloaded successfully", + "total_keys": stats.TotalKeys, + "healthy_keys": stats.HealthyKeys, + "timestamp": time.Now().UTC().Format(time.RFC3339), + }) + + logrus.Info("Keys reset and reloaded successfully") +} + +// NotFound handles 404 requests +func (h *Handler) NotFound(c *gin.Context) { + c.JSON(http.StatusNotFound, gin.H{ + "error": "Endpoint not found", + "path": c.Request.URL.Path, + "method": c.Request.Method, + "timestamp": time.Now().UTC().Format(time.RFC3339), + }) +} + +// MethodNotAllowed handles 405 requests +func (h *Handler) MethodNotAllowed(c *gin.Context) { + c.JSON(http.StatusMethodNotAllowed, gin.H{ + "error": "Method not allowed", + "path": c.Request.URL.Path, + "method": c.Request.Method, + "timestamp": time.Now().UTC().Format(time.RFC3339), + }) +} + +// GetConfig returns configuration information (for debugging) +func (h *Handler) GetConfig(c *gin.Context) { + // Only allow in development mode or with special header + if c.GetHeader("X-Debug-Config") != "true" { + c.JSON(http.StatusForbidden, gin.H{ + "error": "Access denied", + }) + return + } + + serverConfig := h.config.GetServerConfig() + keysConfig := h.config.GetKeysConfig() + openaiConfig := h.config.GetOpenAIConfig() + authConfig := h.config.GetAuthConfig() + corsConfig := h.config.GetCORSConfig() + perfConfig := h.config.GetPerformanceConfig() + logConfig := h.config.GetLogConfig() + + // Sanitize sensitive information + sanitizedConfig := gin.H{ + "server": gin.H{ + "host": serverConfig.Host, + "port": serverConfig.Port, + }, + "keys": gin.H{ + "file_path": keysConfig.FilePath, + "start_index": keysConfig.StartIndex, + "blacklist_threshold": keysConfig.BlacklistThreshold, + "max_retries": keysConfig.MaxRetries, + }, + "openai": gin.H{ + "base_url": openaiConfig.BaseURL, + "timeout": openaiConfig.Timeout, + }, + "auth": gin.H{ + "enabled": authConfig.Enabled, + // Don't expose the actual key + }, + "cors": gin.H{ + "enabled": corsConfig.Enabled, + "allowed_origins": corsConfig.AllowedOrigins, + "allowed_methods": corsConfig.AllowedMethods, + "allowed_headers": corsConfig.AllowedHeaders, + "allow_credentials": corsConfig.AllowCredentials, + }, + "performance": gin.H{ + "max_concurrent_requests": perfConfig.MaxConcurrentRequests, + "request_timeout": perfConfig.RequestTimeout, + "enable_gzip": perfConfig.EnableGzip, + }, + "log": gin.H{ + "level": logConfig.Level, + "format": logConfig.Format, + "enable_file": logConfig.EnableFile, + "file_path": logConfig.FilePath, + "enable_request": logConfig.EnableRequest, + }, + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + + c.JSON(http.StatusOK, sanitizedConfig) +} + +// Helper function to convert bytes to megabytes +func bToMb(b uint64) uint64 { + return b / 1024 / 1024 +} diff --git a/internal/keymanager/keymanager.go b/internal/keymanager/keymanager.go deleted file mode 100644 index a0d4aa2..0000000 --- a/internal/keymanager/keymanager.go +++ /dev/null @@ -1,416 +0,0 @@ -// Package keymanager 高性能密钥管理器 -// @author OpenAI Proxy Team -// @version 2.0.0 -package keymanager - -import ( - "bufio" - "fmt" - "os" - "regexp" - "strings" - "sync" - "sync/atomic" - "time" - - "openai-multi-key-proxy/internal/config" - - "github.com/sirupsen/logrus" -) - -// KeyInfo 密钥信息 -type KeyInfo struct { - Key string `json:"key"` - Index int `json:"index"` - Preview string `json:"preview"` -} - -// Stats 统计信息 -type Stats struct { - CurrentIndex int64 `json:"currentIndex"` - TotalKeys int `json:"totalKeys"` - HealthyKeys int `json:"healthyKeys"` - BlacklistedKeys int `json:"blacklistedKeys"` - SuccessCount int64 `json:"successCount"` - FailureCount int64 `json:"failureCount"` - MemoryUsage MemoryUsage `json:"memoryUsage"` -} - -// MemoryUsage 内存使用情况 -type MemoryUsage struct { - FailureCountsSize int `json:"failureCountsSize"` - BlacklistSize int `json:"blacklistSize"` -} - -// BlacklistDetail 黑名单详情 -type BlacklistDetail struct { - Index int `json:"index"` - LineNumber int `json:"lineNumber"` - KeyPreview string `json:"keyPreview"` - FullKey string `json:"fullKey"` -} - -// BlacklistInfo 黑名单信息 -type BlacklistInfo struct { - TotalBlacklisted int `json:"totalBlacklisted"` - TotalKeys int `json:"totalKeys"` - HealthyKeys int `json:"healthyKeys"` - BlacklistedKeys []BlacklistDetail `json:"blacklistedKeys"` -} - -// KeyManager 密钥管理器 -type KeyManager struct { - keysFilePath string - keys []string - keyPreviews []string - currentIndex int64 - blacklistedKeys sync.Map - successCount int64 - failureCount int64 - keyFailureCounts sync.Map - - // 性能优化:预编译正则表达式 - permanentErrorPatterns []*regexp.Regexp - - // 内存管理 - cleanupTicker *time.Ticker - stopCleanup chan bool - - // 读写锁保护密钥列表 - keysMutex sync.RWMutex -} - -// NewKeyManager 创建新的密钥管理器 -func NewKeyManager(keysFilePath string) *KeyManager { - if keysFilePath == "" { - keysFilePath = config.AppConfig.Keys.FilePath - } - - km := &KeyManager{ - keysFilePath: keysFilePath, - currentIndex: int64(config.AppConfig.Keys.StartIndex), - stopCleanup: make(chan bool), - - // 预编译正则表达式 - permanentErrorPatterns: []*regexp.Regexp{ - regexp.MustCompile(`(?i)invalid api key`), - regexp.MustCompile(`(?i)incorrect api key`), - regexp.MustCompile(`(?i)api key not found`), - regexp.MustCompile(`(?i)unauthorized`), - regexp.MustCompile(`(?i)account deactivated`), - regexp.MustCompile(`(?i)billing`), - }, - } - - // 启动内存清理 - km.setupMemoryCleanup() - - return km -} - -// LoadKeys 加载密钥文件 -func (km *KeyManager) LoadKeys() error { - file, err := os.Open(km.keysFilePath) - if err != nil { - return fmt.Errorf("无法打开密钥文件: %w", err) - } - defer file.Close() - - var keys []string - scanner := bufio.NewScanner(file) - - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line != "" { - keys = append(keys, line) - } - } - - if err := scanner.Err(); err != nil { - return fmt.Errorf("读取密钥文件失败: %w", err) - } - - if len(keys) == 0 { - return fmt.Errorf("密钥文件中没有有效的API密钥") - } - - km.keysMutex.Lock() - km.keys = keys - // 预生成密钥预览,避免运行时重复计算 - km.keyPreviews = make([]string, len(keys)) - for i, key := range keys { - if len(key) > 20 { - km.keyPreviews[i] = key[:20] + "..." - } else { - km.keyPreviews[i] = key - } - } - km.keysMutex.Unlock() - - logrus.Infof("✅ 成功加载 %d 个 API 密钥", len(keys)) - return nil -} - -// GetNextKey 获取下一个可用的密钥(高性能版本) -func (km *KeyManager) GetNextKey() (*KeyInfo, error) { - km.keysMutex.RLock() - keysLen := len(km.keys) - if keysLen == 0 { - km.keysMutex.RUnlock() - return nil, fmt.Errorf("没有可用的 API 密钥") - } - - // 快速路径:直接获取下一个密钥,避免黑名单检查的开销 - currentIdx := atomic.AddInt64(&km.currentIndex, 1) - 1 - keyIndex := int(currentIdx) % keysLen - selectedKey := km.keys[keyIndex] - keyPreview := km.keyPreviews[keyIndex] - km.keysMutex.RUnlock() - - // 检查是否被拉黑 - if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted { - return &KeyInfo{ - Key: selectedKey, - Index: keyIndex, - Preview: keyPreview, - }, nil - } - - // 慢速路径:寻找可用密钥 - attempts := 0 - maxAttempts := keysLen * 2 // 最多尝试两轮 - - for attempts < maxAttempts { - currentIdx = atomic.AddInt64(&km.currentIndex, 1) - 1 - keyIndex = int(currentIdx) % keysLen - - km.keysMutex.RLock() - selectedKey = km.keys[keyIndex] - keyPreview = km.keyPreviews[keyIndex] - km.keysMutex.RUnlock() - - if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted { - return &KeyInfo{ - Key: selectedKey, - Index: keyIndex, - Preview: keyPreview, - }, nil - } - - attempts++ - } - - // 检查是否所有密钥都被拉黑,如果是则重置 - blacklistedCount := 0 - km.blacklistedKeys.Range(func(key, value interface{}) bool { - blacklistedCount++ - return blacklistedCount < keysLen // 提前退出优化 - }) - - if blacklistedCount >= keysLen { - logrus.Warn("⚠️ 所有密钥都被拉黑,重置黑名单") - km.blacklistedKeys = sync.Map{} - km.keyFailureCounts = sync.Map{} - - // 重置后返回第一个密钥 - km.keysMutex.RLock() - firstKey := km.keys[0] - firstPreview := km.keyPreviews[0] - km.keysMutex.RUnlock() - - return &KeyInfo{ - Key: firstKey, - Index: 0, - Preview: firstPreview, - }, nil - } - - return nil, fmt.Errorf("暂时没有可用的 API 密钥") -} - -// RecordSuccess 记录密钥使用成功 -func (km *KeyManager) RecordSuccess(key string) { - atomic.AddInt64(&km.successCount, 1) - // 成功时重置该密钥的失败计数 - km.keyFailureCounts.Delete(key) -} - -// RecordFailure 记录密钥使用失败 -func (km *KeyManager) RecordFailure(key string, err error) { - atomic.AddInt64(&km.failureCount, 1) - - // 检查是否是永久性错误 - if km.isPermanentError(err) { - km.blacklistedKeys.Store(key, true) - km.keyFailureCounts.Delete(key) // 清理计数 - logrus.Warnf("🚫 密钥已被拉黑(永久性错误): %s (%s)", key[:20]+"...", err.Error()) - return - } - - // 临时性错误:增加失败计数 - currentFailures := 0 - if val, exists := km.keyFailureCounts.Load(key); exists { - currentFailures = val.(int) - } - newFailures := currentFailures + 1 - km.keyFailureCounts.Store(key, newFailures) - - threshold := config.AppConfig.Keys.BlacklistThreshold - if newFailures >= threshold { - km.blacklistedKeys.Store(key, true) - km.keyFailureCounts.Delete(key) // 清理计数 - logrus.Warnf("🚫 密钥已被拉黑(达到阈值): %s (失败 %d 次: %s)", key[:20]+"...", newFailures, err.Error()) - } else { - logrus.Debugf("⚠️ 密钥失败: %s (%d/%d 次: %s)", key[:20]+"...", newFailures, threshold, err.Error()) - } -} - -// isPermanentError 判断是否是永久性错误 -func (km *KeyManager) isPermanentError(err error) bool { - errorMessage := err.Error() - for _, pattern := range km.permanentErrorPatterns { - if pattern.MatchString(errorMessage) { - return true - } - } - return false -} - -// GetStats 获取密钥统计信息 -func (km *KeyManager) GetStats() *Stats { - km.keysMutex.RLock() - totalKeys := len(km.keys) - km.keysMutex.RUnlock() - - blacklistedCount := 0 - km.blacklistedKeys.Range(func(key, value interface{}) bool { - blacklistedCount++ - return true - }) - - failureCountsSize := 0 - km.keyFailureCounts.Range(func(key, value interface{}) bool { - failureCountsSize++ - return true - }) - - return &Stats{ - CurrentIndex: atomic.LoadInt64(&km.currentIndex), - TotalKeys: totalKeys, - HealthyKeys: totalKeys - blacklistedCount, - BlacklistedKeys: blacklistedCount, - SuccessCount: atomic.LoadInt64(&km.successCount), - FailureCount: atomic.LoadInt64(&km.failureCount), - MemoryUsage: MemoryUsage{ - FailureCountsSize: failureCountsSize, - BlacklistSize: blacklistedCount, - }, - } -} - -// ResetKeys 重置密钥状态 -func (km *KeyManager) ResetKeys() map[string]interface{} { - beforeCount := 0 - km.blacklistedKeys.Range(func(key, value interface{}) bool { - beforeCount++ - return true - }) - - km.blacklistedKeys = sync.Map{} - km.keyFailureCounts = sync.Map{} - - logrus.Infof("🔄 密钥状态已重置,清除了 %d 个黑名单密钥", beforeCount) - - km.keysMutex.RLock() - totalKeys := len(km.keys) - km.keysMutex.RUnlock() - - return map[string]interface{}{ - "success": true, - "message": fmt.Sprintf("已清除 %d 个黑名单密钥", beforeCount), - "clearedCount": beforeCount, - "totalKeys": totalKeys, - } -} - -// GetBlacklistDetails 获取黑名单详情 -func (km *KeyManager) GetBlacklistDetails() *BlacklistInfo { - var blacklistDetails []BlacklistDetail - - km.keysMutex.RLock() - keys := km.keys - keyPreviews := km.keyPreviews - km.keysMutex.RUnlock() - - for i, key := range keys { - if _, blacklisted := km.blacklistedKeys.Load(key); blacklisted { - blacklistDetails = append(blacklistDetails, BlacklistDetail{ - Index: i, - LineNumber: i + 1, - KeyPreview: keyPreviews[i], - FullKey: key, - }) - } - } - - return &BlacklistInfo{ - TotalBlacklisted: len(blacklistDetails), - TotalKeys: len(keys), - HealthyKeys: len(keys) - len(blacklistDetails), - BlacklistedKeys: blacklistDetails, - } -} - -// setupMemoryCleanup 设置内存清理机制 -func (km *KeyManager) setupMemoryCleanup() { - km.cleanupTicker = time.NewTicker(10 * time.Minute) - - go func() { - for { - select { - case <-km.cleanupTicker.C: - km.performMemoryCleanup() - case <-km.stopCleanup: - km.cleanupTicker.Stop() - return - } - } - }() -} - -// performMemoryCleanup 执行内存清理 -func (km *KeyManager) performMemoryCleanup() { - km.keysMutex.RLock() - maxSize := len(km.keys) * 2 - if maxSize < 1000 { - maxSize = 1000 - } - km.keysMutex.RUnlock() - - currentSize := 0 - km.keyFailureCounts.Range(func(key, value interface{}) bool { - currentSize++ - return true - }) - - if currentSize > maxSize { - logrus.Infof("🧹 清理失败计数缓存 (%d -> %d)", currentSize, maxSize) - - // 简单策略:清理一半的失败计数 - cleared := 0 - target := currentSize - maxSize - - km.keyFailureCounts.Range(func(key, value interface{}) bool { - if cleared < target { - km.keyFailureCounts.Delete(key) - cleared++ - } - return cleared < target - }) - } -} - -// Close 关闭密钥管理器 -func (km *KeyManager) Close() { - close(km.stopCleanup) -} diff --git a/internal/keymanager/manager.go b/internal/keymanager/manager.go new file mode 100644 index 0000000..7fe1376 --- /dev/null +++ b/internal/keymanager/manager.go @@ -0,0 +1,326 @@ +// Package keymanager provides high-performance API key management +package keymanager + +import ( + "bufio" + "os" + "regexp" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" + + "gpt-load/internal/errors" + "gpt-load/pkg/types" + + "github.com/sirupsen/logrus" +) + +// Manager implements the KeyManager interface +type Manager struct { + keysFilePath string + keys []string + keyPreviews []string + currentIndex int64 + blacklistedKeys sync.Map + successCount int64 + failureCount int64 + keyFailureCounts sync.Map + config types.KeysConfig + + // Performance optimization: pre-compiled regex patterns + permanentErrorPatterns []*regexp.Regexp + + // Memory management + cleanupTicker *time.Ticker + stopCleanup chan bool + + // Read-write lock to protect key list + keysMutex sync.RWMutex +} + +// NewManager creates a new key manager +func NewManager(config types.KeysConfig) (types.KeyManager, error) { + if config.FilePath == "" { + return nil, errors.NewAppError(errors.ErrKeyFileNotFound, "Keys file path is required") + } + + km := &Manager{ + keysFilePath: config.FilePath, + currentIndex: int64(config.StartIndex), + stopCleanup: make(chan bool), + config: config, + + // Pre-compile regex patterns + permanentErrorPatterns: []*regexp.Regexp{ + regexp.MustCompile(`(?i)invalid api key`), + regexp.MustCompile(`(?i)incorrect api key`), + regexp.MustCompile(`(?i)api key not found`), + regexp.MustCompile(`(?i)unauthorized`), + regexp.MustCompile(`(?i)account deactivated`), + regexp.MustCompile(`(?i)billing`), + }, + } + + // Start memory cleanup + km.setupMemoryCleanup() + + // Load keys + if err := km.LoadKeys(); err != nil { + return nil, err + } + + return km, nil +} + +// LoadKeys loads API keys from file +func (km *Manager) LoadKeys() error { + file, err := os.Open(km.keysFilePath) + if err != nil { + return errors.NewAppErrorWithCause(errors.ErrKeyFileNotFound, "Failed to open keys file", err) + } + defer file.Close() + + var keys []string + var keyPreviews []string + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line != "" && !strings.HasPrefix(line, "#") { + keys = append(keys, line) + // Create preview (first 8 chars + "..." + last 4 chars) + if len(line) > 12 { + preview := line[:8] + "..." + line[len(line)-4:] + keyPreviews = append(keyPreviews, preview) + } else { + keyPreviews = append(keyPreviews, line) + } + } + } + + if err := scanner.Err(); err != nil { + return errors.NewAppErrorWithCause(errors.ErrKeyFileInvalid, "Failed to read keys file", err) + } + + if len(keys) == 0 { + return errors.NewAppError(errors.ErrNoKeysAvailable, "No valid API keys found in file") + } + + km.keysMutex.Lock() + km.keys = keys + km.keyPreviews = keyPreviews + km.keysMutex.Unlock() + + logrus.Infof("Successfully loaded %d API keys", len(keys)) + return nil +} + +// GetNextKey gets the next available key (high-performance version) +func (km *Manager) GetNextKey() (*types.KeyInfo, error) { + km.keysMutex.RLock() + keysLen := len(km.keys) + if keysLen == 0 { + km.keysMutex.RUnlock() + return nil, errors.ErrNoAPIKeysAvailable + } + + // Fast path: directly get next key, avoid blacklist check overhead + currentIdx := atomic.AddInt64(&km.currentIndex, 1) - 1 + keyIndex := int(currentIdx) % keysLen + selectedKey := km.keys[keyIndex] + keyPreview := km.keyPreviews[keyIndex] + km.keysMutex.RUnlock() + + // Check if blacklisted + if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted { + return &types.KeyInfo{ + Key: selectedKey, + Index: keyIndex, + Preview: keyPreview, + }, nil + } + + // Slow path: find next available key + return km.findNextAvailableKey(keyIndex, keysLen) +} + +// findNextAvailableKey finds the next available non-blacklisted key +func (km *Manager) findNextAvailableKey(startIndex, keysLen int) (*types.KeyInfo, error) { + km.keysMutex.RLock() + defer km.keysMutex.RUnlock() + + blacklistedCount := 0 + for i := 0; i < keysLen; i++ { + keyIndex := (startIndex + i) % keysLen + selectedKey := km.keys[keyIndex] + + if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted { + return &types.KeyInfo{ + Key: selectedKey, + Index: keyIndex, + Preview: km.keyPreviews[keyIndex], + }, nil + } + blacklistedCount++ + } + + if blacklistedCount >= keysLen { + logrus.Warn("All keys are blacklisted, resetting blacklist") + km.blacklistedKeys = sync.Map{} + km.keyFailureCounts = sync.Map{} + + // Return first key after reset + firstKey := km.keys[0] + firstPreview := km.keyPreviews[0] + + return &types.KeyInfo{ + Key: firstKey, + Index: 0, + Preview: firstPreview, + }, nil + } + + return nil, errors.ErrAllAPIKeysBlacklisted +} + +// RecordSuccess records successful key usage +func (km *Manager) RecordSuccess(key string) { + atomic.AddInt64(&km.successCount, 1) + // Reset failure count for this key on success + km.keyFailureCounts.Delete(key) +} + +// RecordFailure records key failure and potentially blacklists it +func (km *Manager) RecordFailure(key string, err error) { + atomic.AddInt64(&km.failureCount, 1) + + // Check if this is a permanent error + if km.isPermanentError(err) { + km.blacklistedKeys.Store(key, time.Now()) + logrus.Debugf("Key blacklisted due to permanent error: %v", err) + return + } + + // Increment failure count + failCount, _ := km.keyFailureCounts.LoadOrStore(key, int64(0)) + newFailCount := atomic.AddInt64(failCount.(*int64), 1) + + // Blacklist if threshold exceeded + if int(newFailCount) >= km.config.BlacklistThreshold { + km.blacklistedKeys.Store(key, time.Now()) + logrus.Debugf("Key blacklisted after %d failures", newFailCount) + } +} + +// isPermanentError checks if an error is permanent +func (km *Manager) isPermanentError(err error) bool { + if err == nil { + return false + } + + errorStr := strings.ToLower(err.Error()) + for _, pattern := range km.permanentErrorPatterns { + if pattern.MatchString(errorStr) { + return true + } + } + return false +} + +// GetStats returns current statistics +func (km *Manager) GetStats() types.Stats { + km.keysMutex.RLock() + totalKeys := len(km.keys) + km.keysMutex.RUnlock() + + blacklistedCount := 0 + km.blacklistedKeys.Range(func(key, value interface{}) bool { + blacklistedCount++ + return true + }) + + var m runtime.MemStats + runtime.ReadMemStats(&m) + + return types.Stats{ + CurrentIndex: atomic.LoadInt64(&km.currentIndex), + TotalKeys: totalKeys, + HealthyKeys: totalKeys - blacklistedCount, + BlacklistedKeys: blacklistedCount, + SuccessCount: atomic.LoadInt64(&km.successCount), + FailureCount: atomic.LoadInt64(&km.failureCount), + MemoryUsage: types.MemoryUsage{ + Alloc: m.Alloc, + TotalAlloc: m.TotalAlloc, + Sys: m.Sys, + NumGC: m.NumGC, + LastGCTime: time.Unix(0, int64(m.LastGC)).Format("2006-01-02 15:04:05"), + NextGCTarget: m.NextGC, + }, + } +} + +// ResetBlacklist resets the blacklist +func (km *Manager) ResetBlacklist() { + km.blacklistedKeys = sync.Map{} + km.keyFailureCounts = sync.Map{} + logrus.Info("Blacklist reset successfully") +} + +// GetBlacklist returns current blacklisted keys +func (km *Manager) GetBlacklist() []types.BlacklistEntry { + var blacklist []types.BlacklistEntry + + km.blacklistedKeys.Range(func(key, value interface{}) bool { + keyStr := key.(string) + blacklistTime := value.(time.Time) + + // Create preview + preview := keyStr + if len(keyStr) > 12 { + preview = keyStr[:8] + "..." + keyStr[len(keyStr)-4:] + } + + // Get failure count + failCount := 0 + if count, exists := km.keyFailureCounts.Load(keyStr); exists { + failCount = int(atomic.LoadInt64(count.(*int64))) + } + + blacklist = append(blacklist, types.BlacklistEntry{ + Key: keyStr, + Preview: preview, + Reason: "Exceeded failure threshold", + BlacklistAt: blacklistTime, + FailCount: failCount, + }) + return true + }) + + return blacklist +} + +// setupMemoryCleanup sets up periodic memory cleanup +func (km *Manager) setupMemoryCleanup() { + km.cleanupTicker = time.NewTicker(5 * time.Minute) + go func() { + for { + select { + case <-km.cleanupTicker.C: + runtime.GC() + case <-km.stopCleanup: + return + } + } + }() +} + +// Close closes the key manager and cleans up resources +func (km *Manager) Close() { + if km.cleanupTicker != nil { + km.cleanupTicker.Stop() + } + close(km.stopCleanup) +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go new file mode 100644 index 0000000..ead038e --- /dev/null +++ b/internal/middleware/middleware.go @@ -0,0 +1,259 @@ +// Package middleware provides HTTP middleware for the application +package middleware + +import ( + "context" + "fmt" + "strings" + "time" + + "gpt-load/internal/errors" + "gpt-load/pkg/types" + + "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" +) + +// Logger creates a high-performance logging middleware +func Logger(config types.LogConfig) gin.HandlerFunc { + return func(c *gin.Context) { + // Check if request logging is enabled + if !config.EnableRequest { + // Don't log requests, only process them + c.Next() + // Only log errors + if c.Writer.Status() >= 400 { + logrus.Errorf("Error %d: %s %s", c.Writer.Status(), c.Request.Method, c.Request.URL.Path) + } + return + } + + start := time.Now() + path := c.Request.URL.Path + raw := c.Request.URL.RawQuery + + // Process request + c.Next() + + // Calculate response time + latency := time.Since(start) + + // Get basic information + method := c.Request.Method + statusCode := c.Writer.Status() + + // Build full path (avoid string concatenation) + fullPath := path + if raw != "" { + fullPath = path + "?" + raw + } + + // Get key information (if exists) + keyInfo := "" + if keyIndex, exists := c.Get("keyIndex"); exists { + if keyPreview, exists := c.Get("keyPreview"); exists { + keyInfo = fmt.Sprintf(" - Key[%v] %v", keyIndex, keyPreview) + } + } + + // Get retry information (if exists) + retryInfo := "" + if retryCount, exists := c.Get("retryCount"); exists { + retryInfo = fmt.Sprintf(" - Retry[%d]", retryCount) + } + + // Filter health check logs + if path == "/health" { + return + } + + // Choose log level based on status code + if statusCode >= 500 { + logrus.Errorf("%s %s - %d - %v%s%s", method, fullPath, statusCode, latency, keyInfo, retryInfo) + } else if statusCode >= 400 { + logrus.Warnf("%s %s - %d - %v%s%s", method, fullPath, statusCode, latency, keyInfo, retryInfo) + } else { + logrus.Infof("%s %s - %d - %v%s%s", method, fullPath, statusCode, latency, keyInfo, retryInfo) + } + } +} + +// CORS creates a CORS middleware +func CORS(config types.CORSConfig) gin.HandlerFunc { + return func(c *gin.Context) { + if !config.Enabled { + c.Next() + return + } + + origin := c.Request.Header.Get("Origin") + + // Check if origin is allowed + allowed := false + for _, allowedOrigin := range config.AllowedOrigins { + if allowedOrigin == "*" || allowedOrigin == origin { + allowed = true + break + } + } + + if allowed { + c.Header("Access-Control-Allow-Origin", origin) + } + + // Set other CORS headers + c.Header("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", ")) + c.Header("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", ")) + + if config.AllowCredentials { + c.Header("Access-Control-Allow-Credentials", "true") + } + + // Handle preflight requests + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(204) + return + } + + c.Next() + } +} + +// Auth creates an authentication middleware +func Auth(config types.AuthConfig) gin.HandlerFunc { + return func(c *gin.Context) { + if !config.Enabled { + c.Next() + return + } + + // Skip authentication for management endpoints + path := c.Request.URL.Path + if path == "/health" || path == "/stats" || path == "/blacklist" || path == "/reset-keys" { + c.Next() + return + } + + // Get authorization header + authHeader := c.GetHeader("Authorization") + if authHeader == "" { + c.JSON(401, gin.H{ + "error": "Authorization header required", + "code": errors.ErrAuthMissing, + }) + c.Abort() + return + } + + // Check Bearer token format + const bearerPrefix = "Bearer " + if !strings.HasPrefix(authHeader, bearerPrefix) { + c.JSON(401, gin.H{ + "error": "Invalid authorization format, expected 'Bearer '", + "code": errors.ErrAuthInvalid, + }) + c.Abort() + return + } + + // Extract and validate token + token := authHeader[len(bearerPrefix):] + if token != config.Key { + c.JSON(401, gin.H{ + "error": "Invalid authentication token", + "code": errors.ErrAuthInvalid, + }) + c.Abort() + return + } + + c.Next() + } +} + +// Recovery creates a recovery middleware with custom error handling +func Recovery() gin.HandlerFunc { + return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) { + if err, ok := recovered.(string); ok { + logrus.Errorf("Panic recovered: %s", err) + c.JSON(500, gin.H{ + "error": "Internal server error", + "code": errors.ErrServerInternal, + }) + } else { + logrus.Errorf("Panic recovered: %v", recovered) + c.JSON(500, gin.H{ + "error": "Internal server error", + "code": errors.ErrServerInternal, + }) + } + c.Abort() + }) +} + +// RateLimiter creates a simple rate limiting middleware +func RateLimiter(config types.PerformanceConfig) gin.HandlerFunc { + // Simple semaphore-based rate limiting + semaphore := make(chan struct{}, config.MaxConcurrentRequests) + + return func(c *gin.Context) { + select { + case semaphore <- struct{}{}: + defer func() { <-semaphore }() + c.Next() + default: + c.JSON(429, gin.H{ + "error": "Too many concurrent requests", + "code": errors.ErrServerUnavailable, + }) + c.Abort() + } + } +} + +// Timeout creates a timeout middleware +func Timeout(timeout time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + c.Request = c.Request.WithContext(ctx) + c.Next() + + if ctx.Err() == context.DeadlineExceeded { + c.JSON(408, gin.H{ + "error": "Request timeout", + "code": errors.ErrProxyTimeout, + }) + c.Abort() + } + } +} + +// ErrorHandler creates an error handling middleware +func ErrorHandler() gin.HandlerFunc { + return func(c *gin.Context) { + c.Next() + + // Handle any errors that occurred during request processing + if len(c.Errors) > 0 { + err := c.Errors.Last().Err + + // Check if it's our custom error type + if appErr, ok := err.(*errors.AppError); ok { + c.JSON(appErr.HTTPStatus, gin.H{ + "error": appErr.Message, + "code": appErr.Code, + }) + return + } + + // Handle other errors + logrus.Errorf("Unhandled error: %v", err) + c.JSON(500, gin.H{ + "error": "Internal server error", + "code": errors.ErrServerInternal, + }) + } + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go deleted file mode 100644 index 36ecd3d..0000000 --- a/internal/proxy/proxy.go +++ /dev/null @@ -1,778 +0,0 @@ -// Package proxy 高性能OpenAI多密钥代理服务器 -// @author OpenAI Proxy Team -// @version 2.0.0 -package proxy - -import ( - "bytes" - "context" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "sync/atomic" - "time" - - "openai-multi-key-proxy/internal/config" - "openai-multi-key-proxy/internal/keymanager" - - "github.com/gin-gonic/gin" - "github.com/sirupsen/logrus" -) - -// RetryError 重试过程中的错误信息 -type RetryError struct { - StatusCode int `json:"status_code"` - ErrorMessage string `json:"error_message"` - KeyIndex int `json:"key_index"` - Attempt int `json:"attempt"` -} - -// ProxyServer 代理服务器 -type ProxyServer struct { - keyManager *keymanager.KeyManager - httpClient *http.Client - streamClient *http.Client // 专门用于流式传输的客户端 - upstreamURL *url.URL - requestCount int64 - startTime time.Time -} - -// NewProxyServer 创建新的代理服务器 -func NewProxyServer() (*ProxyServer, error) { - // 解析上游URL - upstreamURL, err := url.Parse(config.AppConfig.OpenAI.BaseURL) - if err != nil { - return nil, fmt.Errorf("解析上游URL失败: %w", err) - } - - // 创建密钥管理器 - keyManager := keymanager.NewKeyManager(config.AppConfig.Keys.FilePath) - if err := keyManager.LoadKeys(); err != nil { - return nil, fmt.Errorf("加载密钥失败: %w", err) - } - - // 创建高性能HTTP客户端 - transport := &http.Transport{ - MaxIdleConns: config.AppConfig.Performance.MaxSockets, - MaxIdleConnsPerHost: config.AppConfig.Performance.MaxFreeSockets, - MaxConnsPerHost: 0, // 无限制,避免连接池瓶颈 - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - DisableCompression: config.AppConfig.Performance.DisableCompression, - ForceAttemptHTTP2: true, - WriteBufferSize: config.AppConfig.Performance.BufferSize, - ReadBufferSize: config.AppConfig.Performance.BufferSize, - } - - // 创建专门用于流式传输的transport,优化TCP参数 - streamTransport := &http.Transport{ - MaxIdleConns: config.AppConfig.Performance.MaxSockets * 2, - MaxIdleConnsPerHost: config.AppConfig.Performance.MaxFreeSockets * 2, - MaxConnsPerHost: 0, - IdleConnTimeout: 300 * time.Second, // 流式连接保持更长时间 - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - DisableCompression: true, // 流式传输始终禁用压缩 - ForceAttemptHTTP2: true, - WriteBufferSize: config.AppConfig.Performance.StreamBufferSize, - ReadBufferSize: config.AppConfig.Performance.StreamBufferSize, - ResponseHeaderTimeout: time.Duration(config.AppConfig.Performance.StreamHeaderTimeout) * time.Millisecond, - } - - // 配置 Keep-Alive - if !config.AppConfig.Performance.EnableKeepAlive { - transport.DisableKeepAlives = true - streamTransport.DisableKeepAlives = true - } - - httpClient := &http.Client{ - Transport: transport, - // 移除全局超时,使用更细粒度的超时控制 - // Timeout: time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond, - } - - // 流式客户端不设置整体超时 - streamClient := &http.Client{ - Transport: streamTransport, - } - - return &ProxyServer{ - keyManager: keyManager, - httpClient: httpClient, - streamClient: streamClient, - upstreamURL: upstreamURL, - startTime: time.Now(), - }, nil -} - -// SetupRoutes 设置路由 -func (ps *ProxyServer) SetupRoutes() *gin.Engine { - // 设置Gin模式 - gin.SetMode(gin.ReleaseMode) - - router := gin.New() - - // 自定义日志中间件 - router.Use(ps.loggerMiddleware()) - - // 恢复中间件 - router.Use(gin.Recovery()) - - // CORS中间件 - if config.AppConfig.CORS.Enabled { - router.Use(ps.corsMiddleware()) - } - - // 认证中间件(如果启用) - if config.AppConfig.Auth.Enabled { - router.Use(ps.authMiddleware()) - } - - // 管理端点 - router.GET("/health", ps.handleHealth) - router.GET("/stats", ps.handleStats) - router.GET("/blacklist", ps.handleBlacklist) - router.GET("/reset-keys", ps.handleResetKeys) - - // 代理所有其他请求 - router.NoRoute(ps.handleProxy) - - return router -} - -// corsMiddleware CORS中间件 -func (ps *ProxyServer) corsMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - origin := "*" - if len(config.AppConfig.CORS.AllowedOrigins) > 0 && config.AppConfig.CORS.AllowedOrigins[0] != "*" { - origin = strings.Join(config.AppConfig.CORS.AllowedOrigins, ",") - } - - c.Header("Access-Control-Allow-Origin", origin) - c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization") - c.Header("Access-Control-Max-Age", "86400") - - if c.Request.Method == "OPTIONS" { - c.AbortWithStatus(http.StatusOK) - return - } - - c.Next() - } -} - -// authMiddleware 认证中间件 -func (ps *ProxyServer) authMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - // 管理端点不需要认证 - if strings.HasPrefix(c.Request.URL.Path, "/health") || - strings.HasPrefix(c.Request.URL.Path, "/stats") || - strings.HasPrefix(c.Request.URL.Path, "/blacklist") || - strings.HasPrefix(c.Request.URL.Path, "/reset-keys") { - c.Next() - return - } - - authHeader := c.GetHeader("Authorization") - if authHeader == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": gin.H{ - "message": "未提供认证信息", - "type": "authentication_error", - "code": "missing_authorization", - "timestamp": time.Now().Format(time.RFC3339), - }, - }) - c.Abort() - return - } - - if !strings.HasPrefix(authHeader, "Bearer ") { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": gin.H{ - "message": "认证格式错误", - "type": "authentication_error", - "code": "invalid_authorization_format", - "timestamp": time.Now().Format(time.RFC3339), - }, - }) - c.Abort() - return - } - - token := authHeader[7:] // 移除 "Bearer " 前缀 - if token != config.AppConfig.Auth.Key { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": gin.H{ - "message": "认证失败", - "type": "authentication_error", - "code": "invalid_authorization", - "timestamp": time.Now().Format(time.RFC3339), - }, - }) - c.Abort() - return - } - - c.Next() - } -} - -// loggerMiddleware 高性能日志中间件 -func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - // 检查是否启用请求日志 - if !config.AppConfig.Log.EnableRequest { - // 不记录请求日志,只处理请求 - c.Next() - // 只记录错误 - if c.Writer.Status() >= 400 { - logrus.Errorf("Error %d: %s %s", c.Writer.Status(), c.Request.Method, c.Request.URL.Path) - } - return - } - - start := time.Now() - path := c.Request.URL.Path - raw := c.Request.URL.RawQuery - - // 处理请求 - c.Next() - - // 计算响应时间 - latency := time.Since(start) - - // 获取基本信息 - method := c.Request.Method - statusCode := c.Writer.Status() - - // 构建完整路径(避免字符串拼接) - fullPath := path - if raw != "" { - fullPath = path + "?" + raw - } - - // 获取密钥信息(如果存在) - keyInfo := "" - if keyIndex, exists := c.Get("keyIndex"); exists { - if keyPreview, exists := c.Get("keyPreview"); exists { - keyInfo = fmt.Sprintf(" - Key[%v] %v", keyIndex, keyPreview) - } - } - - // 获取重试信息(如果存在) - retryInfo := "" - if retryCount, exists := c.Get("retryCount"); exists { - retryInfo = fmt.Sprintf(" - Retry[%d]", retryCount) - } - - // 过滤健康检查日志 - if path == "/health" { - return - } - - // 根据状态码选择日志级别 - if statusCode >= 500 { - logrus.Errorf("%s %s - %d - %v%s%s", method, fullPath, statusCode, latency, keyInfo, retryInfo) - } else if statusCode >= 400 { - logrus.Warnf("%s %s - %d - %v%s%s", method, fullPath, statusCode, latency, keyInfo, retryInfo) - } else { - logrus.Infof("%s %s - %d - %v%s%s", method, fullPath, statusCode, latency, keyInfo, retryInfo) - } - } -} - -// handleHealth 健康检查处理器 -func (ps *ProxyServer) handleHealth(c *gin.Context) { - uptime := time.Since(ps.startTime) - stats := ps.keyManager.GetStats() - requestCount := atomic.LoadInt64(&ps.requestCount) - - response := gin.H{ - "status": "healthy", - "uptime": fmt.Sprintf("%.0fs", uptime.Seconds()), - "requestCount": requestCount, - "keysStatus": gin.H{ - "total": stats.TotalKeys, - "healthy": stats.HealthyKeys, - "blacklisted": stats.BlacklistedKeys, - }, - "timestamp": time.Now().Format(time.RFC3339), - } - - c.JSON(http.StatusOK, response) -} - -// handleStats 统计信息处理器 -func (ps *ProxyServer) handleStats(c *gin.Context) { - uptime := time.Since(ps.startTime) - stats := ps.keyManager.GetStats() - requestCount := atomic.LoadInt64(&ps.requestCount) - - response := gin.H{ - "server": gin.H{ - "uptime": fmt.Sprintf("%.0fs", uptime.Seconds()), - "requestCount": requestCount, - "startTime": ps.startTime.Format(time.RFC3339), - "version": "2.0.0", - }, - "keys": stats, - "timestamp": time.Now().Format(time.RFC3339), - } - - c.JSON(http.StatusOK, response) -} - -// handleBlacklist 黑名单处理器 -func (ps *ProxyServer) handleBlacklist(c *gin.Context) { - blacklistInfo := ps.keyManager.GetBlacklistDetails() - c.JSON(http.StatusOK, blacklistInfo) -} - -// handleResetKeys 重置密钥处理器 -func (ps *ProxyServer) handleResetKeys(c *gin.Context) { - result := ps.keyManager.ResetKeys() - c.JSON(http.StatusOK, result) -} - -// handleProxy 代理请求处理器 -func (ps *ProxyServer) handleProxy(c *gin.Context) { - startTime := time.Now() - - // 增加请求计数 - atomic.AddInt64(&ps.requestCount, 1) - - // 统一入口,提前缓存所有请求体 - var bodyBytes []byte - if c.Request.Body != nil { - var err error - bodyBytes, err = io.ReadAll(c.Request.Body) - if err != nil { - logrus.Errorf("读取请求体失败: %v", err) - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "读取请求体失败", - "type": "request_error", - "code": "invalid_request_body", - "timestamp": time.Now().Format(time.RFC3339), - }, - }) - return - } - } - - // 使用缓存后的数据判断请求类型 - isStreamRequest := ps.isStreamRequest(bodyBytes, c) - - // 执行带重试的请求 - ps.executeRequestWithRetry(c, startTime, bodyBytes, isStreamRequest, 0, nil) -} - -// isStreamRequest 判断是否为流式请求 -func (ps *ProxyServer) isStreamRequest(bodyBytes []byte, c *gin.Context) bool { - // 检查 Accept header - if strings.Contains(c.GetHeader("Accept"), "text/event-stream") { - return true - } - - // 检查 URL 查询参数 - if c.Query("stream") == "true" { - return true - } - - // 检查请求体中的 stream 参数 - if len(bodyBytes) > 0 { - if strings.Contains(string(bodyBytes), `"stream":true`) || - strings.Contains(string(bodyBytes), `"stream": true`) { - return true - } - } - - return false -} - -// executeRequestWithRetry 执行带重试的请求 -func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Time, bodyBytes []byte, isStreamRequest bool, retryCount int, retryErrors []RetryError) { - // 获取密钥信息 - keyInfo, err := ps.keyManager.GetNextKey() - if err != nil { - logrus.Errorf("获取密钥失败: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "服务器内部错误: " + err.Error(), - "type": "server_error", - "code": "no_keys_available", - "timestamp": time.Now().Format(time.RFC3339), - }, - }) - return - } - - // 设置密钥信息到上下文(用于日志) - c.Set("keyIndex", keyInfo.Index) - c.Set("keyPreview", keyInfo.Preview) - - // 设置重试信息到上下文 - if retryCount > 0 { - c.Set("retryCount", retryCount) - } - - // 构建上游请求URL - targetURL := *ps.upstreamURL - // 正确拼接路径,而不是替换路径 - if strings.HasSuffix(targetURL.Path, "/") { - targetURL.Path = targetURL.Path + strings.TrimPrefix(c.Request.URL.Path, "/") - } else { - targetURL.Path = targetURL.Path + c.Request.URL.Path - } - targetURL.RawQuery = c.Request.URL.RawQuery - - // 为流式和非流式请求使用不同的超时策略 - var ctx context.Context - var cancel context.CancelFunc - - if isStreamRequest { - // 流式请求只设置响应头超时,不设置整体超时 - ctx, cancel = context.WithCancel(c.Request.Context()) - } else { - // 非流式请求使用配置的超时 - timeout := time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond - ctx, cancel = context.WithTimeout(c.Request.Context(), timeout) - } - defer cancel() - - // 统一使用缓存的 bodyBytes 创建请求 - req, err := http.NewRequestWithContext( - ctx, - c.Request.Method, - targetURL.String(), - bytes.NewReader(bodyBytes), - ) - if err != nil { - logrus.Errorf("创建上游请求失败: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "创建上游请求失败", - "type": "proxy_error", - "code": "request_creation_failed", - "timestamp": time.Now().Format(time.RFC3339), - }, - }) - return - } - req.ContentLength = int64(len(bodyBytes)) - - // 复制请求头 - for key, values := range c.Request.Header { - if key != "Host" { - for _, value := range values { - req.Header.Add(key, value) - } - } - } - - // 设置认证头 - req.Header.Set("Authorization", "Bearer "+keyInfo.Key) - - // 根据请求类型选择合适的客户端 - var client *http.Client - if isStreamRequest { - client = ps.streamClient - // 添加禁用nginx缓冲的头 - req.Header.Set("X-Accel-Buffering", "no") - } else { - client = ps.httpClient - } - - // 发送请求 - resp, err := client.Do(req) - if err != nil { - responseTime := time.Since(startTime) - - // 记录失败日志 - if retryCount > 0 { - logrus.Debugf("重试请求失败 (第 %d 次): %v (响应时间: %v)", retryCount, err, responseTime) - } else { - logrus.Debugf("首次请求失败: %v (响应时间: %v)", err, responseTime) - } - - // 异步记录失败 - go ps.keyManager.RecordFailure(keyInfo.Key, err) - - // 记录重试错误信息 - if retryErrors == nil { - retryErrors = make([]RetryError, 0) - } - retryErrors = append(retryErrors, RetryError{ - StatusCode: 0, // 网络错误,没有HTTP状态码 - ErrorMessage: err.Error(), - KeyIndex: keyInfo.Index, - Attempt: retryCount + 1, - }) - - // 检查是否可以重试 - if retryCount < config.AppConfig.Keys.MaxRetries { - logrus.Debugf("准备重试请求 (第 %d/%d 次)", retryCount+1, config.AppConfig.Keys.MaxRetries) - ps.executeRequestWithRetry(c, startTime, bodyBytes, isStreamRequest, retryCount+1, retryErrors) - return - } - - // 达到最大重试次数,记录最终失败并返回详细的重试信息 - logrus.Infof("请求最终失败,已重试 %d 次,总响应时间: %v", retryCount, responseTime) - ps.returnRetryFailureResponse(c, retryCount, retryErrors) - return - } - defer resp.Body.Close() - - responseTime := time.Since(startTime) - - // 检查HTTP状态码是否需要重试 - // 429 (Too Many Requests) 和 5xx 服务器错误都需要重试 - if resp.StatusCode == 429 || resp.StatusCode >= 500 { - // 记录失败日志 - if retryCount > 0 { - logrus.Debugf("重试请求返回错误 %d (第 %d 次) (响应时间: %v)", resp.StatusCode, retryCount, responseTime) - } else { - logrus.Debugf("首次请求返回错误 %d (响应时间: %v)", resp.StatusCode, responseTime) - } - - // 读取响应体以获取错误信息 - var errorMessage string - if bodyBytes, err := io.ReadAll(resp.Body); err == nil { - errorMessage = string(bodyBytes) - } else { - errorMessage = fmt.Sprintf("HTTP %d", resp.StatusCode) - } - - // 异步记录失败 - go ps.keyManager.RecordFailure(keyInfo.Key, fmt.Errorf("HTTP %d", resp.StatusCode)) - - // 记录重试错误信息 - if retryErrors == nil { - retryErrors = make([]RetryError, 0) - } - retryErrors = append(retryErrors, RetryError{ - StatusCode: resp.StatusCode, - ErrorMessage: errorMessage, - KeyIndex: keyInfo.Index, - Attempt: retryCount + 1, - }) - - // 关闭当前响应 - resp.Body.Close() - - // 检查是否可以重试 - if retryCount < config.AppConfig.Keys.MaxRetries { - logrus.Debugf("准备重试请求 (第 %d/%d 次)", retryCount+1, config.AppConfig.Keys.MaxRetries) - ps.executeRequestWithRetry(c, startTime, bodyBytes, isStreamRequest, retryCount+1, retryErrors) - return - } - - // 达到最大重试次数,记录最终失败并返回详细的重试信息 - logrus.Infof("请求最终失败,已重试 %d 次,最后状态码: %d,总响应时间: %v", retryCount, resp.StatusCode, responseTime) - ps.returnRetryFailureResponse(c, retryCount, retryErrors) - return - } - - // 记录最终成功的日志 - if len(retryErrors) > 0 { - logrus.Debugf("请求最终成功,经过 %d 次重试,状态码: %d,总响应时间: %v", len(retryErrors), resp.StatusCode, responseTime) - } - - // 异步记录统计信息(不阻塞响应) - go func() { - if resp.StatusCode >= 200 && resp.StatusCode < 400 { - ps.keyManager.RecordSuccess(keyInfo.Key) - } else if resp.StatusCode >= 400 { - ps.keyManager.RecordFailure(keyInfo.Key, fmt.Errorf("HTTP %d", resp.StatusCode)) - } - }() - - // 复制响应头 - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) - } - } - - // 流式响应添加禁用缓冲的头 - if isStreamRequest { - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - } - - // 设置状态码 - c.Status(resp.StatusCode) - - // 优化流式响应传输 - if isStreamRequest { - ps.handleStreamResponse(c, resp.Body) - } else { - // 非流式响应:使用标准零拷贝 - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - logrus.Errorf("复制响应体失败: %v (响应时间: %v)", err, responseTime) - } - } -} - -// returnRetryFailureResponse 返回重试失败的详细响应 -func (ps *ProxyServer) returnRetryFailureResponse(c *gin.Context, retryCount int, retryErrors []RetryError) { - // 获取最后一次错误作为主要错误 - var lastError RetryError - var lastStatusCode int = http.StatusBadGateway - - if len(retryErrors) > 0 { - lastError = retryErrors[len(retryErrors)-1] - if lastError.StatusCode > 0 { - lastStatusCode = lastError.StatusCode - } - } - - // 构建详细的错误响应 - errorResponse := gin.H{ - "error": gin.H{ - "message": fmt.Sprintf("请求失败,已重试 %d 次", retryCount), - "type": "proxy_error", - "code": "max_retries_exceeded", - "timestamp": time.Now().Format(time.RFC3339), - "retry_count": retryCount, - "retry_details": retryErrors, - }, - } - - // 如果最后一次错误有具体的错误信息,尝试解析并包含 - if lastError.ErrorMessage != "" && lastError.StatusCode > 0 { - // 尝试解析上游的JSON错误响应 - if strings.Contains(lastError.ErrorMessage, "{") { - errorResponse["upstream_error"] = lastError.ErrorMessage - } else { - errorResponse["upstream_message"] = lastError.ErrorMessage - } - } - - c.JSON(lastStatusCode, errorResponse) -} - -// handleStreamResponse 处理流式响应 -func (ps *ProxyServer) handleStreamResponse(c *gin.Context, body io.ReadCloser) { - defer body.Close() - - flusher, ok := c.Writer.(http.Flusher) - if !ok { - // 降级到标准复制 - _, err := io.Copy(c.Writer, body) - if err != nil { - logrus.Errorf("复制流式响应失败: %v", err) - } - return - } - - // 实现零缓存、实时转发 - copyDone := make(chan bool) - - // 检查客户端连接状态 - ctx := c.Request.Context() - - // 在一个独立的goroutine中定期flush,确保数据被立即发送 - go func() { - defer func() { - // 防止panic - if r := recover(); r != nil { - logrus.Errorf("Flush goroutine panic: %v", r) - } - }() - - ticker := time.NewTicker(50 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-copyDone: - // io.Copy完成后,执行最后一次flush并退出 - ps.safeFlush(flusher) - return - case <-ctx.Done(): - // 客户端断开连接,停止flush - return - case <-ticker.C: - ps.safeFlush(flusher) - } - } - }() - - // 使用io.Copy进行高效的数据复制 - _, err := io.Copy(c.Writer, body) - - // 安全地关闭channel - select { - case <-copyDone: - // channel已经关闭 - default: - close(copyDone) // 通知flush的goroutine可以停止了 - } - - if err != nil && err != io.EOF { - // 检查是否是连接断开导致的错误 - if ps.isConnectionError(err) { - logrus.Debugf("客户端连接断开: %v", err) - } else { - logrus.Errorf("复制流式响应失败: %v", err) - } - } -} - -// safeFlush 安全地执行flush操作 -func (ps *ProxyServer) safeFlush(flusher http.Flusher) { - defer func() { - if r := recover(); r != nil { - // 忽略flush时的panic,通常是因为连接已断开 - logrus.Debugf("Flush panic (connection likely closed): %v", r) - } - }() - - if flusher != nil { - flusher.Flush() - } -} - -// isConnectionError 检查是否是连接相关的错误 -func (ps *ProxyServer) isConnectionError(err error) bool { - if err == nil { - return false - } - - errStr := err.Error() - // 常见的连接断开错误 - connectionErrors := []string{ - "broken pipe", - "connection reset by peer", - "connection aborted", - "client disconnected", - "write: broken pipe", - "use of closed network connection", - "context canceled", - "short write", - "context deadline exceeded", - } - - for _, connErr := range connectionErrors { - if strings.Contains(errStr, connErr) { - return true - } - } - - return false -} - -// Close 关闭代理服务器 -func (ps *ProxyServer) Close() { - if ps.keyManager != nil { - ps.keyManager.Close() - } -} diff --git a/internal/proxy/server.go b/internal/proxy/server.go new file mode 100644 index 0000000..89f143f --- /dev/null +++ b/internal/proxy/server.go @@ -0,0 +1,409 @@ +// Package proxy provides high-performance OpenAI multi-key proxy server +package proxy + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync/atomic" + "time" + + "gpt-load/internal/errors" + "gpt-load/pkg/types" + + "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" +) + +// ProxyServer represents the proxy server +type ProxyServer struct { + keyManager types.KeyManager + configManager types.ConfigManager + httpClient *http.Client + streamClient *http.Client // Dedicated client for streaming + upstreamURL *url.URL + requestCount int64 + startTime time.Time +} + +// NewProxyServer creates a new proxy server +func NewProxyServer(keyManager types.KeyManager, configManager types.ConfigManager) (*ProxyServer, error) { + openaiConfig := configManager.GetOpenAIConfig() + perfConfig := configManager.GetPerformanceConfig() + + // Parse upstream URL + upstreamURL, err := url.Parse(openaiConfig.BaseURL) + if err != nil { + return nil, errors.NewAppErrorWithCause(errors.ErrConfigInvalid, "Failed to parse upstream URL", err) + } + + // Create high-performance HTTP client + transport := &http.Transport{ + MaxIdleConns: 50, + MaxIdleConnsPerHost: 10, + MaxConnsPerHost: 0, // No limit to avoid connection pool bottleneck + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DisableCompression: !perfConfig.EnableGzip, + ForceAttemptHTTP2: true, + WriteBufferSize: 32 * 1024, + ReadBufferSize: 32 * 1024, + } + + // Create dedicated transport for streaming, optimize TCP parameters + streamTransport := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, + MaxConnsPerHost: 0, + IdleConnTimeout: 300 * time.Second, // Keep streaming connections longer + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DisableCompression: true, // Always disable compression for streaming + ForceAttemptHTTP2: true, + WriteBufferSize: 64 * 1024, + ReadBufferSize: 64 * 1024, + ResponseHeaderTimeout: 10 * time.Second, + } + + httpClient := &http.Client{ + Transport: transport, + Timeout: time.Duration(openaiConfig.Timeout) * time.Millisecond, + } + + // Streaming client without overall timeout + streamClient := &http.Client{ + Transport: streamTransport, + } + + return &ProxyServer{ + keyManager: keyManager, + configManager: configManager, + httpClient: httpClient, + streamClient: streamClient, + upstreamURL: upstreamURL, + startTime: time.Now(), + }, nil +} + +// HandleProxy handles proxy requests +func (ps *ProxyServer) HandleProxy(c *gin.Context) { + startTime := time.Now() + + // Increment request count + atomic.AddInt64(&ps.requestCount, 1) + + // Cache all request body upfront + var bodyBytes []byte + if c.Request.Body != nil { + var err error + bodyBytes, err = io.ReadAll(c.Request.Body) + if err != nil { + logrus.Errorf("Failed to read request body: %v", err) + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Failed to read request body", + "code": errors.ErrProxyRequest, + }) + return + } + } + + // Determine if this is a streaming request using cached data + isStreamRequest := ps.isStreamRequest(bodyBytes, c) + + // Execute request with retry + ps.executeRequestWithRetry(c, startTime, bodyBytes, isStreamRequest, 0, nil) +} + +// isStreamRequest determines if this is a streaming request +func (ps *ProxyServer) isStreamRequest(bodyBytes []byte, c *gin.Context) bool { + // Check Accept header + if strings.Contains(c.GetHeader("Accept"), "text/event-stream") { + return true + } + + // Check URL query parameter + if c.Query("stream") == "true" { + return true + } + + // Check stream parameter in request body + if len(bodyBytes) > 0 { + if strings.Contains(string(bodyBytes), `"stream":true`) || + strings.Contains(string(bodyBytes), `"stream": true`) { + return true + } + } + + return false +} + +// executeRequestWithRetry executes request with retry logic +func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Time, bodyBytes []byte, isStreamRequest bool, retryCount int, retryErrors []types.RetryError) { + keysConfig := ps.configManager.GetKeysConfig() + + // Check retry limit + if retryCount >= keysConfig.MaxRetries { + logrus.Errorf("Max retries exceeded (%d)", retryCount) + + // Return detailed error information + errorResponse := gin.H{ + "error": "Max retries exceeded", + "code": errors.ErrProxyRetryExhausted, + "retry_count": retryCount, + "retry_errors": retryErrors, + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + + // Use the last error's status code if available + statusCode := http.StatusBadGateway + if len(retryErrors) > 0 && retryErrors[len(retryErrors)-1].StatusCode > 0 { + statusCode = retryErrors[len(retryErrors)-1].StatusCode + } + + c.JSON(statusCode, errorResponse) + return + } + + // Get key information + keyInfo, err := ps.keyManager.GetNextKey() + if err != nil { + logrus.Errorf("Failed to get key: %v", err) + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": "No API keys available", + "code": errors.ErrNoKeysAvailable, + }) + return + } + + // Set key information to context (for logging) + c.Set("keyIndex", keyInfo.Index) + c.Set("keyPreview", keyInfo.Preview) + + // Set retry information to context + if retryCount > 0 { + c.Set("retryCount", retryCount) + } + + // Build upstream request URL + targetURL := *ps.upstreamURL + // Correctly append path instead of replacing it + if strings.HasSuffix(targetURL.Path, "/") { + targetURL.Path = targetURL.Path + strings.TrimPrefix(c.Request.URL.Path, "/") + } else { + targetURL.Path = targetURL.Path + c.Request.URL.Path + } + targetURL.RawQuery = c.Request.URL.RawQuery + + // Use different timeout strategies for streaming and non-streaming requests + var ctx context.Context + var cancel context.CancelFunc + + if isStreamRequest { + // Streaming requests only set response header timeout, no overall timeout + ctx, cancel = context.WithCancel(c.Request.Context()) + } else { + // Non-streaming requests use configured timeout + openaiConfig := ps.configManager.GetOpenAIConfig() + timeout := time.Duration(openaiConfig.Timeout) * time.Millisecond + ctx, cancel = context.WithTimeout(c.Request.Context(), timeout) + } + defer cancel() + + // Create request using cached bodyBytes + req, err := http.NewRequestWithContext( + ctx, + c.Request.Method, + targetURL.String(), + bytes.NewReader(bodyBytes), + ) + if err != nil { + logrus.Errorf("Failed to create upstream request: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to create upstream request", + "code": errors.ErrProxyRequest, + }) + return + } + req.ContentLength = int64(len(bodyBytes)) + + // Copy request headers + for key, values := range c.Request.Header { + if key != "Host" { + for _, value := range values { + req.Header.Add(key, value) + } + } + } + + // Set authorization header + req.Header.Set("Authorization", "Bearer "+keyInfo.Key) + + // Choose appropriate client based on request type + var client *http.Client + if isStreamRequest { + client = ps.streamClient + // Add header to disable nginx buffering + req.Header.Set("X-Accel-Buffering", "no") + } else { + client = ps.httpClient + } + + // Send request + resp, err := client.Do(req) + if err != nil { + responseTime := time.Since(startTime) + + // Log failure + if retryCount > 0 { + logrus.Debugf("Retry request failed (attempt %d): %v (response time: %v)", retryCount+1, err, responseTime) + } else { + logrus.Debugf("Initial request failed: %v (response time: %v)", err, responseTime) + } + + // Record failure asynchronously + go ps.keyManager.RecordFailure(keyInfo.Key, err) + + // Record retry error information + if retryErrors == nil { + retryErrors = make([]types.RetryError, 0) + } + retryErrors = append(retryErrors, types.RetryError{ + StatusCode: 0, // Network error, no HTTP status code + ErrorMessage: err.Error(), + KeyIndex: keyInfo.Index, + Attempt: retryCount + 1, + }) + + // Retry + ps.executeRequestWithRetry(c, startTime, bodyBytes, isStreamRequest, retryCount+1, retryErrors) + return + } + defer resp.Body.Close() + + responseTime := time.Since(startTime) + + // Check if HTTP status code requires retry + // 429 (Too Many Requests) and 5xx server errors need retry + if resp.StatusCode == 429 || resp.StatusCode >= 500 { + // Log failure + if retryCount > 0 { + logrus.Debugf("Retry request returned error %d (attempt %d) (response time: %v)", resp.StatusCode, retryCount+1, responseTime) + } else { + logrus.Debugf("Initial request returned error %d (response time: %v)", resp.StatusCode, responseTime) + } + + // Read response body to get error information + var errorMessage string + if bodyBytes, err := io.ReadAll(resp.Body); err == nil { + errorMessage = string(bodyBytes) + } else { + errorMessage = fmt.Sprintf("HTTP %d", resp.StatusCode) + } + + // Record failure asynchronously + go ps.keyManager.RecordFailure(keyInfo.Key, fmt.Errorf("HTTP %d", resp.StatusCode)) + + // Record retry error information + if retryErrors == nil { + retryErrors = make([]types.RetryError, 0) + } + retryErrors = append(retryErrors, types.RetryError{ + StatusCode: resp.StatusCode, + ErrorMessage: errorMessage, + KeyIndex: keyInfo.Index, + Attempt: retryCount + 1, + }) + + // Retry + ps.executeRequestWithRetry(c, startTime, bodyBytes, isStreamRequest, retryCount+1, retryErrors) + return + } + + // Success - record success asynchronously + go ps.keyManager.RecordSuccess(keyInfo.Key) + + // Log final success result + if retryCount > 0 { + logrus.Infof("Request succeeded after %d retries (response time: %v)", retryCount, responseTime) + } else { + logrus.Debugf("Request succeeded on first attempt (response time: %v)", responseTime) + } + + // Copy response headers + for key, values := range resp.Header { + for _, value := range values { + c.Header(key, value) + } + } + + // Set status code + c.Status(resp.StatusCode) + + // Handle streaming and non-streaming responses + if isStreamRequest { + ps.handleStreamingResponse(c, resp) + } else { + ps.handleNormalResponse(c, resp) + } +} + +// handleStreamingResponse handles streaming responses +func (ps *ProxyServer) handleStreamingResponse(c *gin.Context, resp *http.Response) { + // Set headers for streaming + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + + // Stream response directly + flusher, ok := c.Writer.(http.Flusher) + if !ok { + logrus.Error("Streaming unsupported") + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Streaming unsupported", + "code": errors.ErrServerInternal, + }) + return + } + + // Copy streaming data + buffer := make([]byte, 4096) + for { + n, err := resp.Body.Read(buffer) + if n > 0 { + if _, writeErr := c.Writer.Write(buffer[:n]); writeErr != nil { + logrus.Errorf("Failed to write streaming data: %v", writeErr) + break + } + flusher.Flush() + } + if err != nil { + if err != io.EOF { + logrus.Errorf("Error reading streaming response: %v", err) + } + break + } + } +} + +// handleNormalResponse handles normal responses +func (ps *ProxyServer) handleNormalResponse(c *gin.Context, resp *http.Response) { + // Copy response body + if _, err := io.Copy(c.Writer, resp.Body); err != nil { + logrus.Errorf("Failed to copy response body: %v", err) + } +} + +// Close closes the proxy server and cleans up resources +func (ps *ProxyServer) Close() { + // Close HTTP clients if needed + if ps.httpClient != nil { + ps.httpClient.CloseIdleConnections() + } + if ps.streamClient != nil { + ps.streamClient.CloseIdleConnections() + } +} diff --git a/pkg/types/interfaces.go b/pkg/types/interfaces.go new file mode 100644 index 0000000..1181114 --- /dev/null +++ b/pkg/types/interfaces.go @@ -0,0 +1,124 @@ +// Package types defines common interfaces and types used across the application +package types + +import "time" + +// ConfigManager defines the interface for configuration management +type ConfigManager interface { + GetServerConfig() ServerConfig + GetKeysConfig() KeysConfig + GetOpenAIConfig() OpenAIConfig + GetAuthConfig() AuthConfig + GetCORSConfig() CORSConfig + GetPerformanceConfig() PerformanceConfig + GetLogConfig() LogConfig + Validate() error +} + +// KeyManager defines the interface for API key management +type KeyManager interface { + LoadKeys() error + GetNextKey() (*KeyInfo, error) + RecordSuccess(key string) + RecordFailure(key string, err error) + GetStats() Stats + ResetBlacklist() + GetBlacklist() []BlacklistEntry + Close() +} + +// ServerConfig represents server configuration +type ServerConfig struct { + Port int `json:"port"` + Host string `json:"host"` +} + +// KeysConfig represents keys configuration +type KeysConfig struct { + FilePath string `json:"filePath"` + StartIndex int `json:"startIndex"` + BlacklistThreshold int `json:"blacklistThreshold"` + MaxRetries int `json:"maxRetries"` +} + +// OpenAIConfig represents OpenAI API configuration +type OpenAIConfig struct { + BaseURL string `json:"baseUrl"` + Timeout int `json:"timeout"` +} + +// AuthConfig represents authentication configuration +type AuthConfig struct { + Key string `json:"key"` + Enabled bool `json:"enabled"` +} + +// CORSConfig represents CORS configuration +type CORSConfig struct { + Enabled bool `json:"enabled"` + AllowedOrigins []string `json:"allowedOrigins"` + AllowedMethods []string `json:"allowedMethods"` + AllowedHeaders []string `json:"allowedHeaders"` + AllowCredentials bool `json:"allowCredentials"` +} + +// PerformanceConfig represents performance configuration +type PerformanceConfig struct { + MaxConcurrentRequests int `json:"maxConcurrentRequests"` + RequestTimeout int `json:"requestTimeout"` + EnableGzip bool `json:"enableGzip"` +} + +// LogConfig represents logging configuration +type LogConfig struct { + Level string `json:"level"` + Format string `json:"format"` + EnableFile bool `json:"enableFile"` + FilePath string `json:"filePath"` + EnableRequest bool `json:"enableRequest"` +} + +// KeyInfo represents API key information +type KeyInfo struct { + Key string `json:"key"` + Index int `json:"index"` + Preview string `json:"preview"` +} + +// Stats represents system statistics +type Stats struct { + CurrentIndex int64 `json:"currentIndex"` + TotalKeys int `json:"totalKeys"` + HealthyKeys int `json:"healthyKeys"` + BlacklistedKeys int `json:"blacklistedKeys"` + SuccessCount int64 `json:"successCount"` + FailureCount int64 `json:"failureCount"` + MemoryUsage MemoryUsage `json:"memoryUsage"` +} + +// MemoryUsage represents memory usage statistics +type MemoryUsage struct { + Alloc uint64 `json:"alloc"` + TotalAlloc uint64 `json:"totalAlloc"` + Sys uint64 `json:"sys"` + NumGC uint32 `json:"numGC"` + LastGCTime string `json:"lastGCTime"` + NextGCTarget uint64 `json:"nextGCTarget"` +} + +// BlacklistEntry represents a blacklisted key entry +type BlacklistEntry struct { + Key string `json:"key"` + Preview string `json:"preview"` + Reason string `json:"reason"` + BlacklistAt time.Time `json:"blacklistAt"` + FailCount int `json:"failCount"` +} + +// RetryError represents retry error information +type RetryError struct { + StatusCode int `json:"statusCode"` + ErrorMessage string `json:"errorMessage"` + KeyIndex int `json:"keyIndex"` + Attempt int `json:"attempt"` +}