diff --git a/.env.example b/.env.example index a9e0520..fce39a9 100644 --- a/.env.example +++ b/.env.example @@ -2,11 +2,15 @@ PORT=3000 HOST=0.0.0.0 -# 认证配置 -AUTH_KEY=sk-123456 +# 服务器读取、写入和空闲连接的超时时间(秒) +SERVER_READ_TIMEOUT=120 +SERVER_WRITE_TIMEOUT=1800 +SERVER_GRACEFUL_SHUTDOWN_TIMEOUT=60 +SERVER_IDLE_TIMEOUT=120 -# 应用地址 -APP_URL=http://localhost:3000 +# 认证配置 +# AUTH_KEY 是必需的,用于保护管理 API 和 UI 界面 +AUTH_KEY=sk-123456 # CORS配置 ENABLE_CORS=true @@ -18,17 +22,17 @@ ALLOW_CREDENTIALS=false # 性能配置 MAX_CONCURRENT_REQUESTS=100 KEY_VALIDATION_POOL_SIZE=50 -ENABLE_GZIP=true # 数据库配置 -DATABASE_DSN=user:password@tcp(localhost:3306)/gpt_load?charset=utf8mb4&parseTime=True&loc=Local +# 示例 DSN: user:password@tcp(localhost:3306)/gpt_load?charset=utf8mb4&parseTime=True&loc=Local +DATABASE_DSN= # Redis配置 -REDIS_DSN=redis://:password@localhost:6379/1 +# 示例 DSN: redis://:password@localhost:6379/1 +REDIS_DSN= # 日志配置 LOG_LEVEL=info LOG_FORMAT=text LOG_ENABLE_FILE=false LOG_FILE_PATH=logs/app.log -LOG_ENABLE_REQUEST=true diff --git a/internal/app/app.go b/internal/app/app.go index a3f2902..8556756 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -135,7 +135,7 @@ func (a *App) Start() error { } // 显示配置并启动所有后台服务 - a.configManager.DisplayConfig() + a.configManager.DisplayServerConfig() a.groupManager.Initialize() diff --git a/internal/channel/factory.go b/internal/channel/factory.go index 133216d..07ccbfb 100644 --- a/internal/channel/factory.go +++ b/internal/channel/factory.go @@ -117,7 +117,7 @@ func (f *Factory) newBaseChannel(name string, group *models.Group) (*BaseChannel MaxIdleConns: group.EffectiveConfig.MaxIdleConns, MaxIdleConnsPerHost: group.EffectiveConfig.MaxIdleConnsPerHost, ResponseHeaderTimeout: time.Duration(group.EffectiveConfig.ResponseHeaderTimeout) * time.Second, - DisableCompression: group.EffectiveConfig.DisableCompression, + DisableCompression: false, WriteBufferSize: 32 * 1024, ReadBufferSize: 32 * 1024, ForceAttemptHTTP2: true, @@ -146,11 +146,11 @@ func (f *Factory) newBaseChannel(name string, group *models.Group) (*BaseChannel streamClient := f.clientManager.GetClient(&streamConfig) return &BaseChannel{ - Name: name, - Upstreams: upstreamInfos, - HTTPClient: httpClient, - StreamClient: streamClient, - TestModel: group.TestModel, + Name: name, + Upstreams: upstreamInfos, + HTTPClient: httpClient, + StreamClient: streamClient, + TestModel: group.TestModel, groupUpstreams: group.Upstreams, effectiveConfig: &group.EffectiveConfig, }, nil diff --git a/internal/config/manager.go b/internal/config/manager.go index b52bd88..ea7ceec 100644 --- a/internal/config/manager.go +++ b/internal/config/manager.go @@ -3,15 +3,12 @@ package config import ( "fmt" - "io" "os" - "path/filepath" - "reflect" - "strconv" "strings" "gpt-load/internal/errors" "gpt-load/internal/types" + "gpt-load/internal/utils" "github.com/joho/godotenv" "github.com/sirupsen/logrus" @@ -67,45 +64,38 @@ func NewManager(settingsManager *SystemSettingsManager) (types.ConfigManager, er // ReloadConfig reloads the configuration from environment variables func (m *Manager) ReloadConfig() error { - // Try to load .env file if err := godotenv.Load(); err != nil { logrus.Info("Info: Create .env file to support environment variable configuration") } - // Get business logic defaults from the single source of truth - defaultSettings := DefaultSystemSettings() - config := &Config{ Server: types.ServerConfig{ - Port: parseInteger(os.Getenv("PORT"), 3000), - Host: getEnvOrDefault("HOST", "0.0.0.0"), - ReadTimeout: defaultSettings.ServerReadTimeout, - WriteTimeout: defaultSettings.ServerWriteTimeout, - IdleTimeout: defaultSettings.ServerIdleTimeout, - GracefulShutdownTimeout: defaultSettings.ServerGracefulShutdownTimeout, + Port: utils.ParseInteger(os.Getenv("PORT"), 3000), + Host: utils.GetEnvOrDefault("HOST", "0.0.0.0"), + ReadTimeout: utils.ParseInteger(os.Getenv("SERVER_READ_TIMEOUT"), 120), + WriteTimeout: utils.ParseInteger(os.Getenv("SERVER_WRITE_TIMEOUT"), 1800), + IdleTimeout: utils.ParseInteger(os.Getenv("SERVER_IDLE_TIMEOUT"), 120), + GracefulShutdownTimeout: utils.ParseInteger(os.Getenv("SERVER_GRACEFUL_SHUTDOWN_TIMEOUT"), 60), }, Auth: types.AuthConfig{ - Key: os.Getenv("AUTH_KEY"), - Enabled: os.Getenv("AUTH_KEY") != "", + Key: os.Getenv("AUTH_KEY"), }, CORS: types.CORSConfig{ - Enabled: parseBoolean(os.Getenv("ENABLE_CORS"), true), - AllowedOrigins: parseArray(os.Getenv("ALLOWED_ORIGINS"), []string{"*"}), - AllowedMethods: parseArray(os.Getenv("ALLOWED_METHODS"), []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}), - AllowedHeaders: parseArray(os.Getenv("ALLOWED_HEADERS"), []string{"*"}), - AllowCredentials: parseBoolean(os.Getenv("ALLOW_CREDENTIALS"), false), + Enabled: utils.ParseBoolean(os.Getenv("ENABLE_CORS"), true), + AllowedOrigins: utils.ParseArray(os.Getenv("ALLOWED_ORIGINS"), []string{"*"}), + AllowedMethods: utils.ParseArray(os.Getenv("ALLOWED_METHODS"), []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}), + AllowedHeaders: utils.ParseArray(os.Getenv("ALLOWED_HEADERS"), []string{"*"}), + AllowCredentials: utils.ParseBoolean(os.Getenv("ALLOW_CREDENTIALS"), false), }, Performance: types.PerformanceConfig{ - MaxConcurrentRequests: parseInteger(os.Getenv("MAX_CONCURRENT_REQUESTS"), 100), - EnableGzip: parseBoolean(os.Getenv("ENABLE_GZIP"), true), - KeyValidationPoolSize: parseInteger(os.Getenv("KEY_VALIDATION_POOL_SIZE"), 10), + MaxConcurrentRequests: utils.ParseInteger(os.Getenv("MAX_CONCURRENT_REQUESTS"), 100), + KeyValidationPoolSize: utils.ParseInteger(os.Getenv("KEY_VALIDATION_POOL_SIZE"), 10), }, Log: types.LogConfig{ - Level: getEnvOrDefault("LOG_LEVEL", "info"), - Format: getEnvOrDefault("LOG_FORMAT", "text"), - EnableFile: parseBoolean(os.Getenv("LOG_ENABLE_FILE"), false), - FilePath: getEnvOrDefault("LOG_FILE_PATH", "logs/app.log"), - EnableRequest: parseBoolean(os.Getenv("LOG_ENABLE_REQUEST"), true), + Level: utils.GetEnvOrDefault("LOG_LEVEL", "info"), + Format: utils.GetEnvOrDefault("LOG_FORMAT", "text"), + EnableFile: utils.ParseBoolean(os.Getenv("LOG_ENABLE_FILE"), false), + FilePath: utils.GetEnvOrDefault("LOG_FILE_PATH", "logs/app.log"), }, Database: types.DatabaseConfig{ DSN: os.Getenv("DATABASE_DSN"), @@ -154,17 +144,7 @@ func (m *Manager) GetDatabaseConfig() types.DatabaseConfig { // GetEffectiveServerConfig returns server configuration merged with system settings func (m *Manager) GetEffectiveServerConfig() types.ServerConfig { - config := m.config.Server - - // Merge with system settings from database - systemSettings := m.settingsManager.GetSettings() - - config.ReadTimeout = systemSettings.ServerReadTimeout - config.WriteTimeout = systemSettings.ServerWriteTimeout - config.IdleTimeout = systemSettings.ServerIdleTimeout - config.GracefulShutdownTimeout = systemSettings.ServerGracefulShutdownTimeout - - return config + return m.config.Server } // Validate validates the configuration @@ -180,6 +160,11 @@ func (m *Manager) Validate() error { validationErrors = append(validationErrors, "max concurrent requests cannot be less than 1") } + // Validate auth key + if m.config.Auth.Key == "" { + validationErrors = append(validationErrors, "AUTH_KEY is required and cannot be empty") + } + if len(validationErrors) > 0 { logrus.Error("Configuration validation failed:") for _, err := range validationErrors { @@ -191,160 +176,51 @@ func (m *Manager) Validate() error { return nil } -// DisplayConfig displays current configuration information -func (m *Manager) DisplayConfig() { +// DisplayServerConfig displays current server-related configuration information +func (m *Manager) DisplayServerConfig() { serverConfig := m.GetEffectiveServerConfig() - authConfig := m.GetAuthConfig() corsConfig := m.GetCORSConfig() perfConfig := m.GetPerformanceConfig() logConfig := m.GetLogConfig() + dbConfig := m.GetDatabaseConfig() - logrus.Info("Current Server Configuration:") - logrus.Infof(" Server: %s:%d", serverConfig.Host, serverConfig.Port) + logrus.Info("--- Server Configuration ---") + logrus.Infof(" Listen Address: %s:%d", serverConfig.Host, serverConfig.Port) + logrus.Infof(" Graceful Shutdown Timeout: %d seconds", serverConfig.GracefulShutdownTimeout) + logrus.Infof(" Read Timeout: %d seconds", serverConfig.ReadTimeout) + logrus.Infof(" Write Timeout: %d seconds", serverConfig.WriteTimeout) + logrus.Infof(" Idle Timeout: %d seconds", serverConfig.IdleTimeout) - authStatus := "disabled" - if authConfig.Enabled { - authStatus = "enabled" - } - logrus.Infof(" Authentication: %s", authStatus) + logrus.Info("--- Performance ---") + logrus.Infof(" Max Concurrent Requests: %d", perfConfig.MaxConcurrentRequests) + logrus.Infof(" Key Validation Pool Size: %d", perfConfig.KeyValidationPoolSize) + logrus.Info("--- Security ---") + logrus.Infof(" Authentication: enabled (key loaded)") corsStatus := "disabled" if corsConfig.Enabled { - corsStatus = "enabled" + corsStatus = fmt.Sprintf("enabled (Origins: %s)", strings.Join(corsConfig.AllowedOrigins, ", ")) } - logrus.Infof(" CORS: %s", corsStatus) - logrus.Infof(" Max concurrent requests: %d", perfConfig.MaxConcurrentRequests) - logrus.Infof(" Concurrency pool size: %d", perfConfig.KeyValidationPoolSize) + logrus.Infof(" CORS: %s", corsStatus) - gzipStatus := "disabled" - if perfConfig.EnableGzip { - gzipStatus = "enabled" - } - logrus.Infof(" Gzip compression: %s", gzipStatus) - - requestLogStatus := "enabled" - if !logConfig.EnableRequest { - requestLogStatus = "disabled" - } - logrus.Infof(" Request logging: %s", requestLogStatus) -} - -// Helper functions - -// parseInteger parses integer environment variable -func parseInteger(value string, defaultValue int) int { - if value == "" { - return defaultValue - } - if parsed, err := strconv.Atoi(value); err == nil { - return parsed - } - return defaultValue -} - -// parseBoolean parses boolean environment variable -func parseBoolean(value string, defaultValue bool) bool { - if value == "" { - return defaultValue - } - - lowerValue := strings.ToLower(value) - switch lowerValue { - case "true", "1", "yes", "on": - return true - case "false", "0", "no", "off": - return false - default: - return defaultValue - } -} - -// parseArray parses array environment variable (comma-separated) -func parseArray(value string, defaultValue []string) []string { - if value == "" { - return defaultValue - } - - parts := strings.Split(value, ",") - result := make([]string, 0, len(parts)) - for _, part := range parts { - if trimmed := strings.TrimSpace(part); trimmed != "" { - result = append(result, trimmed) - } - } - - if len(result) == 0 { - return defaultValue - } - return result -} - -// getEnvOrDefault gets environment variable or default value -func getEnvOrDefault(key, defaultValue string) string { - if value := os.Getenv(key); value != "" { - return value - } - return defaultValue -} - -// GetInt is a helper function for SystemSettingsManager to get an integer value with a default. -func (s *SystemSettingsManager) GetInt(key string, defaultValue int) int { - settings := s.GetSettings() - v := reflect.ValueOf(settings) - t := v.Type() - - for i := 0; i < t.NumField(); i++ { - structField := t.Field(i) - jsonTag := strings.Split(structField.Tag.Get("json"), ",")[0] - if jsonTag == key { - valueField := v.Field(i) - if valueField.Kind() == reflect.Int { - return int(valueField.Int()) - } - break - } - } - - return defaultValue -} - -// SetupLogger configures the logging system based on the provided configuration. -func SetupLogger(configManager types.ConfigManager) { - logConfig := configManager.GetLogConfig() - - // Set log level - level, err := logrus.ParseLevel(logConfig.Level) - if err != nil { - logrus.Warn("Invalid log level, using info") - level = logrus.InfoLevel - } - logrus.SetLevel(level) - - // Set log format - if logConfig.Format == "json" { - logrus.SetFormatter(&logrus.JSONFormatter{ - TimestampFormat: "2006-01-02T15:04:05.000Z07:00", // ISO 8601 format - }) - } else { - logrus.SetFormatter(&logrus.TextFormatter{ - ForceColors: true, - FullTimestamp: true, - TimestampFormat: "2006-01-02 15:04:05", - }) - } - - // Setup file logging if enabled + logrus.Info("--- Logging ---") + logrus.Infof(" Log Level: %s", logConfig.Level) + logrus.Infof(" Log Format: %s", logConfig.Format) + logrus.Infof(" File Logging: %t", logConfig.EnableFile) if logConfig.EnableFile { - logDir := filepath.Dir(logConfig.FilePath) - if err := os.MkdirAll(logDir, 0755); err != nil { - logrus.Warnf("Failed to create log directory: %v", err) - } else { - logFile, err := os.OpenFile(logConfig.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) - if err != nil { - logrus.Warnf("Failed to open log file: %v", err) - } else { - logrus.SetOutput(io.MultiWriter(os.Stdout, logFile)) - } - } + logrus.Infof(" Log File Path: %s", logConfig.FilePath) } + + logrus.Info("--- Dependencies ---") + if dbConfig.DSN != "" { + logrus.Info(" Database: configured") + } else { + logrus.Info(" Database: not configured") + } + if m.config.RedisDSN != "" { + logrus.Info(" Redis: configured") + } else { + logrus.Info(" Redis: not configured") + } + logrus.Info("--------------------------") } diff --git a/internal/config/system_settings.go b/internal/config/system_settings.go index 62e5209..d3b67f9 100644 --- a/internal/config/system_settings.go +++ b/internal/config/system_settings.go @@ -8,6 +8,7 @@ import ( "gpt-load/internal/store" "gpt-load/internal/syncer" "gpt-load/internal/types" + "gpt-load/internal/utils" "os" "reflect" "strconv" @@ -20,73 +21,6 @@ import ( const SettingsUpdateChannel = "system_settings:updated" -// GenerateSettingsMetadata 使用反射从 SystemSettings 结构体动态生成元数据 -func GenerateSettingsMetadata(s *types.SystemSettings) []models.SystemSettingInfo { - var settingsInfo []models.SystemSettingInfo - v := reflect.ValueOf(s).Elem() - t := v.Type() - - for i := range t.NumField() { - field := t.Field(i) - fieldValue := v.Field(i) - - jsonTag := field.Tag.Get("json") - if jsonTag == "" { - continue - } - - nameTag := field.Tag.Get("name") - descTag := field.Tag.Get("desc") - defaultTag := field.Tag.Get("default") - validateTag := field.Tag.Get("validate") - categoryTag := field.Tag.Get("category") - - var minValue *int - if strings.HasPrefix(validateTag, "min=") { - valStr := strings.TrimPrefix(validateTag, "min=") - if val, err := strconv.Atoi(valStr); err == nil { - minValue = &val - } - } - - info := models.SystemSettingInfo{ - Key: jsonTag, - Name: nameTag, - Value: fieldValue.Interface(), - Type: field.Type.String(), - DefaultValue: defaultTag, - Description: descTag, - Category: categoryTag, - MinValue: minValue, - } - settingsInfo = append(settingsInfo, info) - } - return settingsInfo -} - -// DefaultSystemSettings 返回默认的系统配置 -func DefaultSystemSettings() types.SystemSettings { - s := types.SystemSettings{} - v := reflect.ValueOf(&s).Elem() - t := v.Type() - - for i := range t.NumField() { - field := t.Field(i) - defaultTag := field.Tag.Get("default") - if defaultTag == "" { - continue - } - - fieldValue := v.Field(i) - if fieldValue.CanSet() { - if err := setFieldFromString(fieldValue, defaultTag); err != nil { - logrus.Warnf("Failed to set default value for field %s: %v", field.Name, err) - } - } - } - return s -} - // SystemSettingsManager 管理系统配置 type SystemSettingsManager struct { syncer *syncer.CacheSyncer[types.SystemSettings] @@ -97,16 +31,16 @@ func NewSystemSettingsManager() *SystemSettingsManager { return &SystemSettingsManager{} } -type gm interface { +type groupManager interface { Invalidate() error } -type leader interface { +type leaderService interface { IsLeader() bool } // Initialize initializes the SystemSettingsManager with database and store dependencies. -func (sm *SystemSettingsManager) Initialize(store store.Store, gm gm, leader leader) error { +func (sm *SystemSettingsManager) Initialize(store store.Store, gm groupManager, leader leaderService) error { settingsLoader := func() (types.SystemSettings, error) { var dbSettings []models.SystemSetting if err := db.DB.Find(&dbSettings).Error; err != nil { @@ -119,7 +53,7 @@ func (sm *SystemSettingsManager) Initialize(store store.Store, gm gm, leader lea } // Start with default settings, then override with values from the database. - settings := DefaultSystemSettings() + settings := utils.DefaultSystemSettings() v := reflect.ValueOf(&settings).Elem() t := v.Type() jsonToField := make(map[string]string) @@ -135,14 +69,14 @@ func (sm *SystemSettingsManager) Initialize(store store.Store, gm gm, leader lea if fieldName, ok := jsonToField[key]; ok { fieldValue := v.FieldByName(fieldName) if fieldValue.IsValid() && fieldValue.CanSet() { - if err := setFieldFromString(fieldValue, valStr); err != nil { + if err := utils.SetFieldFromString(fieldValue, valStr); err != nil { logrus.Warnf("Failed to set value from map for field %s: %v", fieldName, err) } } } } - sm.DisplayCurrentSettings(settings) + sm.DisplaySystemConfig(settings) return settings, nil } @@ -180,8 +114,8 @@ func (sm *SystemSettingsManager) Stop() { // EnsureSettingsInitialized 确保数据库中存在所有系统设置的记录。 func (sm *SystemSettingsManager) EnsureSettingsInitialized() error { - defaultSettings := DefaultSystemSettings() - metadata := GenerateSettingsMetadata(&defaultSettings) + defaultSettings := utils.DefaultSystemSettings() + metadata := utils.GenerateSettingsMetadata(&defaultSettings) for _, meta := range metadata { var existing models.SystemSetting @@ -189,20 +123,15 @@ func (sm *SystemSettingsManager) EnsureSettingsInitialized() error { if err != nil { value := fmt.Sprintf("%v", meta.DefaultValue) if meta.Key == "app_url" { - // Special handling for app_url initialization - if appURL := os.Getenv("APP_URL"); appURL != "" { - value = appURL - } else { - host := os.Getenv("HOST") - if host == "" || host == "0.0.0.0" { - host = "localhost" - } - port := os.Getenv("PORT") - if port == "" { - port = "3000" - } - value = fmt.Sprintf("http://%s:%s", host, port) + host := os.Getenv("HOST") + if host == "" || host == "0.0.0.0" { + host = "localhost" } + port := os.Getenv("PORT") + if port == "" { + port = "3000" + } + value = fmt.Sprintf("http://%s:%s", host, port) } setting := models.SystemSetting{ SettingKey: meta.Key, @@ -221,26 +150,30 @@ func (sm *SystemSettingsManager) EnsureSettingsInitialized() error { } // GetSettings 获取当前系统配置 -// If the syncer is not initialized, it returns default settings. func (sm *SystemSettingsManager) GetSettings() types.SystemSettings { if sm.syncer == nil { logrus.Warn("SystemSettingsManager is not initialized, returning default settings.") - return DefaultSystemSettings() + return utils.DefaultSystemSettings() } return sm.syncer.Get() } // GetAppUrl returns the effective App URL. -// It prioritizes the value from system settings (database) over the APP_URL environment variable. func (sm *SystemSettingsManager) GetAppUrl() string { - // 1. 优先级: 数据库中的系统配置 settings := sm.GetSettings() if settings.AppUrl != "" { return settings.AppUrl } - // 2. 回退: 环境变量 - return os.Getenv("APP_URL") + host := os.Getenv("HOST") + if host == "" || host == "0.0.0.0" { + host = "localhost" + } + port := os.Getenv("PORT") + if port == "" { + port = "3000" + } + return fmt.Sprintf("http://%s:%s", host, port) } // UpdateSettings 更新系统配置 @@ -273,40 +206,35 @@ func (sm *SystemSettingsManager) UpdateSettings(settingsMap map[string]any) erro } // GetEffectiveConfig 获取有效配置 (系统配置 + 分组覆盖) -func (sm *SystemSettingsManager) GetEffectiveConfig(groupConfig datatypes.JSONMap) types.SystemSettings { - // 从系统配置开始 +func (sm *SystemSettingsManager) GetEffectiveConfig(groupConfigJSON datatypes.JSONMap) types.SystemSettings { effectiveConfig := sm.GetSettings() - v := reflect.ValueOf(&effectiveConfig).Elem() - t := v.Type() - // 创建一个从 json 标签到字段名的映射 - jsonToField := make(map[string]string) - for i := range t.NumField() { - field := t.Field(i) - jsonTag := strings.Split(field.Tag.Get("json"), ",")[0] - if jsonTag != "" { - jsonToField[jsonTag] = field.Name - } + if groupConfigJSON == nil { + return effectiveConfig } - // 应用分组配置覆盖 - for key, val := range groupConfig { - if fieldName, ok := jsonToField[key]; ok { - fieldValue := v.FieldByName(fieldName) - if fieldValue.IsValid() && fieldValue.CanSet() { - switch fieldValue.Kind() { - case reflect.Int: - if intVal, err := interfaceToInt(val); err == nil { - fieldValue.SetInt(int64(intVal)) - } - case reflect.String: - if strVal, ok := interfaceToString(val); ok { - fieldValue.SetString(strVal) - } - case reflect.Bool: - if boolVal, ok := interfaceToBool(val); ok { - fieldValue.SetBool(boolVal) - } + var groupConfig models.GroupConfig + groupConfigBytes, err := groupConfigJSON.MarshalJSON() + if err != nil { + logrus.Warnf("Failed to marshal group config JSON, using system settings only. Error: %v", err) + return effectiveConfig + } + if err := json.Unmarshal(groupConfigBytes, &groupConfig); err != nil { + logrus.Warnf("Failed to unmarshal group config, using system settings only. Error: %v", err) + return effectiveConfig + } + + gcv := reflect.ValueOf(groupConfig) + ecv := reflect.ValueOf(&effectiveConfig).Elem() + + for i := 0; i < gcv.NumField(); i++ { + groupField := gcv.Field(i) + if groupField.Kind() == reflect.Ptr && !groupField.IsNil() { + groupFieldValue := groupField.Elem() + effectiveField := ecv.FieldByName(gcv.Type().Field(i).Name) + if effectiveField.IsValid() && effectiveField.CanSet() { + if effectiveField.Type() == groupFieldValue.Type() { + effectiveField.Set(groupFieldValue) } } } @@ -317,11 +245,11 @@ func (sm *SystemSettingsManager) GetEffectiveConfig(groupConfig datatypes.JSONMa // ValidateSettings 验证系统配置的有效性 func (sm *SystemSettingsManager) ValidateSettings(settingsMap map[string]any) error { - tempSettings := DefaultSystemSettings() + tempSettings := utils.DefaultSystemSettings() v := reflect.ValueOf(&tempSettings).Elem() t := v.Type() jsonToField := make(map[string]reflect.StructField) - for i := 0; i < t.NumField(); i++ { + for i := range t.NumField() { field := t.Field(i) jsonTag := field.Tag.Get("json") if jsonTag != "" { @@ -339,7 +267,6 @@ func (sm *SystemSettingsManager) ValidateSettings(settingsMap map[string]any) er switch field.Type.Kind() { case reflect.Int: - // JSON numbers are decoded as float64 floatVal, ok := value.(float64) if !ok { return fmt.Errorf("invalid type for %s: expected a number, got %T", key, value) @@ -372,97 +299,70 @@ func (sm *SystemSettingsManager) ValidateSettings(settingsMap map[string]any) er return nil } -// DisplayCurrentSettings 显示当前系统配置信息 -func (sm *SystemSettingsManager) DisplayCurrentSettings(settings types.SystemSettings) { - logrus.Info("Current System Settings:") - logrus.Infof(" App URL: %s", settings.AppUrl) - logrus.Infof(" Blacklist threshold: %d", settings.BlacklistThreshold) - logrus.Infof(" Max retries: %d", settings.MaxRetries) - logrus.Infof(" Server timeouts: read=%ds, write=%ds, idle=%ds, shutdown=%ds", - settings.ServerReadTimeout, settings.ServerWriteTimeout, - settings.ServerIdleTimeout, settings.ServerGracefulShutdownTimeout) - logrus.Infof(" Request timeouts: request=%ds, connect=%ds, idle_conn=%ds", - settings.RequestTimeout, settings.ConnectTimeout, settings.IdleConnTimeout) - logrus.Infof(" HTTP Client Pool: max_idle_conns=%d, max_idle_conns_per_host=%d", - settings.MaxIdleConns, settings.MaxIdleConnsPerHost) - logrus.Infof(" Request log retention: %d days", settings.RequestLogRetentionDays) - logrus.Infof(" Key validation: interval=%dmin, task_timeout=%dmin", - settings.KeyValidationIntervalMinutes, settings.KeyValidationTaskTimeoutMinutes) -} - -// setFieldFromString sets a struct field's value from a string, based on the field's kind. -func setFieldFromString(fieldValue reflect.Value, value string) error { - if !fieldValue.CanSet() { - return fmt.Errorf("field cannot be set") +// ValidateGroupConfigOverrides validates a map of group-level configuration overrides. +func (sm *SystemSettingsManager) ValidateGroupConfigOverrides(configMap map[string]any) error { + tempSettings := types.SystemSettings{} + v := reflect.ValueOf(&tempSettings).Elem() + t := v.Type() + jsonToField := make(map[string]reflect.StructField) + for i := range t.NumField() { + field := t.Field(i) + jsonTag := field.Tag.Get("json") + if jsonTag != "" { + jsonToField[jsonTag] = field + } } - switch fieldValue.Kind() { - case reflect.Int: - intVal, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return fmt.Errorf("invalid integer value '%s': %w", value, err) + for key, value := range configMap { + if value == nil { + continue } - fieldValue.SetInt(int64(intVal)) - case reflect.Bool: - boolVal, err := strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid boolean value '%s': %w", value, err) + + field, ok := jsonToField[key] + if !ok { + return fmt.Errorf("invalid setting key: %s", key) + } + + validateTag := field.Tag.Get("validate") + + floatVal, isFloat := value.(float64) + if !isFloat { + continue + } + intVal := int(floatVal) + if floatVal != float64(intVal) { + return fmt.Errorf("invalid value for %s: must be an integer", key) + } + + if strings.HasPrefix(validateTag, "min=") { + minValStr := strings.TrimPrefix(validateTag, "min=") + minVal, _ := strconv.Atoi(minValStr) + if intVal < minVal { + return fmt.Errorf("value for %s (%d) is below minimum value (%d)", key, intVal, minVal) + } } - fieldValue.SetBool(boolVal) - case reflect.String: - fieldValue.SetString(value) - default: - return fmt.Errorf("unsupported field kind: %s", fieldValue.Kind()) } + return nil } -// 工具函数 +// DisplaySystemConfig displays the current system settings. +func (sm *SystemSettingsManager) DisplaySystemConfig(settings types.SystemSettings) { + logrus.Info("--- System Settings ---") + logrus.Infof(" App URL: %s", settings.AppUrl) + logrus.Infof(" Request Log Retention: %d days", settings.RequestLogRetentionDays) -func interfaceToInt(val any) (int, error) { - switch v := val.(type) { - case json.Number: - i64, err := v.Int64() - if err != nil { - return 0, err - } - return int(i64), nil - case int: - return v, nil - case float64: - if v != float64(int(v)) { - return 0, fmt.Errorf("value is a float, not an integer: %v", v) - } - return int(v), nil - case string: - return strconv.Atoi(v) - default: - return 0, fmt.Errorf("cannot convert %T to int", v) - } -} + logrus.Info("--- Request Behavior ---") + logrus.Infof(" Request Timeout: %d seconds", settings.RequestTimeout) + logrus.Infof(" Connect Timeout: %d seconds", settings.ConnectTimeout) + logrus.Infof(" Response Header Timeout: %d seconds", settings.ResponseHeaderTimeout) + logrus.Infof(" Idle Connection Timeout: %d seconds", settings.IdleConnTimeout) + logrus.Infof(" Max Idle Connections: %d", settings.MaxIdleConns) + logrus.Infof(" Max Idle Connections Per Host: %d", settings.MaxIdleConnsPerHost) -// interfaceToString is kept for GetEffectiveConfig -func interfaceToString(val any) (string, bool) { - s, ok := val.(string) - return s, ok -} - -// interfaceToBool is kept for GetEffectiveConfig -func interfaceToBool(val any) (bool, bool) { - switch v := val.(type) { - case json.Number: - if s := v.String(); s == "1" { - return true, true - } else if s == "0" { - return false, true - } - case bool: - return v, true - case string: - b, err := strconv.ParseBool(v) - if err == nil { - return b, true - } - } - return false, false + logrus.Info("--- Key & Group Behavior ---") + logrus.Infof(" Max Retries: %d", settings.MaxRetries) + logrus.Infof(" Blacklist Threshold: %d", settings.BlacklistThreshold) + logrus.Infof(" Key Validation Interval: %d minutes", settings.KeyValidationIntervalMinutes) + logrus.Info("-----------------------") } diff --git a/internal/handler/group_handler.go b/internal/handler/group_handler.go index 9ddf23b..c69a4df 100644 --- a/internal/handler/group_handler.go +++ b/internal/handler/group_handler.go @@ -6,10 +6,10 @@ import ( "fmt" "net/url" - "gpt-load/internal/config" app_errors "gpt-load/internal/errors" "gpt-load/internal/models" "gpt-load/internal/response" + "gpt-load/internal/utils" "reflect" "regexp" "strconv" @@ -87,33 +87,18 @@ func isValidGroupName(name string) bool { return match } -// validateAndCleanConfig validates the group config against the GroupConfig struct. -func validateAndCleanConfig(configMap map[string]any) (map[string]any, error) { +// validateAndCleanConfig validates the group config against the GroupConfig struct and system-defined rules. +func (s *Server) validateAndCleanConfig(configMap map[string]any) (map[string]any, error) { if configMap == nil { return nil, nil } - configBytes, err := json.Marshal(configMap) - if err != nil { - return nil, err - } - - var validatedConfig models.GroupConfig - if err := json.Unmarshal(configBytes, &validatedConfig); err != nil { - return nil, err - } - - // Strict check for unknown fields - var cleanedMap map[string]any - if err := json.Unmarshal(configBytes, &cleanedMap); err != nil { - return nil, err - } - - val := reflect.ValueOf(validatedConfig) - typ := val.Type() + // 1. Check for unknown fields by comparing against the GroupConfig struct definition. + var tempGroupConfig models.GroupConfig + groupConfigType := reflect.TypeOf(tempGroupConfig) validFields := make(map[string]bool) - for i := 0; i < typ.NumField(); i++ { - jsonTag := typ.Field(i).Tag.Get("json") + for i := 0; i < groupConfigType.NumField(); i++ { + jsonTag := groupConfigType.Field(i).Tag.Get("json") fieldName := strings.Split(jsonTag, ",")[0] if fieldName != "" && fieldName != "-" { validFields[fieldName] = true @@ -126,28 +111,29 @@ func validateAndCleanConfig(configMap map[string]any) (map[string]any, error) { } } - // 验证配置项的合理范围 - if validatedConfig.BlacklistThreshold != nil && *validatedConfig.BlacklistThreshold < 0 { - return nil, fmt.Errorf("blacklist_threshold must be >= 0") - } - if validatedConfig.MaxRetries != nil && (*validatedConfig.MaxRetries < 0 || *validatedConfig.MaxRetries > 10) { - return nil, fmt.Errorf("max_retries must be between 0 and 10") - } - if validatedConfig.RequestTimeout != nil && (*validatedConfig.RequestTimeout < 1 || *validatedConfig.RequestTimeout > 3600) { - return nil, fmt.Errorf("request_timeout must be between 1 and 3600 seconds") - } - if validatedConfig.KeyValidationIntervalMinutes != nil && (*validatedConfig.KeyValidationIntervalMinutes < 5 || *validatedConfig.KeyValidationIntervalMinutes > 1440) { - return nil, fmt.Errorf("key_validation_interval_minutes must be between 5 and 1440 minutes") + // 2. Validate the values of the provided fields using the central system settings validator. + if err := s.SettingsManager.ValidateGroupConfigOverrides(configMap); err != nil { + return nil, err + } + + // 3. Unmarshal and marshal back to clean the map and ensure correct types. + configBytes, err := json.Marshal(configMap) + if err != nil { + return nil, fmt.Errorf("failed to marshal config map: %w", err) + } + + var validatedConfig models.GroupConfig + if err := json.Unmarshal(configBytes, &validatedConfig); err != nil { + return nil, fmt.Errorf("failed to unmarshal into validated config: %w", err) } - // Marshal back to a map to ensure consistency validatedBytes, err := json.Marshal(validatedConfig) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to marshal validated config: %w", err) } var finalMap map[string]any if err := json.Unmarshal(validatedBytes, &finalMap); err != nil { - return nil, err + return nil, fmt.Errorf("failed to unmarshal into final map: %w", err) } return finalMap, nil @@ -187,7 +173,7 @@ func (s *Server) CreateGroup(c *gin.Context) { return } - cleanedConfig, err := validateAndCleanConfig(req.Config) + cleanedConfig, err := s.validateAndCleanConfig(req.Config) if err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid config format: %v", err))) return @@ -325,7 +311,7 @@ func (s *Server) UpdateGroup(c *gin.Context) { group.ParamOverrides = req.ParamOverrides } if req.Config != nil { - cleanedConfig, err := validateAndCleanConfig(req.Config) + cleanedConfig, err := s.validateAndCleanConfig(req.Config) if err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid config format: %v", err))) return @@ -496,8 +482,8 @@ func (s *Server) GetGroupConfigOptions(c *gin.Context) { var options []ConfigOption // 1. Get all system setting definitions from the struct tags - defaultSettings := config.DefaultSystemSettings() - settingDefinitions := config.GenerateSettingsMetadata(&defaultSettings) + defaultSettings := utils.DefaultSystemSettings() + settingDefinitions := utils.GenerateSettingsMetadata(&defaultSettings) defMap := make(map[string]models.SystemSettingInfo) for _, def := range settingDefinitions { defMap[def.Key] = def diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 2d80102..57a6ddc 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -78,14 +78,6 @@ func (s *Server) Login(c *gin.Context) { authConfig := s.config.GetAuthConfig() - if !authConfig.Enabled { - c.JSON(http.StatusOK, LoginResponse{ - Success: true, - Message: "Authentication disabled", - }) - return - } - if req.AuthKey == authConfig.Key { c.JSON(http.StatusOK, LoginResponse{ Success: true, diff --git a/internal/handler/settings_handler.go b/internal/handler/settings_handler.go index 27ae494..47fd3d6 100644 --- a/internal/handler/settings_handler.go +++ b/internal/handler/settings_handler.go @@ -1,10 +1,10 @@ package handler import ( - "gpt-load/internal/config" app_errors "gpt-load/internal/errors" "gpt-load/internal/models" "gpt-load/internal/response" + "gpt-load/internal/utils" "time" "github.com/gin-gonic/gin" @@ -14,7 +14,7 @@ import ( // It retrieves all system settings, groups them by category, and returns them. func (s *Server) GetSettings(c *gin.Context) { currentSettings := s.SettingsManager.GetSettings() - settingsInfo := config.GenerateSettingsMetadata(¤tSettings) + settingsInfo := utils.GenerateSettingsMetadata(¤tSettings) // Group settings by category while preserving order categorized := make(map[string][]models.SystemSettingInfo) diff --git a/internal/keypool/provider.go b/internal/keypool/provider.go index 85f605a..2057915 100644 --- a/internal/keypool/provider.go +++ b/internal/keypool/provider.go @@ -71,18 +71,18 @@ func (p *KeyProvider) SelectKey(groupID uint) (*models.APIKey, error) { } // UpdateStatus 异步地提交一个 Key 状态更新任务。 -func (p *KeyProvider) UpdateStatus(keyID uint, groupID uint, isSuccess bool) { +func (p *KeyProvider) UpdateStatus(apiKey *models.APIKey, group *models.Group, isSuccess bool) { go func() { - keyHashKey := fmt.Sprintf("key:%d", keyID) - activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID) + keyHashKey := fmt.Sprintf("key:%d", apiKey.ID) + activeKeysListKey := fmt.Sprintf("group:%d:active_keys", group.ID) if isSuccess { - if err := p.handleSuccess(keyID, keyHashKey, activeKeysListKey); err != nil { - logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to handle key success") + if err := p.handleSuccess(apiKey.ID, keyHashKey, activeKeysListKey); err != nil { + logrus.WithFields(logrus.Fields{"keyID": apiKey.ID, "error": err}).Error("Failed to handle key success") } } else { - if err := p.handleFailure(keyID, keyHashKey, activeKeysListKey); err != nil { - logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to handle key failure") + if err := p.handleFailure(apiKey, group, keyHashKey, activeKeysListKey); err != nil { + logrus.WithFields(logrus.Fields{"keyID": apiKey.ID, "error": err}).Error("Failed to handle key failure") } } }() @@ -134,7 +134,7 @@ func (p *KeyProvider) handleSuccess(keyID uint, keyHashKey, activeKeysListKey st }) } -func (p *KeyProvider) handleFailure(keyID uint, keyHashKey, activeKeysListKey string) error { +func (p *KeyProvider) handleFailure(apiKey *models.APIKey, group *models.Group, keyHashKey, activeKeysListKey string) error { keyDetails, err := p.store.HGetAll(keyHashKey) if err != nil { return fmt.Errorf("failed to get key details from store: %w", err) @@ -146,19 +146,19 @@ func (p *KeyProvider) handleFailure(keyID uint, keyHashKey, activeKeysListKey st return nil } - settings := p.settingsManager.GetSettings() - blacklistThreshold := settings.BlacklistThreshold + // 获取该分组的有效配置 + blacklistThreshold := group.EffectiveConfig.BlacklistThreshold return p.db.Transaction(func(tx *gorm.DB) error { var key models.APIKey - if err := tx.Set("gorm:query_option", "FOR UPDATE").First(&key, keyID).Error; err != nil { - return fmt.Errorf("failed to lock key %d for update: %w", keyID, err) + if err := tx.Set("gorm:query_option", "FOR UPDATE").First(&key, apiKey.ID).Error; err != nil { + return fmt.Errorf("failed to lock key %d for update: %w", apiKey.ID, err) } newFailureCount := failureCount + 1 updates := map[string]any{"failure_count": newFailureCount} - shouldBlacklist := newFailureCount >= int64(blacklistThreshold) + shouldBlacklist := blacklistThreshold > 0 && newFailureCount >= int64(blacklistThreshold) if shouldBlacklist { updates["status"] = models.KeyStatusInvalid } @@ -172,8 +172,8 @@ func (p *KeyProvider) handleFailure(keyID uint, keyHashKey, activeKeysListKey st } if shouldBlacklist { - logrus.WithFields(logrus.Fields{"keyID": keyID, "threshold": blacklistThreshold}).Warn("Key has reached blacklist threshold, disabling.") - if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil { + logrus.WithFields(logrus.Fields{"keyID": apiKey.ID, "threshold": blacklistThreshold}).Warn("Key has reached blacklist threshold, disabling.") + if err := p.store.LRem(activeKeysListKey, 0, apiKey.ID); err != nil { return fmt.Errorf("failed to LRem key from active list: %w", err) } if err := p.store.HSet(keyHashKey, map[string]any{"status": models.KeyStatusInvalid}); err != nil { diff --git a/internal/keypool/validator.go b/internal/keypool/validator.go index cecf4ff..bd57ae2 100644 --- a/internal/keypool/validator.go +++ b/internal/keypool/validator.go @@ -58,7 +58,8 @@ func (s *KeyValidator) ValidateSingleKey(ctx context.Context, key *models.APIKey isValid, validationErr := ch.ValidateKey(ctx, key.KeyValue) - s.keypoolProvider.UpdateStatus(key.ID, group.ID, isValid) + group.EffectiveConfig = s.SettingsManager.GetEffectiveConfig(group.Config) + s.keypoolProvider.UpdateStatus(key, group, isValid) if !isValid { logrus.WithFields(logrus.Fields{ diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index e6ccdb6..cc4c83c 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -18,16 +18,6 @@ import ( // Logger creates a high-performance logging middleware func Logger(config types.LogConfig) gin.HandlerFunc { return func(c *gin.Context) { - // Check if request logging is enabled - if !config.EnableRequest { - // Don't log requests, only process them - c.Next() - // Only log errors - if c.Writer.Status() >= 400 { - logrus.Errorf("Error %d: %s %s", c.Writer.Status(), c.Request.Method, c.Request.URL.Path) - } - return - } start := time.Now() path := c.Request.URL.Path @@ -127,10 +117,6 @@ func CORS(config types.CORSConfig) gin.HandlerFunc { // Auth creates an authentication middleware func Auth(config types.AuthConfig) gin.HandlerFunc { return func(c *gin.Context) { - if !config.Enabled { - c.Next() - return - } // Skip authentication for management endpoints path := c.Request.URL.Path diff --git a/internal/models/types.go b/internal/models/types.go index efbe8ea..ec33a6d 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -25,21 +25,15 @@ type SystemSetting struct { // GroupConfig 存储特定于分组的配置 type GroupConfig struct { - BlacklistThreshold *int `json:"blacklist_threshold,omitempty"` - MaxRetries *int `json:"max_retries,omitempty"` - ServerReadTimeout *int `json:"server_read_timeout,omitempty"` - ServerWriteTimeout *int `json:"server_write_timeout,omitempty"` - ServerIdleTimeout *int `json:"server_idle_timeout,omitempty"` - ServerGracefulShutdownTimeout *int `json:"server_graceful_shutdown_timeout,omitempty"` - RequestTimeout *int `json:"request_timeout,omitempty"` - ResponseTimeout *int `json:"response_timeout,omitempty"` - IdleConnTimeout *int `json:"idle_conn_timeout,omitempty"` - KeyValidationIntervalMinutes *int `json:"key_validation_interval_minutes,omitempty"` - ConnectTimeout *int `json:"connect_timeout,omitempty"` - MaxIdleConns *int `json:"max_idle_conns,omitempty"` - MaxIdleConnsPerHost *int `json:"max_idle_conns_per_host,omitempty"` - ResponseHeaderTimeout *int `json:"response_header_timeout,omitempty"` - DisableCompression *bool `json:"disable_compression,omitempty"` + RequestTimeout *int `json:"request_timeout,omitempty"` + IdleConnTimeout *int `json:"idle_conn_timeout,omitempty"` + ConnectTimeout *int `json:"connect_timeout,omitempty"` + MaxIdleConns *int `json:"max_idle_conns,omitempty"` + MaxIdleConnsPerHost *int `json:"max_idle_conns_per_host,omitempty"` + ResponseHeaderTimeout *int `json:"response_header_timeout,omitempty"` + MaxRetries *int `json:"max_retries,omitempty"` + BlacklistThreshold *int `json:"blacklist_threshold,omitempty"` + KeyValidationIntervalMinutes *int `json:"key_validation_interval_minutes,omitempty"` } // Group 对应 groups 表 diff --git a/internal/proxy/server.go b/internal/proxy/server.go index cdb4517..38e8e14 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -166,7 +166,7 @@ func (ps *ProxyServer) executeRequestWithRetry( return } - ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, false) + ps.keyProvider.UpdateStatus(apiKey, group, false) var statusCode int var errorMessage string @@ -201,7 +201,7 @@ func (ps *ProxyServer) executeRequestWithRetry( return } - ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, true) + ps.keyProvider.UpdateStatus(apiKey, group, true) logrus.Debugf("Request for group %s succeeded on attempt %d with key %s", group.Name, retryCount+1, utils.MaskAPIKey(apiKey.KeyValue)) for key, values := range resp.Header { diff --git a/internal/router/router.go b/internal/router/router.go index 2f4bbaa..79ac4b9 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -83,13 +83,9 @@ func registerAPIRoutes(router *gin.Engine, serverHandler *handler.Server, logCle registerPublicAPIRoutes(api, serverHandler) // 认证 - if authConfig.Enabled { - protectedAPI := api.Group("") - protectedAPI.Use(middleware.Auth(authConfig)) - registerProtectedAPIRoutes(protectedAPI, serverHandler, logCleanupHandler) - } else { - registerProtectedAPIRoutes(api, serverHandler, logCleanupHandler) - } + protectedAPI := api.Group("") + protectedAPI.Use(middleware.Auth(authConfig)) + registerProtectedAPIRoutes(protectedAPI, serverHandler, logCleanupHandler) } // registerPublicAPIRoutes 公开API路由 @@ -153,9 +149,7 @@ func registerProxyRoutes(router *gin.Engine, proxyServer *proxy.ProxyServer, con proxyGroup := router.Group("/proxy") authConfig := configManager.GetAuthConfig() - if authConfig.Enabled { - proxyGroup.Use(middleware.Auth(authConfig)) - } + proxyGroup.Use(middleware.Auth(authConfig)) proxyGroup.Any("/:group_name/*path", proxyServer.HandleProxy) } diff --git a/internal/services/group_manager.go b/internal/services/group_manager.go index a3d9ea3..bdf5f0b 100644 --- a/internal/services/group_manager.go +++ b/internal/services/group_manager.go @@ -47,7 +47,12 @@ func (gm *GroupManager) Initialize() error { g := *group g.EffectiveConfig = gm.settingsManager.GetEffectiveConfig(g.Config) groupMap[g.Name] = &g + logrus.WithFields(logrus.Fields{ + "group_name": g.Name, + "effective_config": g.EffectiveConfig, + }).Debug("Loaded group with effective config") } + return groupMap, nil } diff --git a/internal/services/key_cron_service.go b/internal/services/key_cron_service.go index 2f2e250..a5580a4 100644 --- a/internal/services/key_cron_service.go +++ b/internal/services/key_cron_service.go @@ -66,7 +66,7 @@ func (s *KeyCronService) runLoop() { select { case <-ticker.C: if s.LeaderService.IsLeader() { - logrus.Info("KeyCronService: Running as leader, submitting validation jobs.") + logrus.Debug("KeyCronService: Running as leader, submitting validation jobs.") s.submitValidationJobs() } else { logrus.Debug("KeyCronService: Not the leader. Standing by.") @@ -130,7 +130,9 @@ func (s *KeyCronService) submitValidationJobs() { s.updateGroupTimestamps(groupsToUpdateTimestamp, validationStartTime) } - logrus.Infof("KeyCronService: Submitted %d keys for validation across %d groups.", total, len(groupsToUpdateTimestamp)) + if total > 0 { + logrus.Infof("KeyCronService: Submitted %d keys for validation across %d groups.", total, len(groupsToUpdateTimestamp)) + } } func (s *KeyCronService) updateGroupTimestamps(groups map[uint]*models.Group, validationStartTime time.Time) { diff --git a/internal/services/key_manual_validation_service.go b/internal/services/key_manual_validation_service.go index acaf091..cefa37b 100644 --- a/internal/services/key_manual_validation_service.go +++ b/internal/services/key_manual_validation_service.go @@ -52,8 +52,7 @@ func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (* return nil, fmt.Errorf("no keys to validate in group %s", group.Name) } - timeoutMinutes := s.SettingsManager.GetInt("key_validation_task_timeout_minutes", 60) - timeout := time.Duration(timeoutMinutes) * time.Minute + timeout := 30 * time.Minute taskStatus, err := s.TaskService.StartTask(group.Name, len(keys), timeout) if err != nil { diff --git a/internal/services/key_service.go b/internal/services/key_service.go index e145730..e5a8cdf 100644 --- a/internal/services/key_service.go +++ b/internal/services/key_service.go @@ -298,8 +298,6 @@ func (s *KeyService) TestMultipleKeys(ctx context.Context, group *models.Group, chunk := keysToTest[i:end] results, err := s.KeyValidator.TestMultipleKeys(ctx, group, chunk) if err != nil { - // If one chunk fails, we might want to stop or collect partial results. - // For now, let's stop and return the error. return nil, err } allResults = append(allResults, results...) diff --git a/internal/types/types.go b/internal/types/types.go index 7892c75..97d522e 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -10,7 +10,7 @@ type ConfigManager interface { GetEffectiveServerConfig() ServerConfig GetRedisDSN() string Validate() error - DisplayConfig() + DisplayServerConfig() ReloadConfig() error } @@ -18,28 +18,20 @@ type ConfigManager interface { type SystemSettings struct { // 基础参数 AppUrl string `json:"app_url" default:"http://localhost:3000" name:"项目地址" category:"基础参数" desc:"项目的基础 URL,用于拼接分组终端节点地址。系统配置优先于环境变量 APP_URL。"` - RequestLogRetentionDays int `json:"request_log_retention_days" default:"7" name:"日志保留天数" category:"基础参数" desc:"请求日志在数据库中的保留天数" validate:"min=1"` + RequestLogRetentionDays int `json:"request_log_retention_days" default:"7" name:"日志保留天数" category:"基础参数" desc:"请求日志在数据库中的保留天数" validate:"min=0"` - // 服务超时 - ServerReadTimeout int `json:"server_read_timeout" default:"120" name:"读取超时" category:"服务超时" desc:"HTTP 服务器读取超时时间(秒)" validate:"min=1"` - ServerWriteTimeout int `json:"server_write_timeout" default:"1800" name:"写入超时" category:"服务超时" desc:"HTTP 服务器写入超时时间(秒)" validate:"min=1"` - ServerIdleTimeout int `json:"server_idle_timeout" default:"120" name:"空闲超时" category:"服务超时" desc:"HTTP 服务器空闲超时时间(秒)" validate:"min=1"` - ServerGracefulShutdownTimeout int `json:"server_graceful_shutdown_timeout" default:"60" name:"优雅关闭超时" category:"服务超时" desc:"服务优雅关闭的等待超时时间(秒)" validate:"min=1"` - - // 请求超时 - RequestTimeout int `json:"request_timeout" default:"600" name:"请求超时" category:"请求超时" desc:"转发请求的完整生命周期超时(秒),包括连接、重试等。" validate:"min=1"` - ConnectTimeout int `json:"connect_timeout" default:"5" name:"连接超时" category:"请求超时" desc:"与上游服务建立新连接的超时时间(秒)。" validate:"min=1"` - IdleConnTimeout int `json:"idle_conn_timeout" default:"120" name:"空闲连接超时" category:"请求超时" desc:"HTTP 客户端中空闲连接的超时时间(秒)。" validate:"min=1"` - MaxIdleConns int `json:"max_idle_conns" default:"100" name:"最大空闲连接数" category:"请求超时" desc:"HTTP 客户端连接池中允许的最大空闲连接总数。" validate:"min=1"` - MaxIdleConnsPerHost int `json:"max_idle_conns_per_host" default:"10" name:"每主机最大空闲连接数" category:"请求超时" desc:"HTTP 客户端连接池对每个上游主机允许的最大空闲连接数。" validate:"min=1"` - ResponseHeaderTimeout int `json:"response_header_timeout" default:"120" name:"响应头超时" category:"请求超时" desc:"等待上游服务响应头的最长时间(秒),用于流式请求。" validate:"min=1"` - DisableCompression bool `json:"disable_compression" default:"false" name:"禁用压缩" category:"请求超时" desc:"是否禁用对上游请求的传输压缩(Gzip)。对于流式请求建议开启以降低延迟。"` + // 请求设置 + RequestTimeout int `json:"request_timeout" default:"600" name:"请求超时" category:"请求设置" desc:"转发请求的完整生命周期超时(秒),包括连接、重试等。" validate:"min=1"` + ConnectTimeout int `json:"connect_timeout" default:"30" name:"连接超时" category:"请求设置" desc:"与上游服务建立新连接的超时时间(秒)。" validate:"min=1"` + IdleConnTimeout int `json:"idle_conn_timeout" default:"120" name:"空闲连接超时" category:"请求设置" desc:"HTTP 客户端中空闲连接的超时时间(秒)。" validate:"min=1"` + ResponseHeaderTimeout int `json:"response_header_timeout" default:"120" name:"响应头超时" category:"请求设置" desc:"等待上游服务响应头的最长时间(秒),用于流式请求。" validate:"min=1"` + MaxIdleConns int `json:"max_idle_conns" default:"50" name:"最大空闲连接数" category:"请求设置" desc:"HTTP 客户端连接池中允许的最大空闲连接总数。" validate:"min=1"` + MaxIdleConnsPerHost int `json:"max_idle_conns_per_host" default:"20" name:"每主机最大空闲连接数" category:"请求设置" desc:"HTTP 客户端连接池对每个上游主机允许的最大空闲连接数。" validate:"min=1"` // 密钥配置 - MaxRetries int `json:"max_retries" default:"3" name:"最大重试次数" category:"密钥配置" desc:"单个请求使用不同 Key 的最大重试次数" validate:"min=0"` - BlacklistThreshold int `json:"blacklist_threshold" default:"1" name:"黑名单阈值" category:"密钥配置" desc:"一个 Key 连续失败多少次后进入黑名单" validate:"min=0"` - KeyValidationIntervalMinutes int `json:"key_validation_interval_minutes" default:"60" name:"定时验证周期" category:"密钥配置" desc:"后台定时验证密钥的默认周期(分钟)" validate:"min=5"` - KeyValidationTaskTimeoutMinutes int `json:"key_validation_task_timeout_minutes" default:"60" name:"手动验证超时" category:"密钥配置" desc:"手动触发的全量验证任务的超时时间(分钟)" validate:"min=10"` + MaxRetries int `json:"max_retries" default:"3" name:"最大重试次数" category:"密钥配置" desc:"单个请求使用不同 Key 的最大重试次数" validate:"min=0"` + BlacklistThreshold int `json:"blacklist_threshold" default:"3" name:"黑名单阈值" category:"密钥配置" desc:"一个 Key 连续失败多少次后进入黑名单" validate:"min=0"` + KeyValidationIntervalMinutes int `json:"key_validation_interval_minutes" default:"60" name:"定时验证周期" category:"密钥配置" desc:"后台定时验证密钥的默认周期(分钟)" validate:"min=30"` } // ServerConfig represents server configuration @@ -54,8 +46,7 @@ type ServerConfig struct { // AuthConfig represents authentication configuration type AuthConfig struct { - Key string `json:"key"` - Enabled bool `json:"enabled"` + Key string `json:"key"` } // CORSConfig represents CORS configuration @@ -69,18 +60,16 @@ type CORSConfig struct { // PerformanceConfig represents performance configuration type PerformanceConfig struct { - MaxConcurrentRequests int `json:"max_concurrent_requests"` - KeyValidationPoolSize int `json:"key_validation_pool_size"` - EnableGzip bool `json:"enable_gzip"` + MaxConcurrentRequests int `json:"max_concurrent_requests"` + KeyValidationPoolSize int `json:"key_validation_pool_size"` } // LogConfig represents logging configuration type LogConfig struct { - Level string `json:"level"` - Format string `json:"format"` - EnableFile bool `json:"enable_file"` - FilePath string `json:"file_path"` - EnableRequest bool `json:"enable_request"` + Level string `json:"level"` + Format string `json:"format"` + EnableFile bool `json:"enable_file"` + FilePath string `json:"file_path"` } // DatabaseConfig represents database configuration diff --git a/internal/utils/config_utils.go b/internal/utils/config_utils.go new file mode 100644 index 0000000..31d2641 --- /dev/null +++ b/internal/utils/config_utils.go @@ -0,0 +1,206 @@ +package utils + +import ( + "fmt" + "gpt-load/internal/models" + "gpt-load/internal/types" + "io" + "os" + "path/filepath" + "reflect" + "strconv" + "strings" + + "github.com/sirupsen/logrus" +) + +// GenerateSettingsMetadata 使用反射从 SystemSettings 结构体动态生成元数据 +func GenerateSettingsMetadata(s *types.SystemSettings) []models.SystemSettingInfo { + var settingsInfo []models.SystemSettingInfo + v := reflect.ValueOf(s).Elem() + t := v.Type() + + for i := range t.NumField() { + field := t.Field(i) + fieldValue := v.Field(i) + + jsonTag := field.Tag.Get("json") + if jsonTag == "" { + continue + } + + nameTag := field.Tag.Get("name") + descTag := field.Tag.Get("desc") + defaultTag := field.Tag.Get("default") + validateTag := field.Tag.Get("validate") + categoryTag := field.Tag.Get("category") + + var minValue *int + if strings.HasPrefix(validateTag, "min=") { + valStr := strings.TrimPrefix(validateTag, "min=") + if val, err := strconv.Atoi(valStr); err == nil { + minValue = &val + } + } + + info := models.SystemSettingInfo{ + Key: jsonTag, + Name: nameTag, + Value: fieldValue.Interface(), + Type: field.Type.String(), + DefaultValue: defaultTag, + Description: descTag, + Category: categoryTag, + MinValue: minValue, + } + settingsInfo = append(settingsInfo, info) + } + return settingsInfo +} + +// DefaultSystemSettings 返回默认的系统配置 +func DefaultSystemSettings() types.SystemSettings { + s := types.SystemSettings{} + v := reflect.ValueOf(&s).Elem() + t := v.Type() + + for i := range t.NumField() { + field := t.Field(i) + defaultTag := field.Tag.Get("default") + if defaultTag == "" { + continue + } + + fieldValue := v.Field(i) + if fieldValue.CanSet() { + if err := SetFieldFromString(fieldValue, defaultTag); err != nil { + logrus.Warnf("Failed to set default value for field %s: %v", field.Name, err) + } + } + } + return s +} + +// SetFieldFromString sets a struct field's value from a string, based on the field's kind. +func SetFieldFromString(fieldValue reflect.Value, value string) error { + if !fieldValue.CanSet() { + return fmt.Errorf("field cannot be set") + } + + switch fieldValue.Kind() { + case reflect.Int: + intVal, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return fmt.Errorf("invalid integer value '%s': %w", value, err) + } + fieldValue.SetInt(int64(intVal)) + case reflect.Bool: + boolVal, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid boolean value '%s': %w", value, err) + } + fieldValue.SetBool(boolVal) + case reflect.String: + fieldValue.SetString(value) + default: + return fmt.Errorf("unsupported field kind: %s", fieldValue.Kind()) + } + return nil +} + +// ParseInteger parses integer environment variable +func ParseInteger(value string, defaultValue int) int { + if value == "" { + return defaultValue + } + if parsed, err := strconv.Atoi(value); err == nil { + return parsed + } + return defaultValue +} + +// ParseBoolean parses boolean environment variable +func ParseBoolean(value string, defaultValue bool) bool { + if value == "" { + return defaultValue + } + + lowerValue := strings.ToLower(value) + switch lowerValue { + case "true", "1", "yes", "on": + return true + case "false", "0", "no", "off": + return false + default: + return defaultValue + } +} + +// ParseArray parses array environment variable (comma-separated) +func ParseArray(value string, defaultValue []string) []string { + if value == "" { + return defaultValue + } + + parts := strings.Split(value, ",") + result := make([]string, 0, len(parts)) + for _, part := range parts { + if trimmed := strings.TrimSpace(part); trimmed != "" { + result = append(result, trimmed) + } + } + + if len(result) == 0 { + return defaultValue + } + return result +} + +// GetEnvOrDefault gets environment variable or default value +func GetEnvOrDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +// SetupLogger configures the logging system based on the provided configuration. +func SetupLogger(configManager types.ConfigManager) { + logConfig := configManager.GetLogConfig() + + // Set log level + level, err := logrus.ParseLevel(logConfig.Level) + if err != nil { + logrus.Warn("Invalid log level, using info") + level = logrus.InfoLevel + } + logrus.SetLevel(level) + + // Set log format + if logConfig.Format == "json" { + logrus.SetFormatter(&logrus.JSONFormatter{ + TimestampFormat: "2006-01-02T15:04:05.000Z07:00", // ISO 8601 format + }) + } else { + logrus.SetFormatter(&logrus.TextFormatter{ + ForceColors: true, + FullTimestamp: true, + TimestampFormat: "2006-01-02 15:04:05", + }) + } + + // Setup file logging if enabled + if logConfig.EnableFile { + logDir := filepath.Dir(logConfig.FilePath) + if err := os.MkdirAll(logDir, 0755); err != nil { + logrus.Warnf("Failed to create log directory: %v", err) + } else { + logFile, err := os.OpenFile(logConfig.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + logrus.Warnf("Failed to open log file: %v", err) + } else { + logrus.SetOutput(io.MultiWriter(os.Stdout, logFile)) + } + } + } +}