diff --git a/cmd/gpt-load/main.go b/cmd/gpt-load/main.go index 9dec6d1..45e1d4a 100644 --- a/cmd/gpt-load/main.go +++ b/cmd/gpt-load/main.go @@ -96,7 +96,7 @@ func main() { defer proxyServer.Close() // Create handlers - serverHandler := handler.NewServer(database, configManager, keyValidatorService, keyManualValidationService, taskService, keyService) + serverHandler := handler.NewServer(database, configManager, settingsManager, keyValidatorService, keyManualValidationService, taskService, keyService) logCleanupHandler := handler.NewLogCleanupHandler(logCleanupService) // Setup routes using the new router package diff --git a/internal/handler/group_handler.go b/internal/handler/group_handler.go index e746306..da189b6 100644 --- a/internal/handler/group_handler.go +++ b/internal/handler/group_handler.go @@ -4,16 +4,76 @@ package handler import ( "encoding/json" "fmt" + + "gpt-load/internal/config" app_errors "gpt-load/internal/errors" "gpt-load/internal/models" "gpt-load/internal/response" + "reflect" "regexp" "strconv" + "strings" + + "gpt-load/internal/channel" "github.com/gin-gonic/gin" "gorm.io/datatypes" ) +// isValidChannelType checks if the channel type is valid by checking against the registered channels. +func isValidChannelType(channelType string) bool { + channels := channel.GetChannels() + for _, t := range channels { + if t == channelType { + return true + } + } + return false +} + +// UpstreamDefinition defines the structure for an upstream in the request. +type UpstreamDefinition struct { + URL string `json:"url"` + Weight int `json:"weight"` +} + +// validateAndCleanUpstreams validates and cleans the upstreams JSON. +func validateAndCleanUpstreams(upstreams json.RawMessage) (datatypes.JSON, error) { + if len(upstreams) == 0 { + return nil, fmt.Errorf("upstreams field is required") + } + + var defs []UpstreamDefinition + if err := json.Unmarshal(upstreams, &defs); err != nil { + return nil, fmt.Errorf("invalid format for upstreams: %w", err) + } + + if len(defs) == 0 { + return nil, fmt.Errorf("at least one upstream is required") + } + + for i := range defs { + defs[i].URL = strings.TrimSpace(defs[i].URL) + if defs[i].URL == "" { + return nil, fmt.Errorf("upstream URL cannot be empty") + } + // Basic URL format validation + if !strings.HasPrefix(defs[i].URL, "http://") && !strings.HasPrefix(defs[i].URL, "https://") { + return nil, fmt.Errorf("invalid URL format for upstream: %s", defs[i].URL) + } + if defs[i].Weight <= 0 { + return nil, fmt.Errorf("upstream weight must be a positive integer") + } + } + + cleanedUpstreams, err := json.Marshal(defs) + if err != nil { + return nil, fmt.Errorf("failed to marshal cleaned upstreams: %w", err) + } + + return cleanedUpstreams, nil +} + // isValidGroupName checks if the group name is valid. func isValidGroupName(name string) bool { if name == "" { @@ -40,6 +100,29 @@ func validateAndCleanConfig(configMap map[string]any) (map[string]any, error) { return nil, err } + // Strict check for unknown fields + var cleanedMap map[string]any + if err := json.Unmarshal(configBytes, &cleanedMap); err != nil { + return nil, err + } + + val := reflect.ValueOf(validatedConfig) + typ := val.Type() + validFields := make(map[string]bool) + for i := 0; i < typ.NumField(); i++ { + jsonTag := typ.Field(i).Tag.Get("json") + fieldName := strings.Split(jsonTag, ",")[0] + if fieldName != "" && fieldName != "-" { + validFields[fieldName] = true + } + } + + for key := range configMap { + if !validFields[key] { + return nil, fmt.Errorf("unknown config field: '%s'", key) + } + } + // 验证配置项的合理范围 if validatedConfig.BlacklistThreshold != nil && *validatedConfig.BlacklistThreshold < 0 { return nil, fmt.Errorf("blacklist_threshold must be >= 0") @@ -54,52 +137,70 @@ func validateAndCleanConfig(configMap map[string]any) (map[string]any, error) { return nil, fmt.Errorf("key_validation_interval_minutes must be between 5 and 1440 minutes") } - // Marshal back to a map to remove any fields not in GroupConfig + // Marshal back to a map to ensure consistency validatedBytes, err := json.Marshal(validatedConfig) if err != nil { return nil, err } - - var cleanedMap map[string]any - if err := json.Unmarshal(validatedBytes, &cleanedMap); err != nil { + var finalMap map[string]any + if err := json.Unmarshal(validatedBytes, &finalMap); err != nil { return nil, err } - return cleanedMap, nil + return finalMap, nil } // CreateGroup handles the creation of a new group. func (s *Server) CreateGroup(c *gin.Context) { - var group models.Group - if err := c.ShouldBindJSON(&group); err != nil { + var req models.Group + if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) return } - // Validation - if !isValidGroupName(group.Name) { + // Data Cleaning and Validation + name := strings.TrimSpace(req.Name) + if !isValidGroupName(name) { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid group name format. Use 3-30 lowercase letters, numbers, and underscores.")) return } - if len(group.Upstreams) == 0 { - response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "At least one upstream is required")) + + channelType := strings.TrimSpace(req.ChannelType) + if !isValidChannelType(channelType) { + supported := strings.Join(channel.GetChannels(), ", ") + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid channel type. Supported types are: %s", supported))) return } - if group.ChannelType == "" { - response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Channel type is required")) - return - } - if group.TestModel == "" { + + testModel := strings.TrimSpace(req.TestModel) + if testModel == "" { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Test model is required")) return } - cleanedConfig, err := validateAndCleanConfig(group.Config) + cleanedUpstreams, err := validateAndCleanUpstreams(json.RawMessage(req.Upstreams)) if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid config format")) + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error())) return } - group.Config = cleanedConfig + + cleanedConfig, err := validateAndCleanConfig(req.Config) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid config format: %v", err))) + return + } + + group := models.Group{ + Name: name, + DisplayName: strings.TrimSpace(req.DisplayName), + Description: strings.TrimSpace(req.Description), + Upstreams: cleanedUpstreams, + ChannelType: channelType, + Sort: req.Sort, + TestModel: testModel, + ParamOverrides: req.ParamOverrides, + Config: cleanedConfig, + } if err := s.DB.Create(&group).Error; err != nil { response.Error(c, app_errors.ParseDBError(err)) @@ -122,11 +223,11 @@ func (s *Server) ListGroups(c *gin.Context) { // GroupUpdateRequest defines the payload for updating a group. // Using a dedicated struct avoids issues with zero values being ignored by GORM's Update. type GroupUpdateRequest struct { - Name string `json:"name"` - DisplayName string `json:"display_name"` - Description string `json:"description"` + Name *string `json:"name,omitempty"` + DisplayName *string `json:"display_name,omitempty"` + Description *string `json:"description,omitempty"` Upstreams json.RawMessage `json:"upstreams"` - ChannelType string `json:"channel_type"` + ChannelType *string `json:"channel_type,omitempty"` Sort *int `json:"sort"` TestModel string `json:"test_model"` ParamOverrides map[string]any `json:"param_overrides"` @@ -161,31 +262,52 @@ func (s *Server) UpdateGroup(c *gin.Context) { } defer tx.Rollback() // Rollback on panic - // Apply updates from the request - if req.Name != "" { - if !isValidGroupName(req.Name) { - response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid group name format.")) + // Apply updates from the request, with cleaning and validation + if req.Name != nil { + cleanedName := strings.TrimSpace(*req.Name) + if !isValidGroupName(cleanedName) { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid group name format. Name is required and must be 3-30 lowercase letters, numbers, or underscores.")) return } - group.Name = req.Name + group.Name = cleanedName } - if req.DisplayName != "" { - group.DisplayName = req.DisplayName + + if req.DisplayName != nil { + group.DisplayName = strings.TrimSpace(*req.DisplayName) } - if req.Description != "" { - group.Description = req.Description + + if req.Description != nil { + group.Description = strings.TrimSpace(*req.Description) } + if req.Upstreams != nil { - group.Upstreams = datatypes.JSON(req.Upstreams) + cleanedUpstreams, err := validateAndCleanUpstreams(req.Upstreams) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error())) + return + } + group.Upstreams = cleanedUpstreams } - if req.ChannelType != "" { - group.ChannelType = req.ChannelType + + if req.ChannelType != nil { + cleanedChannelType := strings.TrimSpace(*req.ChannelType) + if !isValidChannelType(cleanedChannelType) { + supported := strings.Join(channel.GetChannels(), ", ") + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid channel type. Supported types are: %s", supported))) + return + } + group.ChannelType = cleanedChannelType } if req.Sort != nil { group.Sort = *req.Sort } if req.TestModel != "" { - group.TestModel = req.TestModel + cleanedTestModel := strings.TrimSpace(req.TestModel) + if cleanedTestModel == "" { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Test model cannot be empty or just spaces.")) + return + } + group.TestModel = cleanedTestModel } if req.ParamOverrides != nil { group.ParamOverrides = req.ParamOverrides @@ -193,7 +315,7 @@ func (s *Server) UpdateGroup(c *gin.Context) { if req.Config != nil { cleanedConfig, err := validateAndCleanConfig(req.Config) if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid config format")) + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid config format: %v", err))) return } group.Config = cleanedConfig @@ -258,3 +380,67 @@ func (s *Server) DeleteGroup(c *gin.Context) { response.Success(c, gin.H{"message": "Group and associated keys deleted successfully"}) } + +// ConfigOption represents a single configurable option for a group. +type ConfigOption struct { + Key string `json:"key"` + Name string `json:"name"` + Description string `json:"description"` + DefaultValue any `json:"default_value"` +} + +// GetGroupConfigOptions returns a list of available configuration options for groups. +func (s *Server) GetGroupConfigOptions(c *gin.Context) { + var options []ConfigOption + + // 1. Get all system setting definitions from the struct tags + defaultSettings := config.DefaultSystemSettings() + settingDefinitions := config.GenerateSettingsMetadata(&defaultSettings) + defMap := make(map[string]models.SystemSettingInfo) + for _, def := range settingDefinitions { + defMap[def.Key] = def + } + + // 2. Get current system setting values + currentSettings := s.SettingsManager.GetSettings() + currentSettingsValue := reflect.ValueOf(currentSettings) + currentSettingsType := currentSettingsValue.Type() + jsonToFieldMap := make(map[string]string) + for i := 0; i < currentSettingsType.NumField(); i++ { + field := currentSettingsType.Field(i) + jsonTag := strings.Split(field.Tag.Get("json"), ",")[0] + if jsonTag != "" { + jsonToFieldMap[jsonTag] = field.Name + } + } + + // 3. Iterate over GroupConfig fields to maintain order and build the response + groupConfigType := reflect.TypeOf(models.GroupConfig{}) + + for i := 0; i < groupConfigType.NumField(); i++ { + field := groupConfigType.Field(i) + jsonTag := field.Tag.Get("json") + key := strings.Split(jsonTag, ",")[0] + + if key == "" || key == "-" { + continue + } + + if definition, ok := defMap[key]; ok { + var defaultValue any + if fieldName, ok := jsonToFieldMap[key]; ok { + defaultValue = currentSettingsValue.FieldByName(fieldName).Interface() + } + + option := ConfigOption{ + Key: key, + Name: definition.Name, + Description: definition.Description, + DefaultValue: defaultValue, + } + options = append(options, option) + } + } + + response.Success(c, options) +} diff --git a/internal/handler/handler.go b/internal/handler/handler.go index b9b8bbf..138a462 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -5,6 +5,7 @@ import ( "net/http" "time" + "gpt-load/internal/config" "gpt-load/internal/models" "gpt-load/internal/services" "gpt-load/internal/types" @@ -17,6 +18,7 @@ import ( type Server struct { DB *gorm.DB config types.ConfigManager + SettingsManager *config.SystemSettingsManager KeyValidatorService *services.KeyValidatorService KeyManualValidationService *services.KeyManualValidationService TaskService *services.TaskService @@ -27,6 +29,7 @@ type Server struct { func NewServer( db *gorm.DB, config types.ConfigManager, + settingsManager *config.SystemSettingsManager, keyValidatorService *services.KeyValidatorService, keyManualValidationService *services.KeyManualValidationService, taskService *services.TaskService, @@ -35,6 +38,7 @@ func NewServer( return &Server{ DB: db, config: config, + SettingsManager: settingsManager, KeyValidatorService: keyValidatorService, KeyManualValidationService: keyManualValidationService, TaskService: taskService, diff --git a/internal/router/router.go b/internal/router/router.go index d2df176..a16ec51 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -103,6 +103,7 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser { groups.POST("", serverHandler.CreateGroup) groups.GET("", serverHandler.ListGroups) + groups.GET("/config-options", serverHandler.GetGroupConfigOptions) groups.PUT("/:id", serverHandler.UpdateGroup) groups.DELETE("/:id", serverHandler.DeleteGroup)