// Package handler provides HTTP handlers for the application package handler import ( "encoding/json" "fmt" "net/url" "sync" app_errors "gpt-load/internal/errors" "gpt-load/internal/models" "gpt-load/internal/response" "gpt-load/internal/utils" "reflect" "regexp" "strconv" "strings" "time" "gpt-load/internal/channel" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" "gorm.io/datatypes" ) // isValidChannelType checks if the channel type is valid by checking against the registered channels. func isValidChannelType(channelType string) bool { channels := channel.GetChannels() for _, t := range channels { if t == channelType { return true } } return false } // UpstreamDefinition defines the structure for an upstream in the request. type UpstreamDefinition struct { URL string `json:"url"` Weight int `json:"weight"` } // validateAndCleanUpstreams validates and cleans the upstreams JSON. func validateAndCleanUpstreams(upstreams json.RawMessage) (datatypes.JSON, error) { if len(upstreams) == 0 { return nil, fmt.Errorf("upstreams field is required") } var defs []UpstreamDefinition if err := json.Unmarshal(upstreams, &defs); err != nil { return nil, fmt.Errorf("invalid format for upstreams: %w", err) } if len(defs) == 0 { return nil, fmt.Errorf("at least one upstream is required") } for i := range defs { defs[i].URL = strings.TrimSpace(defs[i].URL) if defs[i].URL == "" { return nil, fmt.Errorf("upstream URL cannot be empty") } // Basic URL format validation if !strings.HasPrefix(defs[i].URL, "http://") && !strings.HasPrefix(defs[i].URL, "https://") { return nil, fmt.Errorf("invalid URL format for upstream: %s", defs[i].URL) } if defs[i].Weight <= 0 { return nil, fmt.Errorf("upstream weight must be a positive integer") } } cleanedUpstreams, err := json.Marshal(defs) if err != nil { return nil, fmt.Errorf("failed to marshal cleaned upstreams: %w", err) } return cleanedUpstreams, nil } // isValidGroupName checks if the group name is valid. func isValidGroupName(name string) bool { if name == "" { return false } // 允许使用小写字母、数字、下划线和中划线,长度在 3 到 30 个字符之间 match, _ := regexp.MatchString("^[a-z0-9_-]{3,30}$", name) return match } // isValidValidationEndpoint checks if the validation endpoint is a valid path. func isValidValidationEndpoint(endpoint string) bool { if endpoint == "" { return true } if !strings.HasPrefix(endpoint, "/") { return false } if strings.Contains(endpoint, "://") { return false } return true } // validateAndCleanConfig validates the group config against the GroupConfig struct and system-defined rules. func (s *Server) validateAndCleanConfig(configMap map[string]any) (map[string]any, error) { if configMap == nil { return nil, nil } // 1. Check for unknown fields by comparing against the GroupConfig struct definition. var tempGroupConfig models.GroupConfig groupConfigType := reflect.TypeOf(tempGroupConfig) validFields := make(map[string]bool) for i := 0; i < groupConfigType.NumField(); i++ { jsonTag := groupConfigType.Field(i).Tag.Get("json") fieldName := strings.Split(jsonTag, ",")[0] if fieldName != "" && fieldName != "-" { validFields[fieldName] = true } } for key := range configMap { if !validFields[key] { return nil, fmt.Errorf("unknown config field: '%s'", key) } } // 2. Validate the values of the provided fields using the central system settings validator. if err := s.SettingsManager.ValidateGroupConfigOverrides(configMap); err != nil { return nil, err } // 3. Unmarshal and marshal back to clean the map and ensure correct types. configBytes, err := json.Marshal(configMap) if err != nil { return nil, fmt.Errorf("failed to marshal config map: %w", err) } var validatedConfig models.GroupConfig if err := json.Unmarshal(configBytes, &validatedConfig); err != nil { return nil, fmt.Errorf("failed to unmarshal into validated config: %w", err) } validatedBytes, err := json.Marshal(validatedConfig) if err != nil { return nil, fmt.Errorf("failed to marshal validated config: %w", err) } var finalMap map[string]any if err := json.Unmarshal(validatedBytes, &finalMap); err != nil { return nil, fmt.Errorf("failed to unmarshal into final map: %w", err) } return finalMap, nil } // CreateGroup handles the creation of a new group. func (s *Server) CreateGroup(c *gin.Context) { var req models.Group if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) return } // Data Cleaning and Validation name := strings.TrimSpace(req.Name) if !isValidGroupName(name) { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "无效的分组名称。只能包含小写字母、数字、中划线或下划线,长度3-30位")) return } channelType := strings.TrimSpace(req.ChannelType) if !isValidChannelType(channelType) { supported := strings.Join(channel.GetChannels(), ", ") response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid channel type. Supported types are: %s", supported))) return } testModel := strings.TrimSpace(req.TestModel) if testModel == "" { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Test model is required")) return } cleanedUpstreams, err := validateAndCleanUpstreams(json.RawMessage(req.Upstreams)) if err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error())) return } cleanedConfig, err := s.validateAndCleanConfig(req.Config) if err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid config format: %v", err))) return } validationEndpoint := strings.TrimSpace(req.ValidationEndpoint) if !isValidValidationEndpoint(validationEndpoint) { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "无效的测试路径。如果提供,必须是以 / 开头的有效路径,且不能是完整的URL。")) return } group := models.Group{ Name: name, DisplayName: strings.TrimSpace(req.DisplayName), Description: strings.TrimSpace(req.Description), Upstreams: cleanedUpstreams, ChannelType: channelType, Sort: req.Sort, TestModel: testModel, ValidationEndpoint: validationEndpoint, ParamOverrides: req.ParamOverrides, Config: cleanedConfig, ProxyKeys: strings.TrimSpace(req.ProxyKeys), } if err := s.DB.Create(&group).Error; err != nil { response.Error(c, app_errors.ParseDBError(err)) return } if err := s.GroupManager.Invalidate(); err != nil { logrus.WithContext(c.Request.Context()).WithError(err).Error("failed to invalidate group cache") } response.Success(c, s.newGroupResponse(&group)) } // ListGroups handles listing all groups. func (s *Server) ListGroups(c *gin.Context) { var groups []models.Group if err := s.DB.Order("sort asc, id desc").Find(&groups).Error; err != nil { response.Error(c, app_errors.ParseDBError(err)) return } var groupResponses []GroupResponse for i := range groups { groupResponses = append(groupResponses, *s.newGroupResponse(&groups[i])) } response.Success(c, groupResponses) } // GroupUpdateRequest defines the payload for updating a group. // Using a dedicated struct avoids issues with zero values being ignored by GORM's Update. type GroupUpdateRequest struct { Name *string `json:"name,omitempty"` DisplayName *string `json:"display_name,omitempty"` Description *string `json:"description,omitempty"` Upstreams json.RawMessage `json:"upstreams"` ChannelType *string `json:"channel_type,omitempty"` Sort *int `json:"sort"` TestModel string `json:"test_model"` ValidationEndpoint *string `json:"validation_endpoint,omitempty"` ParamOverrides map[string]any `json:"param_overrides"` Config map[string]any `json:"config"` ProxyKeys *string `json:"proxy_keys,omitempty"` } // UpdateGroup handles updating an existing group. func (s *Server) UpdateGroup(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format")) return } var group models.Group if err := s.DB.First(&group, id).Error; err != nil { response.Error(c, app_errors.ParseDBError(err)) return } var req GroupUpdateRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) return } // Start a transaction tx := s.DB.Begin() if tx.Error != nil { response.Error(c, app_errors.ErrDatabase) return } defer tx.Rollback() // Rollback on panic // Apply updates from the request, with cleaning and validation if req.Name != nil { cleanedName := strings.TrimSpace(*req.Name) if !isValidGroupName(cleanedName) { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "无效的分组名称格式。只能包含小写字母、数字、中划线或下划线,长度3-30位")) return } group.Name = cleanedName } if req.DisplayName != nil { group.DisplayName = strings.TrimSpace(*req.DisplayName) } if req.Description != nil { group.Description = strings.TrimSpace(*req.Description) } if req.Upstreams != nil { cleanedUpstreams, err := validateAndCleanUpstreams(req.Upstreams) if err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error())) return } group.Upstreams = cleanedUpstreams } if req.ChannelType != nil { cleanedChannelType := strings.TrimSpace(*req.ChannelType) if !isValidChannelType(cleanedChannelType) { supported := strings.Join(channel.GetChannels(), ", ") response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid channel type. Supported types are: %s", supported))) return } group.ChannelType = cleanedChannelType } if req.Sort != nil { group.Sort = *req.Sort } if req.TestModel != "" { cleanedTestModel := strings.TrimSpace(req.TestModel) if cleanedTestModel == "" { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Test model cannot be empty or just spaces.")) return } group.TestModel = cleanedTestModel } if req.ParamOverrides != nil { group.ParamOverrides = req.ParamOverrides } if req.ValidationEndpoint != nil { validationEndpoint := strings.TrimSpace(*req.ValidationEndpoint) if !isValidValidationEndpoint(validationEndpoint) { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "无效的测试路径。如果提供,必须是以 / 开头的有效路径,且不能是完整的URL。")) return } group.ValidationEndpoint = validationEndpoint } if req.Config != nil { cleanedConfig, err := s.validateAndCleanConfig(req.Config) if err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid config format: %v", err))) return } group.Config = cleanedConfig } if req.ProxyKeys != nil { group.ProxyKeys = strings.TrimSpace(*req.ProxyKeys) } // Save the updated group object if err := tx.Save(&group).Error; err != nil { response.Error(c, app_errors.ParseDBError(err)) return } if err := tx.Commit().Error; err != nil { response.Error(c, app_errors.ErrDatabase) return } if err := s.GroupManager.Invalidate(); err != nil { logrus.WithContext(c.Request.Context()).WithError(err).Error("failed to invalidate group cache") } response.Success(c, s.newGroupResponse(&group)) } // GroupResponse defines the structure for a group response, excluding sensitive or large fields. type GroupResponse struct { ID uint `json:"id"` Name string `json:"name"` Endpoint string `json:"endpoint"` DisplayName string `json:"display_name"` Description string `json:"description"` Upstreams datatypes.JSON `json:"upstreams"` ChannelType string `json:"channel_type"` Sort int `json:"sort"` TestModel string `json:"test_model"` ValidationEndpoint string `json:"validation_endpoint"` ParamOverrides datatypes.JSONMap `json:"param_overrides"` Config datatypes.JSONMap `json:"config"` ProxyKeys string `json:"proxy_keys"` LastValidatedAt *time.Time `json:"last_validated_at"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } // newGroupResponse creates a new GroupResponse from a models.Group. func (s *Server) newGroupResponse(group *models.Group) *GroupResponse { appURL := s.SettingsManager.GetAppUrl() endpoint := "" if appURL != "" { u, err := url.Parse(appURL) if err == nil { u.Path = strings.TrimRight(u.Path, "/") + "/proxy/" + group.Name endpoint = u.String() } } return &GroupResponse{ ID: group.ID, Name: group.Name, Endpoint: endpoint, DisplayName: group.DisplayName, Description: group.Description, Upstreams: group.Upstreams, ChannelType: group.ChannelType, Sort: group.Sort, TestModel: group.TestModel, ValidationEndpoint: group.ValidationEndpoint, ParamOverrides: group.ParamOverrides, Config: group.Config, ProxyKeys: group.ProxyKeys, LastValidatedAt: group.LastValidatedAt, CreatedAt: group.CreatedAt, UpdatedAt: group.UpdatedAt, } } // DeleteGroup handles deleting a group. func (s *Server) DeleteGroup(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format")) return } // First, get all API keys for this group to clean up from memory store var apiKeys []models.APIKey if err := s.DB.Where("group_id = ?", id).Find(&apiKeys).Error; err != nil { response.Error(c, app_errors.ParseDBError(err)) return } // Extract key IDs for memory store cleanup var keyIDs []uint for _, key := range apiKeys { keyIDs = append(keyIDs, key.ID) } // Use a transaction to ensure atomicity tx := s.DB.Begin() if tx.Error != nil { response.Error(c, app_errors.ErrDatabase) return } defer func() { if r := recover(); r != nil { tx.Rollback() } }() // First check if the group exists var group models.Group if err := tx.First(&group, id).Error; err != nil { tx.Rollback() response.Error(c, app_errors.ParseDBError(err)) return } // Delete associated API keys first due to foreign key constraint if err := tx.Where("group_id = ?", id).Delete(&models.APIKey{}).Error; err != nil { tx.Rollback() response.Error(c, app_errors.ErrDatabase) return } // Then delete the group if err := tx.Delete(&models.Group{}, id).Error; err != nil { tx.Rollback() response.Error(c, app_errors.ParseDBError(err)) return } // Clean up memory store (Redis) within the transaction to ensure atomicity // If Redis cleanup fails, the entire transaction will be rolled back if len(keyIDs) > 0 { if err := s.KeyService.KeyProvider.RemoveKeysFromStore(uint(id), keyIDs); err != nil { tx.Rollback() logrus.WithFields(logrus.Fields{ "groupID": id, "keyCount": len(keyIDs), "error": err, }).Error("Failed to remove keys from memory store, rolling back transaction") response.Error(c, app_errors.NewAPIError(app_errors.ErrDatabase, "Failed to delete group: unable to clean up cache")) return } } // Commit the transaction only if both DB and Redis operations succeed if err := tx.Commit().Error; err != nil { tx.Rollback() response.Error(c, app_errors.ErrDatabase) return } if err := s.GroupManager.Invalidate(); err != nil { logrus.WithContext(c.Request.Context()).WithError(err).Error("failed to invalidate group cache") } response.Success(c, gin.H{"message": "Group and associated keys deleted successfully"}) } // ConfigOption represents a single configurable option for a group. type ConfigOption struct { Key string `json:"key"` Name string `json:"name"` Description string `json:"description"` DefaultValue any `json:"default_value"` } // GetGroupConfigOptions returns a list of available configuration options for groups. func (s *Server) GetGroupConfigOptions(c *gin.Context) { var options []ConfigOption // 1. Get all system setting definitions from the struct tags defaultSettings := utils.DefaultSystemSettings() settingDefinitions := utils.GenerateSettingsMetadata(&defaultSettings) defMap := make(map[string]models.SystemSettingInfo) for _, def := range settingDefinitions { defMap[def.Key] = def } // 2. Get current system setting values currentSettings := s.SettingsManager.GetSettings() currentSettingsValue := reflect.ValueOf(currentSettings) currentSettingsType := currentSettingsValue.Type() jsonToFieldMap := make(map[string]string) for i := 0; i < currentSettingsType.NumField(); i++ { field := currentSettingsType.Field(i) jsonTag := strings.Split(field.Tag.Get("json"), ",")[0] if jsonTag != "" { jsonToFieldMap[jsonTag] = field.Name } } // 3. Iterate over GroupConfig fields to maintain order and build the response groupConfigType := reflect.TypeOf(models.GroupConfig{}) for i := 0; i < groupConfigType.NumField(); i++ { field := groupConfigType.Field(i) jsonTag := field.Tag.Get("json") key := strings.Split(jsonTag, ",")[0] if key == "" || key == "-" { continue } if definition, ok := defMap[key]; ok { var defaultValue any if fieldName, ok := jsonToFieldMap[key]; ok { defaultValue = currentSettingsValue.FieldByName(fieldName).Interface() } option := ConfigOption{ Key: key, Name: definition.Name, Description: definition.Description, DefaultValue: defaultValue, } options = append(options, option) } } response.Success(c, options) } // KeyStats defines the statistics for API keys in a group. type KeyStats struct { TotalKeys int64 `json:"total_keys"` ActiveKeys int64 `json:"active_keys"` InvalidKeys int64 `json:"invalid_keys"` } // RequestStats defines the statistics for requests over a period. type RequestStats struct { TotalRequests int64 `json:"total_requests"` FailedRequests int64 `json:"failed_requests"` FailureRate float64 `json:"failure_rate"` } // GroupStatsResponse defines the complete statistics for a group. type GroupStatsResponse struct { KeyStats KeyStats `json:"key_stats"` HourlyStats RequestStats `json:"hourly_stats"` // 1 hour DailyStats RequestStats `json:"daily_stats"` // 24 hours WeeklyStats RequestStats `json:"weekly_stats"` // 7 days } // calculateRequestStats is a helper to compute request statistics. func calculateRequestStats(total, failed int64) RequestStats { stats := RequestStats{ TotalRequests: total, FailedRequests: failed, } if total > 0 { stats.FailureRate, _ = strconv.ParseFloat(fmt.Sprintf("%.4f", float64(failed)/float64(total)), 64) } return stats } // GetGroupStats handles retrieving detailed statistics for a specific group. func (s *Server) GetGroupStats(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format")) return } groupID := uint(id) // 1. 验证分组是否存在 var group models.Group if err := s.DB.First(&group, groupID).Error; err != nil { response.Error(c, app_errors.ParseDBError(err)) return } var resp GroupStatsResponse var wg sync.WaitGroup var mu sync.Mutex var errors []error // 并发执行所有统计查询 // 2. Key 统计 wg.Add(1) go func() { defer wg.Done() var totalKeys, activeKeys int64 if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalKeys).Error; err != nil { mu.Lock() errors = append(errors, fmt.Errorf("failed to get total keys: %w", err)) mu.Unlock() return } if err := s.DB.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, models.KeyStatusActive).Count(&activeKeys).Error; err != nil { mu.Lock() errors = append(errors, fmt.Errorf("failed to get active keys: %w", err)) mu.Unlock() return } mu.Lock() resp.KeyStats = KeyStats{ TotalKeys: totalKeys, ActiveKeys: activeKeys, InvalidKeys: totalKeys - activeKeys, } mu.Unlock() }() // 3. 1小时请求统计 (查询 request_logs 表) wg.Add(1) go func() { defer wg.Done() var total, failed int64 now := time.Now() oneHourAgo := now.Add(-1 * time.Hour) if err := s.DB.Model(&models.RequestLog{}).Where("group_id = ? AND timestamp BETWEEN ? AND ?", groupID, oneHourAgo, now).Count(&total).Error; err != nil { mu.Lock() errors = append(errors, fmt.Errorf("failed to get hourly total requests: %w", err)) mu.Unlock() return } if err := s.DB.Model(&models.RequestLog{}).Where("group_id = ? AND timestamp BETWEEN ? AND ? AND is_success = ?", groupID, oneHourAgo, now, false).Count(&failed).Error; err != nil { mu.Lock() errors = append(errors, fmt.Errorf("failed to get hourly failed requests: %w", err)) mu.Unlock() return } mu.Lock() resp.HourlyStats = calculateRequestStats(total, failed) mu.Unlock() }() // 4. 24小时和7天统计 (查询 group_hourly_stats 表) // 辅助函数,用于从 group_hourly_stats 查询 queryHourlyStats := func(duration time.Duration) (RequestStats, error) { var result struct { SuccessCount int64 FailureCount int64 } now := time.Now() // 结束时间为当前小时的整点,查询时不包含该小时 // 开始时间为结束时间减去统计周期 endTime := now.Truncate(time.Hour) startTime := endTime.Add(-duration) err := s.DB.Model(&models.GroupHourlyStat{}). Select("SUM(success_count) as success_count, SUM(failure_count) as failure_count"). Where("group_id = ? AND time >= ? AND time < ?", groupID, startTime, endTime). Scan(&result).Error if err != nil { return RequestStats{}, err } return calculateRequestStats(result.SuccessCount+result.FailureCount, result.FailureCount), nil } // 24小时统计 wg.Add(1) go func() { defer wg.Done() stats, err := queryHourlyStats(24 * time.Hour) if err != nil { mu.Lock() errors = append(errors, fmt.Errorf("failed to get daily stats: %w", err)) mu.Unlock() return } mu.Lock() resp.DailyStats = stats mu.Unlock() }() // 7天统计 wg.Add(1) go func() { defer wg.Done() stats, err := queryHourlyStats(7 * 24 * time.Hour) if err != nil { mu.Lock() errors = append(errors, fmt.Errorf("failed to get weekly stats: %w", err)) mu.Unlock() return } mu.Lock() resp.WeeklyStats = stats mu.Unlock() }() wg.Wait() if len(errors) > 0 { // 只记录第一个错误,但表明可能存在多个错误 logrus.WithContext(c.Request.Context()).WithError(errors[0]).Error("Errors occurred while fetching group stats") response.Error(c, app_errors.NewAPIError(app_errors.ErrDatabase, "Failed to retrieve some statistics")) return } response.Success(c, resp) } // List godoc func (s *Server) List(c *gin.Context) { var groups []models.Group if err := s.DB.Select("id, name,display_name").Find(&groups).Error; err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrDatabase, "无法获取分组列表")) return } response.Success(c, groups) }