Files
gpt-load/internal/handler/group_handler.go
2025-07-24 16:34:51 +08:00

765 lines
23 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package handler provides HTTP handlers for the application
package handler
import (
"encoding/json"
"fmt"
"net/url"
"sync"
app_errors "gpt-load/internal/errors"
"gpt-load/internal/models"
"gpt-load/internal/response"
"gpt-load/internal/utils"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"gpt-load/internal/channel"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gorm.io/datatypes"
)
// isValidChannelType checks if the channel type is valid by checking against the registered channels.
func isValidChannelType(channelType string) bool {
channels := channel.GetChannels()
for _, t := range channels {
if t == channelType {
return true
}
}
return false
}
// UpstreamDefinition defines the structure for an upstream in the request.
type UpstreamDefinition struct {
URL string `json:"url"`
Weight int `json:"weight"`
}
// validateAndCleanUpstreams validates and cleans the upstreams JSON.
func validateAndCleanUpstreams(upstreams json.RawMessage) (datatypes.JSON, error) {
if len(upstreams) == 0 {
return nil, fmt.Errorf("upstreams field is required")
}
var defs []UpstreamDefinition
if err := json.Unmarshal(upstreams, &defs); err != nil {
return nil, fmt.Errorf("invalid format for upstreams: %w", err)
}
if len(defs) == 0 {
return nil, fmt.Errorf("at least one upstream is required")
}
for i := range defs {
defs[i].URL = strings.TrimSpace(defs[i].URL)
if defs[i].URL == "" {
return nil, fmt.Errorf("upstream URL cannot be empty")
}
// Basic URL format validation
if !strings.HasPrefix(defs[i].URL, "http://") && !strings.HasPrefix(defs[i].URL, "https://") {
return nil, fmt.Errorf("invalid URL format for upstream: %s", defs[i].URL)
}
if defs[i].Weight <= 0 {
return nil, fmt.Errorf("upstream weight must be a positive integer")
}
}
cleanedUpstreams, err := json.Marshal(defs)
if err != nil {
return nil, fmt.Errorf("failed to marshal cleaned upstreams: %w", err)
}
return cleanedUpstreams, nil
}
// isValidGroupName checks if the group name is valid.
func isValidGroupName(name string) bool {
if name == "" {
return false
}
// 允许使用小写字母、数字、下划线和中划线,长度在 3 到 30 个字符之间
match, _ := regexp.MatchString("^[a-z0-9_-]{3,30}$", name)
return match
}
// isValidValidationEndpoint checks if the validation endpoint is a valid path.
func isValidValidationEndpoint(endpoint string) bool {
if endpoint == "" {
return true
}
if !strings.HasPrefix(endpoint, "/") {
return false
}
if strings.Contains(endpoint, "://") {
return false
}
return true
}
// validateAndCleanConfig validates the group config against the GroupConfig struct and system-defined rules.
func (s *Server) validateAndCleanConfig(configMap map[string]any) (map[string]any, error) {
if configMap == nil {
return nil, nil
}
// 1. Check for unknown fields by comparing against the GroupConfig struct definition.
var tempGroupConfig models.GroupConfig
groupConfigType := reflect.TypeOf(tempGroupConfig)
validFields := make(map[string]bool)
for i := 0; i < groupConfigType.NumField(); i++ {
jsonTag := groupConfigType.Field(i).Tag.Get("json")
fieldName := strings.Split(jsonTag, ",")[0]
if fieldName != "" && fieldName != "-" {
validFields[fieldName] = true
}
}
for key := range configMap {
if !validFields[key] {
return nil, fmt.Errorf("unknown config field: '%s'", key)
}
}
// 2. Validate the values of the provided fields using the central system settings validator.
if err := s.SettingsManager.ValidateGroupConfigOverrides(configMap); err != nil {
return nil, err
}
// 3. Unmarshal and marshal back to clean the map and ensure correct types.
configBytes, err := json.Marshal(configMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal config map: %w", err)
}
var validatedConfig models.GroupConfig
if err := json.Unmarshal(configBytes, &validatedConfig); err != nil {
return nil, fmt.Errorf("failed to unmarshal into validated config: %w", err)
}
validatedBytes, err := json.Marshal(validatedConfig)
if err != nil {
return nil, fmt.Errorf("failed to marshal validated config: %w", err)
}
var finalMap map[string]any
if err := json.Unmarshal(validatedBytes, &finalMap); err != nil {
return nil, fmt.Errorf("failed to unmarshal into final map: %w", err)
}
return finalMap, nil
}
// CreateGroup handles the creation of a new group.
func (s *Server) CreateGroup(c *gin.Context) {
var req models.Group
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
return
}
// Data Cleaning and Validation
name := strings.TrimSpace(req.Name)
if !isValidGroupName(name) {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "无效的分组名称。只能包含小写字母、数字、中划线或下划线长度3-30位"))
return
}
channelType := strings.TrimSpace(req.ChannelType)
if !isValidChannelType(channelType) {
supported := strings.Join(channel.GetChannels(), ", ")
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid channel type. Supported types are: %s", supported)))
return
}
testModel := strings.TrimSpace(req.TestModel)
if testModel == "" {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Test model is required"))
return
}
cleanedUpstreams, err := validateAndCleanUpstreams(json.RawMessage(req.Upstreams))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
return
}
cleanedConfig, err := s.validateAndCleanConfig(req.Config)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid config format: %v", err)))
return
}
validationEndpoint := strings.TrimSpace(req.ValidationEndpoint)
if !isValidValidationEndpoint(validationEndpoint) {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "无效的测试路径。如果提供,必须是以 / 开头的有效路径且不能是完整的URL。"))
return
}
group := models.Group{
Name: name,
DisplayName: strings.TrimSpace(req.DisplayName),
Description: strings.TrimSpace(req.Description),
Upstreams: cleanedUpstreams,
ChannelType: channelType,
Sort: req.Sort,
TestModel: testModel,
ValidationEndpoint: validationEndpoint,
ParamOverrides: req.ParamOverrides,
Config: cleanedConfig,
ProxyKeys: strings.TrimSpace(req.ProxyKeys),
}
if err := s.DB.Create(&group).Error; err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
if err := s.GroupManager.Invalidate(); err != nil {
logrus.WithContext(c.Request.Context()).WithError(err).Error("failed to invalidate group cache")
}
response.Success(c, s.newGroupResponse(&group))
}
// ListGroups handles listing all groups.
func (s *Server) ListGroups(c *gin.Context) {
var groups []models.Group
if err := s.DB.Order("sort asc, id desc").Find(&groups).Error; err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
var groupResponses []GroupResponse
for i := range groups {
groupResponses = append(groupResponses, *s.newGroupResponse(&groups[i]))
}
response.Success(c, groupResponses)
}
// GroupUpdateRequest defines the payload for updating a group.
// Using a dedicated struct avoids issues with zero values being ignored by GORM's Update.
type GroupUpdateRequest struct {
Name *string `json:"name,omitempty"`
DisplayName *string `json:"display_name,omitempty"`
Description *string `json:"description,omitempty"`
Upstreams json.RawMessage `json:"upstreams"`
ChannelType *string `json:"channel_type,omitempty"`
Sort *int `json:"sort"`
TestModel string `json:"test_model"`
ValidationEndpoint *string `json:"validation_endpoint,omitempty"`
ParamOverrides map[string]any `json:"param_overrides"`
Config map[string]any `json:"config"`
ProxyKeys *string `json:"proxy_keys,omitempty"`
}
// UpdateGroup handles updating an existing group.
func (s *Server) UpdateGroup(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format"))
return
}
var group models.Group
if err := s.DB.First(&group, id).Error; err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
var req GroupUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
return
}
// Start a transaction
tx := s.DB.Begin()
if tx.Error != nil {
response.Error(c, app_errors.ErrDatabase)
return
}
defer tx.Rollback() // Rollback on panic
// Apply updates from the request, with cleaning and validation
if req.Name != nil {
cleanedName := strings.TrimSpace(*req.Name)
if !isValidGroupName(cleanedName) {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "无效的分组名称格式。只能包含小写字母、数字、中划线或下划线长度3-30位"))
return
}
group.Name = cleanedName
}
if req.DisplayName != nil {
group.DisplayName = strings.TrimSpace(*req.DisplayName)
}
if req.Description != nil {
group.Description = strings.TrimSpace(*req.Description)
}
if req.Upstreams != nil {
cleanedUpstreams, err := validateAndCleanUpstreams(req.Upstreams)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
return
}
group.Upstreams = cleanedUpstreams
}
if req.ChannelType != nil {
cleanedChannelType := strings.TrimSpace(*req.ChannelType)
if !isValidChannelType(cleanedChannelType) {
supported := strings.Join(channel.GetChannels(), ", ")
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid channel type. Supported types are: %s", supported)))
return
}
group.ChannelType = cleanedChannelType
}
if req.Sort != nil {
group.Sort = *req.Sort
}
if req.TestModel != "" {
cleanedTestModel := strings.TrimSpace(req.TestModel)
if cleanedTestModel == "" {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Test model cannot be empty or just spaces."))
return
}
group.TestModel = cleanedTestModel
}
if req.ParamOverrides != nil {
group.ParamOverrides = req.ParamOverrides
}
if req.ValidationEndpoint != nil {
validationEndpoint := strings.TrimSpace(*req.ValidationEndpoint)
if !isValidValidationEndpoint(validationEndpoint) {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "无效的测试路径。如果提供,必须是以 / 开头的有效路径且不能是完整的URL。"))
return
}
group.ValidationEndpoint = validationEndpoint
}
if req.Config != nil {
cleanedConfig, err := s.validateAndCleanConfig(req.Config)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Invalid config format: %v", err)))
return
}
group.Config = cleanedConfig
}
if req.ProxyKeys != nil {
group.ProxyKeys = strings.TrimSpace(*req.ProxyKeys)
}
// Save the updated group object
if err := tx.Save(&group).Error; err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
if err := tx.Commit().Error; err != nil {
response.Error(c, app_errors.ErrDatabase)
return
}
if err := s.GroupManager.Invalidate(); err != nil {
logrus.WithContext(c.Request.Context()).WithError(err).Error("failed to invalidate group cache")
}
response.Success(c, s.newGroupResponse(&group))
}
// GroupResponse defines the structure for a group response, excluding sensitive or large fields.
type GroupResponse struct {
ID uint `json:"id"`
Name string `json:"name"`
Endpoint string `json:"endpoint"`
DisplayName string `json:"display_name"`
Description string `json:"description"`
Upstreams datatypes.JSON `json:"upstreams"`
ChannelType string `json:"channel_type"`
Sort int `json:"sort"`
TestModel string `json:"test_model"`
ValidationEndpoint string `json:"validation_endpoint"`
ParamOverrides datatypes.JSONMap `json:"param_overrides"`
Config datatypes.JSONMap `json:"config"`
ProxyKeys string `json:"proxy_keys"`
LastValidatedAt *time.Time `json:"last_validated_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// newGroupResponse creates a new GroupResponse from a models.Group.
func (s *Server) newGroupResponse(group *models.Group) *GroupResponse {
appURL := s.SettingsManager.GetAppUrl()
endpoint := ""
if appURL != "" {
u, err := url.Parse(appURL)
if err == nil {
u.Path = strings.TrimRight(u.Path, "/") + "/proxy/" + group.Name
endpoint = u.String()
}
}
return &GroupResponse{
ID: group.ID,
Name: group.Name,
Endpoint: endpoint,
DisplayName: group.DisplayName,
Description: group.Description,
Upstreams: group.Upstreams,
ChannelType: group.ChannelType,
Sort: group.Sort,
TestModel: group.TestModel,
ValidationEndpoint: group.ValidationEndpoint,
ParamOverrides: group.ParamOverrides,
Config: group.Config,
ProxyKeys: group.ProxyKeys,
LastValidatedAt: group.LastValidatedAt,
CreatedAt: group.CreatedAt,
UpdatedAt: group.UpdatedAt,
}
}
// DeleteGroup handles deleting a group.
func (s *Server) DeleteGroup(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format"))
return
}
// First, get all API keys for this group to clean up from memory store
var apiKeys []models.APIKey
if err := s.DB.Where("group_id = ?", id).Find(&apiKeys).Error; err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
// Extract key IDs for memory store cleanup
var keyIDs []uint
for _, key := range apiKeys {
keyIDs = append(keyIDs, key.ID)
}
// Use a transaction to ensure atomicity
tx := s.DB.Begin()
if tx.Error != nil {
response.Error(c, app_errors.ErrDatabase)
return
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// First check if the group exists
var group models.Group
if err := tx.First(&group, id).Error; err != nil {
tx.Rollback()
response.Error(c, app_errors.ParseDBError(err))
return
}
// Delete associated API keys first due to foreign key constraint
if err := tx.Where("group_id = ?", id).Delete(&models.APIKey{}).Error; err != nil {
tx.Rollback()
response.Error(c, app_errors.ErrDatabase)
return
}
// Then delete the group
if err := tx.Delete(&models.Group{}, id).Error; err != nil {
tx.Rollback()
response.Error(c, app_errors.ParseDBError(err))
return
}
// Clean up memory store (Redis) within the transaction to ensure atomicity
// If Redis cleanup fails, the entire transaction will be rolled back
if len(keyIDs) > 0 {
if err := s.KeyService.KeyProvider.RemoveKeysFromStore(uint(id), keyIDs); err != nil {
tx.Rollback()
logrus.WithFields(logrus.Fields{
"groupID": id,
"keyCount": len(keyIDs),
"error": err,
}).Error("Failed to remove keys from memory store, rolling back transaction")
response.Error(c, app_errors.NewAPIError(app_errors.ErrDatabase,
"Failed to delete group: unable to clean up cache"))
return
}
}
// Commit the transaction only if both DB and Redis operations succeed
if err := tx.Commit().Error; err != nil {
tx.Rollback()
response.Error(c, app_errors.ErrDatabase)
return
}
if err := s.GroupManager.Invalidate(); err != nil {
logrus.WithContext(c.Request.Context()).WithError(err).Error("failed to invalidate group cache")
}
response.Success(c, gin.H{"message": "Group and associated keys deleted successfully"})
}
// ConfigOption represents a single configurable option for a group.
type ConfigOption struct {
Key string `json:"key"`
Name string `json:"name"`
Description string `json:"description"`
DefaultValue any `json:"default_value"`
}
// GetGroupConfigOptions returns a list of available configuration options for groups.
func (s *Server) GetGroupConfigOptions(c *gin.Context) {
var options []ConfigOption
// 1. Get all system setting definitions from the struct tags
defaultSettings := utils.DefaultSystemSettings()
settingDefinitions := utils.GenerateSettingsMetadata(&defaultSettings)
defMap := make(map[string]models.SystemSettingInfo)
for _, def := range settingDefinitions {
defMap[def.Key] = def
}
// 2. Get current system setting values
currentSettings := s.SettingsManager.GetSettings()
currentSettingsValue := reflect.ValueOf(currentSettings)
currentSettingsType := currentSettingsValue.Type()
jsonToFieldMap := make(map[string]string)
for i := 0; i < currentSettingsType.NumField(); i++ {
field := currentSettingsType.Field(i)
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
if jsonTag != "" {
jsonToFieldMap[jsonTag] = field.Name
}
}
// 3. Iterate over GroupConfig fields to maintain order and build the response
groupConfigType := reflect.TypeOf(models.GroupConfig{})
for i := 0; i < groupConfigType.NumField(); i++ {
field := groupConfigType.Field(i)
jsonTag := field.Tag.Get("json")
key := strings.Split(jsonTag, ",")[0]
if key == "" || key == "-" {
continue
}
if definition, ok := defMap[key]; ok {
var defaultValue any
if fieldName, ok := jsonToFieldMap[key]; ok {
defaultValue = currentSettingsValue.FieldByName(fieldName).Interface()
}
option := ConfigOption{
Key: key,
Name: definition.Name,
Description: definition.Description,
DefaultValue: defaultValue,
}
options = append(options, option)
}
}
response.Success(c, options)
}
// KeyStats defines the statistics for API keys in a group.
type KeyStats struct {
TotalKeys int64 `json:"total_keys"`
ActiveKeys int64 `json:"active_keys"`
InvalidKeys int64 `json:"invalid_keys"`
}
// RequestStats defines the statistics for requests over a period.
type RequestStats struct {
TotalRequests int64 `json:"total_requests"`
FailedRequests int64 `json:"failed_requests"`
FailureRate float64 `json:"failure_rate"`
}
// GroupStatsResponse defines the complete statistics for a group.
type GroupStatsResponse struct {
KeyStats KeyStats `json:"key_stats"`
HourlyStats RequestStats `json:"hourly_stats"` // 1 hour
DailyStats RequestStats `json:"daily_stats"` // 24 hours
WeeklyStats RequestStats `json:"weekly_stats"` // 7 days
}
// calculateRequestStats is a helper to compute request statistics.
func calculateRequestStats(total, failed int64) RequestStats {
stats := RequestStats{
TotalRequests: total,
FailedRequests: failed,
}
if total > 0 {
stats.FailureRate, _ = strconv.ParseFloat(fmt.Sprintf("%.4f", float64(failed)/float64(total)), 64)
}
return stats
}
// GetGroupStats handles retrieving detailed statistics for a specific group.
func (s *Server) GetGroupStats(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format"))
return
}
groupID := uint(id)
// 1. 验证分组是否存在
var group models.Group
if err := s.DB.First(&group, groupID).Error; err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
var resp GroupStatsResponse
var wg sync.WaitGroup
var mu sync.Mutex
var errors []error
// 并发执行所有统计查询
// 2. Key 统计
wg.Add(1)
go func() {
defer wg.Done()
var totalKeys, activeKeys int64
if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalKeys).Error; err != nil {
mu.Lock()
errors = append(errors, fmt.Errorf("failed to get total keys: %w", err))
mu.Unlock()
return
}
if err := s.DB.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, models.KeyStatusActive).Count(&activeKeys).Error; err != nil {
mu.Lock()
errors = append(errors, fmt.Errorf("failed to get active keys: %w", err))
mu.Unlock()
return
}
mu.Lock()
resp.KeyStats = KeyStats{
TotalKeys: totalKeys,
ActiveKeys: activeKeys,
InvalidKeys: totalKeys - activeKeys,
}
mu.Unlock()
}()
// 3. 1小时请求统计 (查询 request_logs 表)
wg.Add(1)
go func() {
defer wg.Done()
var total, failed int64
now := time.Now()
oneHourAgo := now.Add(-1 * time.Hour)
if err := s.DB.Model(&models.RequestLog{}).Where("group_id = ? AND timestamp BETWEEN ? AND ?", groupID, oneHourAgo, now).Count(&total).Error; err != nil {
mu.Lock()
errors = append(errors, fmt.Errorf("failed to get hourly total requests: %w", err))
mu.Unlock()
return
}
if err := s.DB.Model(&models.RequestLog{}).Where("group_id = ? AND timestamp BETWEEN ? AND ? AND is_success = ?", groupID, oneHourAgo, now, false).Count(&failed).Error; err != nil {
mu.Lock()
errors = append(errors, fmt.Errorf("failed to get hourly failed requests: %w", err))
mu.Unlock()
return
}
mu.Lock()
resp.HourlyStats = calculateRequestStats(total, failed)
mu.Unlock()
}()
// 4. 24小时和7天统计 (查询 group_hourly_stats 表)
// 辅助函数,用于从 group_hourly_stats 查询
queryHourlyStats := func(duration time.Duration) (RequestStats, error) {
var result struct {
SuccessCount int64
FailureCount int64
}
now := time.Now()
// 结束时间为当前小时的整点,查询时不包含该小时
// 开始时间为结束时间减去统计周期
endTime := now.Truncate(time.Hour)
startTime := endTime.Add(-duration)
err := s.DB.Model(&models.GroupHourlyStat{}).
Select("SUM(success_count) as success_count, SUM(failure_count) as failure_count").
Where("group_id = ? AND time >= ? AND time < ?", groupID, startTime, endTime).
Scan(&result).Error
if err != nil {
return RequestStats{}, err
}
return calculateRequestStats(result.SuccessCount+result.FailureCount, result.FailureCount), nil
}
// 24小时统计
wg.Add(1)
go func() {
defer wg.Done()
stats, err := queryHourlyStats(24 * time.Hour)
if err != nil {
mu.Lock()
errors = append(errors, fmt.Errorf("failed to get daily stats: %w", err))
mu.Unlock()
return
}
mu.Lock()
resp.DailyStats = stats
mu.Unlock()
}()
// 7天统计
wg.Add(1)
go func() {
defer wg.Done()
stats, err := queryHourlyStats(7 * 24 * time.Hour)
if err != nil {
mu.Lock()
errors = append(errors, fmt.Errorf("failed to get weekly stats: %w", err))
mu.Unlock()
return
}
mu.Lock()
resp.WeeklyStats = stats
mu.Unlock()
}()
wg.Wait()
if len(errors) > 0 {
// 只记录第一个错误,但表明可能存在多个错误
logrus.WithContext(c.Request.Context()).WithError(errors[0]).Error("Errors occurred while fetching group stats")
response.Error(c, app_errors.NewAPIError(app_errors.ErrDatabase, "Failed to retrieve some statistics"))
return
}
response.Success(c, resp)
}
// List godoc
func (s *Server) List(c *gin.Context) {
var groups []models.Group
if err := s.DB.Select("id, name,display_name").Find(&groups).Error; err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrDatabase, "无法获取分组列表"))
return
}
response.Success(c, groups)
}