diff --git a/cmd/gpt-load/main.go b/cmd/gpt-load/main.go index a3ecae3..9dec6d1 100644 --- a/cmd/gpt-load/main.go +++ b/cmd/gpt-load/main.go @@ -14,6 +14,7 @@ import ( "syscall" "time" + "gpt-load/internal/channel" "gpt-load/internal/config" "gpt-load/internal/db" "gpt-load/internal/handler" @@ -74,15 +75,28 @@ func main() { go startRequestLogger(database, requestLogChan, &wg) // --- + // --- Service Initialization --- + taskService := services.NewTaskService() + channelFactory := channel.NewFactory(settingsManager) + keyValidatorService := services.NewKeyValidatorService(database, channelFactory) + + keyManualValidationService := services.NewKeyManualValidationService(database, keyValidatorService, taskService, settingsManager) + keyCronService := services.NewKeyCronService(database, keyValidatorService, settingsManager) + keyCronService.Start() + defer keyCronService.Stop() + + keyService := services.NewKeyService(database) + // --- + // Create proxy server - proxyServer, err := proxy.NewProxyServer(database, requestLogChan) + proxyServer, err := proxy.NewProxyServer(database, channelFactory, requestLogChan) if err != nil { logrus.Fatalf("Failed to create proxy server: %v", err) } defer proxyServer.Close() // Create handlers - serverHandler := handler.NewServer(database, configManager) + serverHandler := handler.NewServer(database, configManager, keyValidatorService, keyManualValidationService, taskService, keyService) logCleanupHandler := handler.NewLogCleanupHandler(logCleanupService) // Setup routes using the new router package diff --git a/go.mod b/go.mod index 6feb398..4c5a5e7 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/gin-contrib/static v1.1.5 github.com/gin-gonic/gin v1.10.1 github.com/go-sql-driver/mysql v1.8.1 + github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/sirupsen/logrus v1.9.3 gorm.io/datatypes v1.2.1 diff --git a/go.sum b/go.sum index 116e32c..33215af 100644 --- a/go.sum +++ b/go.sum @@ -41,6 +41,8 @@ github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EO github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= diff --git a/internal/channel/base_channel.go b/internal/channel/base_channel.go index 5f3bed9..fa6a1b0 100644 --- a/internal/channel/base_channel.go +++ b/internal/channel/base_channel.go @@ -8,38 +8,65 @@ import ( "net/http/httputil" "net/url" "strings" - "sync/atomic" + "sync" "github.com/gin-gonic/gin" - "github.com/gin-gonic/gin/binding" "github.com/sirupsen/logrus" ) +// UpstreamInfo holds the information for a single upstream server, including its weight. +type UpstreamInfo struct { + URL *url.URL + Weight int + CurrentWeight int +} + // BaseChannel provides common functionality for channel proxies. type BaseChannel struct { - Name string - Upstreams []*url.URL - HTTPClient *http.Client - roundRobin uint64 + Name string + Upstreams []UpstreamInfo + HTTPClient *http.Client + upstreamLock sync.Mutex } // RequestModifier is a function that can modify the request before it's sent. type RequestModifier func(req *http.Request, key *models.APIKey) -// getUpstreamURL selects an upstream URL using round-robin. +// getUpstreamURL selects an upstream URL using a smooth weighted round-robin algorithm. func (b *BaseChannel) getUpstreamURL() *url.URL { + b.upstreamLock.Lock() + defer b.upstreamLock.Unlock() + if len(b.Upstreams) == 0 { return nil } if len(b.Upstreams) == 1 { - return b.Upstreams[0] + return b.Upstreams[0].URL } - index := atomic.AddUint64(&b.roundRobin, 1) - 1 - return b.Upstreams[index%uint64(len(b.Upstreams))] + + totalWeight := 0 + var best *UpstreamInfo + + for i := range b.Upstreams { + up := &b.Upstreams[i] + totalWeight += up.Weight + up.CurrentWeight += up.Weight + + if best == nil || up.CurrentWeight > best.CurrentWeight { + best = up + } + } + + if best == nil { + return b.Upstreams[0].URL // 降级到第一个可用的 + } + + best.CurrentWeight -= totalWeight + return best.URL } // ProcessRequest handles the common logic of processing and forwarding a request. -func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier) error { +func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier, ch ChannelProxy) error { upstreamURL := b.getUpstreamURL() if upstreamURL == nil { return fmt.Errorf("no upstream URL configured for channel %s", b.Name) @@ -78,7 +105,7 @@ func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modi } // Check if the client request is for a streaming endpoint - if isStreamingRequest(c) { + if ch.IsStreamingRequest(c) { return b.handleStreaming(c, proxy) } @@ -87,6 +114,9 @@ func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modi } func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReverseProxy) error { + var wg sync.WaitGroup + wg.Add(1) + c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") @@ -96,13 +126,12 @@ func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReversePro pr, pw := io.Pipe() defer pr.Close() - // Create a new request with the pipe reader as the body - // This is a bit of a hack to get ReverseProxy to stream req := c.Request.Clone(c.Request.Context()) req.Body = pr // Start the proxy in a goroutine go func() { + defer wg.Done() defer pw.Close() proxy.ServeHTTP(c.Writer, req) }() @@ -111,32 +140,16 @@ func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReversePro _, err := io.Copy(pw, c.Request.Body) if err != nil { logrus.Errorf("Error copying request body to pipe: %v", err) + wg.Wait() // Wait for the goroutine to finish even if copy fails return err } + // Wait for the proxy to finish + wg.Wait() + return nil } -// isStreamingRequest checks if the request is for a streaming response. -func isStreamingRequest(c *gin.Context) bool { - // For Gemini, streaming is indicated by the path. - if strings.Contains(c.Request.URL.Path, ":streamGenerateContent") { - return true - } - - // For OpenAI, streaming is indicated by a "stream": true field in the JSON body. - // We use ShouldBindBodyWith to check the body without consuming it, so it can be read again by the proxy. - type streamPayload struct { - Stream bool `json:"stream"` - } - var p streamPayload - if err := c.ShouldBindBodyWith(&p, binding.JSON); err == nil { - return p.Stream - } - - return false -} - // singleJoiningSlash joins two URL paths with a single slash. func singleJoiningSlash(a, b string) string { aslash := strings.HasSuffix(a, "/") diff --git a/internal/channel/channel.go b/internal/channel/channel.go index 9db511a..4317cfe 100644 --- a/internal/channel/channel.go +++ b/internal/channel/channel.go @@ -1,6 +1,7 @@ package channel import ( + "context" "gpt-load/internal/models" "github.com/gin-gonic/gin" @@ -11,4 +12,10 @@ type ChannelProxy interface { // Handle takes a context, an API key, and the original request, // then forwards the request to the upstream service. Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error -} \ No newline at end of file + + // ValidateKey checks if the given API key is valid. + ValidateKey(ctx context.Context, key string) (bool, error) + + // IsStreamingRequest checks if the request is for a streaming response. + IsStreamingRequest(c *gin.Context) bool +} diff --git a/internal/channel/factory.go b/internal/channel/factory.go index 892838c..c3e38dd 100644 --- a/internal/channel/factory.go +++ b/internal/channel/factory.go @@ -1,6 +1,7 @@ package channel import ( + "encoding/json" "fmt" "gpt-load/internal/config" "gpt-load/internal/models" @@ -11,36 +12,61 @@ import ( "gorm.io/datatypes" ) +// Factory is responsible for creating channel proxies. +type Factory struct { + settingsManager *config.SystemSettingsManager +} + +// NewFactory creates a new channel factory. +func NewFactory(settingsManager *config.SystemSettingsManager) *Factory { + return &Factory{ + settingsManager: settingsManager, + } +} + // GetChannel returns a channel proxy based on the group's channel type. -func GetChannel(group *models.Group) (ChannelProxy, error) { +func (f *Factory) GetChannel(group *models.Group) (ChannelProxy, error) { switch group.ChannelType { case "openai": - return NewOpenAIChannel(group.Upstreams, group.Config) + return f.NewOpenAIChannel(group) case "gemini": - return NewGeminiChannel(group.Upstreams, group.Config) + return f.NewGeminiChannel(group) default: return nil, fmt.Errorf("unsupported channel type: %s", group.ChannelType) } } -// newBaseChannelWithUpstreams is a helper function to create and configure a BaseChannel. -func newBaseChannelWithUpstreams(name string, upstreams []string, groupConfig datatypes.JSONMap) (BaseChannel, error) { - if len(upstreams) == 0 { - return BaseChannel{}, fmt.Errorf("at least one upstream is required for %s channel", name) +// newBaseChannel is a helper function to create and configure a BaseChannel. +func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, groupConfig datatypes.JSONMap) (*BaseChannel, error) { + type upstreamDef struct { + URL string `json:"url"` + Weight int `json:"weight"` } - var upstreamURLs []*url.URL - for _, us := range upstreams { - u, err := url.Parse(us) + var defs []upstreamDef + if err := json.Unmarshal(upstreamsJSON, &defs); err != nil { + return nil, fmt.Errorf("failed to unmarshal upstreams for %s channel: %w", name, err) + } + + if len(defs) == 0 { + return nil, fmt.Errorf("at least one upstream is required for %s channel", name) + } + + var upstreamInfos []UpstreamInfo + for _, def := range defs { + u, err := url.Parse(def.URL) if err != nil { - return BaseChannel{}, fmt.Errorf("failed to parse upstream url '%s' for %s channel: %w", us, name, err) + return nil, fmt.Errorf("failed to parse upstream url '%s' for %s channel: %w", def.URL, name, err) } - upstreamURLs = append(upstreamURLs, u) + weight := def.Weight + if weight <= 0 { + weight = 1 // Default weight to 1 if not specified or invalid + } + upstreamInfos = append(upstreamInfos, UpstreamInfo{URL: u, Weight: weight}) } // Get effective settings by merging system and group configs - settingsManager := config.GetSystemSettingsManager() - effectiveSettings := settingsManager.GetEffectiveConfig(groupConfig) + effectiveSettings := f.settingsManager.GetEffectiveConfig(groupConfig) // Configure the HTTP client with the effective timeouts httpClient := &http.Client{ @@ -50,9 +76,9 @@ func newBaseChannelWithUpstreams(name string, upstreams []string, groupConfig da Timeout: time.Duration(effectiveSettings.RequestTimeout) * time.Second, } - return BaseChannel{ + return &BaseChannel{ Name: name, - Upstreams: upstreamURLs, + Upstreams: upstreamInfos, HTTPClient: httpClient, }, nil } diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index 79d2880..0353eae 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -1,19 +1,22 @@ package channel import ( + "context" + "fmt" "gpt-load/internal/models" "net/http" + "strings" + "github.com/gin-gonic/gin" - "gorm.io/datatypes" ) type GeminiChannel struct { - BaseChannel + *BaseChannel } -func NewGeminiChannel(upstreams []string, config datatypes.JSONMap) (*GeminiChannel, error) { - base, err := newBaseChannelWithUpstreams("gemini", upstreams, config) +func (f *Factory) NewGeminiChannel(group *models.Group) (*GeminiChannel, error) { + base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config) if err != nil { return nil, err } @@ -29,5 +32,40 @@ func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo q.Set("key", key.KeyValue) req.URL.RawQuery = q.Encode() } - return ch.ProcessRequest(c, apiKey, modifier) + return ch.ProcessRequest(c, apiKey, modifier, ch) +} + +// ValidateKey checks if the given API key is valid by making a request to the models endpoint. +func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) { + upstreamURL := ch.getUpstreamURL() + if upstreamURL == nil { + return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name) + } + + // Construct the request URL for listing models. + reqURL := fmt.Sprintf("%s/v1beta/models?key=%s", upstreamURL.String(), key) + + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + return false, fmt.Errorf("failed to create validation request: %w", err) + } + + resp, err := ch.HTTPClient.Do(req) + if err != nil { + return false, fmt.Errorf("failed to send validation request: %w", err) + } + defer resp.Body.Close() + + // A 200 OK status code indicates the key is valid. + return resp.StatusCode == http.StatusOK, nil +} + +// IsStreamingRequest checks if the request is for a streaming response. +func (ch *GeminiChannel) IsStreamingRequest(c *gin.Context) bool { + // For Gemini, streaming is indicated by the path containing streaming keywords + path := c.Request.URL.Path + return strings.Contains(path, ":streamGenerateContent") || + strings.Contains(path, "streamGenerateContent") || + strings.Contains(path, ":stream") || + strings.Contains(path, "/stream") } diff --git a/internal/channel/openai_channel.go b/internal/channel/openai_channel.go index 85e3063..d2dfe1d 100644 --- a/internal/channel/openai_channel.go +++ b/internal/channel/openai_channel.go @@ -1,19 +1,21 @@ package channel import ( + "context" + "fmt" "gpt-load/internal/models" "net/http" "github.com/gin-gonic/gin" - "gorm.io/datatypes" + "github.com/gin-gonic/gin/binding" ) type OpenAIChannel struct { - BaseChannel + *BaseChannel } -func NewOpenAIChannel(upstreams []string, config datatypes.JSONMap) (*OpenAIChannel, error) { - base, err := newBaseChannelWithUpstreams("openai", upstreams, config) +func (f *Factory) NewOpenAIChannel(group *models.Group) (*OpenAIChannel, error) { + base, err := f.newBaseChannel("openai", group.Upstreams, group.Config) if err != nil { return nil, err } @@ -27,5 +29,46 @@ func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo modifier := func(req *http.Request, key *models.APIKey) { req.Header.Set("Authorization", "Bearer "+key.KeyValue) } - return ch.ProcessRequest(c, apiKey, modifier) + return ch.ProcessRequest(c, apiKey, modifier, ch) +} + +// ValidateKey checks if the given API key is valid by making a request to the models endpoint. +func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) { + upstreamURL := ch.getUpstreamURL() + if upstreamURL == nil { + return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name) + } + + // Construct the request URL for listing models, a common endpoint for key validation. + reqURL := upstreamURL.String() + "/v1/models" + + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + return false, fmt.Errorf("failed to create validation request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+key) + + resp, err := ch.HTTPClient.Do(req) + if err != nil { + return false, fmt.Errorf("failed to send validation request: %w", err) + } + defer resp.Body.Close() + + // A 200 OK status code indicates the key is valid. + // Other status codes (e.g., 401 Unauthorized) indicate an invalid key. + return resp.StatusCode == http.StatusOK, nil +} + +// IsStreamingRequest checks if the request is for a streaming response. +func (ch *OpenAIChannel) IsStreamingRequest(c *gin.Context) bool { + // For OpenAI, streaming is indicated by a "stream": true field in the JSON body. + // We use ShouldBindBodyWith to check the body without consuming it, so it can be read again by the proxy. + type streamPayload struct { + Stream bool `json:"stream"` + } + var p streamPayload + if err := c.ShouldBindBodyWith(&p, binding.JSON); err == nil { + return p.Stream + } + return false } diff --git a/internal/config/manager.go b/internal/config/manager.go index d468c28..39b49a3 100644 --- a/internal/config/manager.go +++ b/internal/config/manager.go @@ -276,3 +276,16 @@ func getEnvOrDefault(key, defaultValue string) string { } return defaultValue } + +// GetInt is a helper function for SystemSettingsManager to get an integer value with a default. +func (s *SystemSettingsManager) GetInt(key string, defaultValue int) int { + s.mu.RLock() + defer s.mu.RUnlock() + + if valStr, ok := s.settingsCache[key]; ok { + if valInt, err := strconv.Atoi(valStr); err == nil { + return valInt + } + } + return defaultValue +} diff --git a/internal/config/system_settings.go b/internal/config/system_settings.go index fecd1ac..645dc4a 100644 --- a/internal/config/system_settings.go +++ b/internal/config/system_settings.go @@ -34,6 +34,11 @@ type SystemSettings struct { // 请求日志配置(数据库日志) RequestLogRetentionDays int `json:"request_log_retention_days" default:"30" name:"日志保留天数" category:"日志配置" desc:"请求日志在数据库中的保留天数" validate:"min=1"` + + // 密钥验证配置 + KeyValidationIntervalMinutes int `json:"key_validation_interval_minutes" default:"60" name:"定时验证周期" category:"密钥验证" desc:"后台定时验证密钥的默认周期(分钟)" validate:"min=5"` + KeyValidationConcurrency int `json:"key_validation_concurrency" default:"10" name:"验证并发数" category:"密钥验证" desc:"执行密钥验证时的并发 goroutine 数量" validate:"min=1,max=100"` + KeyValidationTaskTimeoutMinutes int `json:"key_validation_task_timeout_minutes" default:"60" name:"手动验证超时" category:"密钥验证" desc:"手动触发的全量验证任务的超时时间(分钟)" validate:"min=10"` } // GenerateSettingsMetadata 使用反射从 SystemSettings 结构体动态生成元数据 @@ -105,8 +110,9 @@ func DefaultSystemSettings() SystemSettings { // SystemSettingsManager 管理系统配置 type SystemSettingsManager struct { - settings SystemSettings - mu sync.RWMutex + settings SystemSettings + settingsCache map[string]string // Cache for raw string values + mu sync.RWMutex } var globalSystemSettings *SystemSettingsManager @@ -169,6 +175,8 @@ func (sm *SystemSettingsManager) LoadFromDatabase() error { sm.mu.Lock() defer sm.mu.Unlock() + sm.settingsCache = settingsMap + // 使用默认值,然后用数据库中的值覆盖 sm.settings = DefaultSystemSettings() sm.mapToStruct(settingsMap, &sm.settings) @@ -334,6 +342,8 @@ func (sm *SystemSettingsManager) DisplayCurrentSettings() { logrus.Infof(" Request timeouts: request=%ds, response=%ds, idle_conn=%ds", sm.settings.RequestTimeout, sm.settings.ResponseTimeout, sm.settings.IdleConnTimeout) logrus.Infof(" Request log retention: %d days", sm.settings.RequestLogRetentionDays) + logrus.Infof(" Key validation: interval=%dmin, concurrency=%d, task_timeout=%dmin", + sm.settings.KeyValidationIntervalMinutes, sm.settings.KeyValidationConcurrency, sm.settings.KeyValidationTaskTimeoutMinutes) } // 辅助方法 diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 6353e57..f20122c 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -31,6 +31,8 @@ var ( ErrDatabase = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "DATABASE_ERROR", Message: "Database operation failed"} ErrUnauthorized = &APIError{HTTPStatus: http.StatusUnauthorized, Code: "UNAUTHORIZED", Message: "Authentication failed"} ErrForbidden = &APIError{HTTPStatus: http.StatusForbidden, Code: "FORBIDDEN", Message: "You do not have permission to access this resource"} + ErrTaskInProgress = &APIError{HTTPStatus: http.StatusConflict, Code: "TASK_IN_PROGRESS", Message: "A task is already in progress"} + ErrBadGateway = &APIError{HTTPStatus: http.StatusBadGateway, Code: "BAD_GATEWAY", Message: "Upstream service error"} ) // NewAPIError creates a new APIError with a custom message. diff --git a/internal/handler/group_handler.go b/internal/handler/group_handler.go index 232f584..85bfd3e 100644 --- a/internal/handler/group_handler.go +++ b/internal/handler/group_handler.go @@ -3,6 +3,7 @@ package handler import ( "encoding/json" + "fmt" app_errors "gpt-load/internal/errors" "gpt-load/internal/models" "gpt-load/internal/response" @@ -10,6 +11,7 @@ import ( "strconv" "github.com/gin-gonic/gin" + "gorm.io/datatypes" ) // isValidGroupName checks if the group name is valid. @@ -22,6 +24,50 @@ func isValidGroupName(name string) bool { return match } +// validateAndCleanConfig validates the group config against the GroupConfig struct. +func validateAndCleanConfig(configMap map[string]interface{}) (map[string]interface{}, error) { + if configMap == nil { + return nil, nil + } + + configBytes, err := json.Marshal(configMap) + if err != nil { + return nil, err + } + + var validatedConfig models.GroupConfig + if err := json.Unmarshal(configBytes, &validatedConfig); err != nil { + return nil, err + } + + // 验证配置项的合理范围 + if validatedConfig.BlacklistThreshold != nil && *validatedConfig.BlacklistThreshold < 0 { + return nil, fmt.Errorf("blacklist_threshold must be >= 0") + } + if validatedConfig.MaxRetries != nil && (*validatedConfig.MaxRetries < 0 || *validatedConfig.MaxRetries > 10) { + return nil, fmt.Errorf("max_retries must be between 0 and 10") + } + if validatedConfig.RequestTimeout != nil && (*validatedConfig.RequestTimeout < 1 || *validatedConfig.RequestTimeout > 3600) { + return nil, fmt.Errorf("request_timeout must be between 1 and 3600 seconds") + } + if validatedConfig.KeyValidationIntervalMinutes != nil && (*validatedConfig.KeyValidationIntervalMinutes < 5 || *validatedConfig.KeyValidationIntervalMinutes > 1440) { + return nil, fmt.Errorf("key_validation_interval_minutes must be between 5 and 1440 minutes") + } + + // Marshal back to a map to remove any fields not in GroupConfig + validatedBytes, err := json.Marshal(validatedConfig) + if err != nil { + return nil, err + } + + var cleanedMap map[string]interface{} + if err := json.Unmarshal(validatedBytes, &cleanedMap); err != nil { + return nil, err + } + + return cleanedMap, nil +} + // CreateGroup handles the creation of a new group. func (s *Server) CreateGroup(c *gin.Context) { var group models.Group @@ -43,6 +89,17 @@ func (s *Server) CreateGroup(c *gin.Context) { response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Channel type is required")) return } + if group.TestModel == "" { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Test model is required")) + return + } + + cleanedConfig, err := validateAndCleanConfig(group.Config) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid config format")) + return + } + group.Config = cleanedConfig if err := s.DB.Create(&group).Error; err != nil { response.Error(c, app_errors.ParseDBError(err)) @@ -62,6 +119,20 @@ func (s *Server) ListGroups(c *gin.Context) { response.Success(c, groups) } +// GroupUpdateRequest defines the payload for updating a group. +// Using a dedicated struct avoids issues with zero values being ignored by GORM's Update. +type GroupUpdateRequest struct { + Name string `json:"name"` + DisplayName string `json:"display_name"` + Description string `json:"description"` + Upstreams json.RawMessage `json:"upstreams"` + ChannelType string `json:"channel_type"` + Sort *int `json:"sort"` + TestModel string `json:"test_model"` + ParamOverrides map[string]interface{} `json:"param_overrides"` + Config map[string]interface{} `json:"config"` +} + // UpdateGroup handles updating an existing group. func (s *Server) UpdateGroup(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) @@ -76,84 +147,70 @@ func (s *Server) UpdateGroup(c *gin.Context) { return } - var updateData models.Group - if err := c.ShouldBindJSON(&updateData); err != nil { + var req GroupUpdateRequest + if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) return } - // Validate group name if it's being updated - if updateData.Name != "" && !isValidGroupName(updateData.Name) { - response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid group name format. Use 3-30 lowercase letters, numbers, and underscores.")) - return - } - - // Use a transaction to ensure atomicity + // Start a transaction tx := s.DB.Begin() if tx.Error != nil { response.Error(c, app_errors.ErrDatabase) return } + defer tx.Rollback() // Rollback on panic - // Convert updateData to a map to ensure zero values (like Sort: 0) are updated - var updateMap map[string]interface{} - updateBytes, _ := json.Marshal(updateData) - if err := json.Unmarshal(updateBytes, &updateMap); err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Failed to process update data")) - return - } - - // If config is being updated, it needs to be marshalled to JSON string for GORM - if config, ok := updateMap["config"]; ok { - if configMap, isMap := config.(map[string]interface{}); isMap { - configJSON, err := json.Marshal(configMap) - if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Failed to process config data")) - return - } - updateMap["config"] = string(configJSON) + // Apply updates from the request + if req.Name != "" { + if !isValidGroupName(req.Name) { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid group name format.")) + return } + group.Name = req.Name } - - // Handle upstreams field specifically - if upstreams, ok := updateMap["upstreams"]; ok { - if upstreamsSlice, isSlice := upstreams.([]interface{}); isSlice { - upstreamsJSON, err := json.Marshal(upstreamsSlice) - if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Failed to process upstreams data")) - return - } - updateMap["upstreams"] = string(upstreamsJSON) + if req.DisplayName != "" { + group.DisplayName = req.DisplayName + } + if req.Description != "" { + group.Description = req.Description + } + if req.Upstreams != nil { + group.Upstreams = datatypes.JSON(req.Upstreams) + } + if req.ChannelType != "" { + group.ChannelType = req.ChannelType + } + if req.Sort != nil { + group.Sort = *req.Sort + } + if req.TestModel != "" { + group.TestModel = req.TestModel + } + if req.ParamOverrides != nil { + group.ParamOverrides = req.ParamOverrides + } + if req.Config != nil { + cleanedConfig, err := validateAndCleanConfig(req.Config) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid config format")) + return } + group.Config = cleanedConfig } - // Remove fields that are not actual columns or should not be updated from the map - delete(updateMap, "id") - delete(updateMap, "api_keys") - delete(updateMap, "created_at") - delete(updateMap, "updated_at") - - // Use Updates with a map to only update provided fields, including zero values - if err := tx.Model(&group).Updates(updateMap).Error; err != nil { - tx.Rollback() + // Save the updated group object + if err := tx.Save(&group).Error; err != nil { response.Error(c, app_errors.ParseDBError(err)) return } if err := tx.Commit().Error; err != nil { - tx.Rollback() response.Error(c, app_errors.ErrDatabase) return } - // Re-fetch the group to return the updated data - var updatedGroup models.Group - if err := s.DB.First(&updatedGroup, id).Error; err != nil { - response.Error(c, app_errors.ParseDBError(err)) - return - } - - response.Success(c, updatedGroup) + response.Success(c, group) } // DeleteGroup handles deleting a group. diff --git a/internal/handler/handler.go b/internal/handler/handler.go index fc40397..b9b8bbf 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -6,6 +6,7 @@ import ( "time" "gpt-load/internal/models" + "gpt-load/internal/services" "gpt-load/internal/types" "github.com/gin-gonic/gin" @@ -14,15 +15,30 @@ import ( // Server contains dependencies for HTTP handlers type Server struct { - DB *gorm.DB - config types.ConfigManager + DB *gorm.DB + config types.ConfigManager + KeyValidatorService *services.KeyValidatorService + KeyManualValidationService *services.KeyManualValidationService + TaskService *services.TaskService + KeyService *services.KeyService } // NewServer creates a new handler instance -func NewServer(db *gorm.DB, config types.ConfigManager) *Server { +func NewServer( + db *gorm.DB, + config types.ConfigManager, + keyValidatorService *services.KeyValidatorService, + keyManualValidationService *services.KeyManualValidationService, + taskService *services.TaskService, + keyService *services.KeyService, +) *Server { return &Server{ - DB: db, - config: config, + DB: db, + config: config, + KeyValidatorService: keyValidatorService, + KeyManualValidationService: keyManualValidationService, + TaskService: taskService, + KeyService: keyService, } } diff --git a/internal/handler/key_handler.go b/internal/handler/key_handler.go index 5b2a64b..501fe5c 100644 --- a/internal/handler/key_handler.go +++ b/internal/handler/key_handler.go @@ -1,60 +1,123 @@ -// Package handler provides HTTP handlers for the application package handler import ( + "fmt" app_errors "gpt-load/internal/errors" "gpt-load/internal/models" "gpt-load/internal/response" "strconv" + "strings" "github.com/gin-gonic/gin" + "gorm.io/gorm" ) -type CreateKeysRequest struct { - Keys []string `json:"keys" binding:"required"` +// validateGroupID validates and parses group ID from request parameter +func validateGroupID(c *gin.Context) (uint, error) { + groupIDStr := c.Param("id") + if groupIDStr == "" { + return 0, fmt.Errorf("group ID is required") + } + + groupID, err := strconv.Atoi(groupIDStr) + if err != nil || groupID <= 0 { + return 0, fmt.Errorf("invalid group ID format") + } + + return uint(groupID), nil } -// CreateKeysInGroup handles creating new keys within a specific group. -func (s *Server) CreateKeysInGroup(c *gin.Context) { - groupID, err := strconv.Atoi(c.Param("id")) +// validateKeyID validates and parses key ID from request parameter +func validateKeyID(c *gin.Context) (uint, error) { + keyIDStr := c.Param("key_id") + if keyIDStr == "" { + return 0, fmt.Errorf("key ID is required") + } + + keyID, err := strconv.Atoi(keyIDStr) + if err != nil || keyID <= 0 { + return 0, fmt.Errorf("invalid key ID format") + } + + return uint(keyID), nil +} + +// validateKeysText validates the keys text input +func validateKeysText(keysText string) error { + if strings.TrimSpace(keysText) == "" { + return fmt.Errorf("keys text cannot be empty") + } + + if len(keysText) > 1024*1024 { // 1MB limit + return fmt.Errorf("keys text is too large (max 1MB)") + } + + return nil +} + +// findGroupByID is a helper function to find a group by its ID. +func (s *Server) findGroupByID(c *gin.Context, groupID int) (*models.Group, bool) { + var group models.Group + if err := s.DB.First(&group, groupID).Error; err != nil { + if err == gorm.ErrRecordNotFound { + response.Error(c, app_errors.ErrResourceNotFound) + } else { + response.Error(c, app_errors.ParseDBError(err)) + } + return nil, false + } + return &group, true +} + +// AddMultipleKeysRequest defines the payload for adding multiple keys from a text block. +type AddMultipleKeysRequest struct { + KeysText string `json:"keys_text" binding:"required"` +} + +// AddMultipleKeys handles creating new keys from a text block within a specific group. +func (s *Server) AddMultipleKeys(c *gin.Context) { + groupID, err := validateGroupID(c) if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format")) + response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error())) return } - var req CreateKeysRequest + var req AddMultipleKeysRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) return } - var newKeys []models.APIKey - for _, keyVal := range req.Keys { - newKeys = append(newKeys, models.APIKey{ - GroupID: uint(groupID), - KeyValue: keyVal, - Status: "active", - }) + if err := validateKeysText(req.KeysText); err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error())) + return } - if err := s.DB.Create(&newKeys).Error; err != nil { + result, err := s.KeyService.AddMultipleKeys(groupID, req.KeysText) + if err != nil { response.Error(c, app_errors.ParseDBError(err)) return } - response.Success(c, newKeys) + response.Success(c, result) } // ListKeysInGroup handles listing all keys within a specific group. func (s *Server) ListKeysInGroup(c *gin.Context) { groupID, err := strconv.Atoi(c.Param("id")) if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format")) + response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID")) return } - var keys []models.APIKey - if err := s.DB.Where("group_id = ?", groupID).Find(&keys).Error; err != nil { + statusFilter := c.Query("status") + if statusFilter != "" && statusFilter != "active" && statusFilter != "inactive" { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid status filter")) + return + } + + keys, err := s.KeyService.ListKeysInGroup(uint(groupID), statusFilter) + if err != nil { response.Error(c, app_errors.ParseDBError(err)) return } @@ -62,90 +125,124 @@ func (s *Server) ListKeysInGroup(c *gin.Context) { response.Success(c, keys) } -// UpdateKey handles updating a specific key. -func (s *Server) UpdateKey(c *gin.Context) { +// DeleteSingleKey handles deleting a specific key. +func (s *Server) DeleteSingleKey(c *gin.Context) { groupID, err := strconv.Atoi(c.Param("id")) if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format")) + response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID")) return } keyID, err := strconv.Atoi(c.Param("key_id")) if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid key ID format")) + response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid key ID")) return } - var key models.APIKey - if err := s.DB.Where("group_id = ? AND id = ?", groupID, keyID).First(&key).Error; err != nil { + rowsAffected, err := s.KeyService.DeleteSingleKey(uint(groupID), uint(keyID)) + if err != nil { response.Error(c, app_errors.ParseDBError(err)) return } - - var updateData struct { - Status string `json:"status"` - } - if err := c.ShouldBindJSON(&updateData); err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) + if rowsAffected == 0 { + response.Error(c, app_errors.ErrResourceNotFound) return } - key.Status = updateData.Status - if err := s.DB.Save(&key).Error; err != nil { - response.Error(c, app_errors.ParseDBError(err)) + response.Success(c, gin.H{"message": "Key deleted successfully"}) +} + +// TestSingleKey handles a one-off validation test for a single key. +func (s *Server) TestSingleKey(c *gin.Context) { + keyID, err := validateKeyID(c) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error())) return } - response.Success(c, key) + isValid, validationErr := s.KeyValidatorService.TestSingleKeyByID(c.Request.Context(), keyID) + if validationErr != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrBadGateway, validationErr.Error())) + return + } + + if isValid { + response.Success(c, gin.H{"success": true, "message": "Key is valid."}) + } else { + response.Success(c, gin.H{"success": false, "message": "Key is invalid or has insufficient quota."}) + } } -type DeleteKeysRequest struct { - KeyIDs []uint `json:"key_ids" binding:"required"` -} - -// DeleteKeys handles deleting one or more keys. -func (s *Server) DeleteKeys(c *gin.Context) { +// ValidateGroupKeys initiates a manual validation task for all keys in a group. +func (s *Server) ValidateGroupKeys(c *gin.Context) { groupID, err := strconv.Atoi(c.Param("id")) if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format")) + response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID")) return } - var req DeleteKeysRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) + group, ok := s.findGroupByID(c, groupID) + if !ok { return } - if len(req.KeyIDs) == 0 { - response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "No key IDs provided")) + taskStatus, err := s.KeyManualValidationService.StartValidationTask(group) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrTaskInProgress, err.Error())) return } - // Start a transaction - tx := s.DB.Begin() - - // Verify all keys belong to the specified group - var count int64 - if err := tx.Model(&models.APIKey{}).Where("id IN ? AND group_id = ?", req.KeyIDs, groupID).Count(&count).Error; err != nil { - tx.Rollback() - response.Error(c, app_errors.ParseDBError(err)) - return - } - - if count != int64(len(req.KeyIDs)) { - tx.Rollback() - response.Error(c, app_errors.NewAPIError(app_errors.ErrForbidden, "One or more keys do not belong to the specified group")) - return - } - - // Delete the keys - if err := tx.Where("id IN ?", req.KeyIDs).Delete(&models.APIKey{}).Error; err != nil { - tx.Rollback() - response.Error(c, app_errors.ParseDBError(err)) - return - } - - tx.Commit() - response.Success(c, gin.H{"message": "Keys deleted successfully"}) + response.Success(c, taskStatus) +} + +// RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'. +func (s *Server) RestoreAllInvalidKeys(c *gin.Context) { + groupID, err := strconv.Atoi(c.Param("id")) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID")) + return + } + + rowsAffected, err := s.KeyService.RestoreAllInvalidKeys(uint(groupID)) + if err != nil { + response.Error(c, app_errors.ParseDBError(err)) + return + } + + response.Success(c, gin.H{"message": fmt.Sprintf("%d keys restored.", rowsAffected)}) +} + +// ClearAllInvalidKeys deletes all 'inactive' keys from a group. +func (s *Server) ClearAllInvalidKeys(c *gin.Context) { + groupID, err := strconv.Atoi(c.Param("id")) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID")) + return + } + + rowsAffected, err := s.KeyService.ClearAllInvalidKeys(uint(groupID)) + if err != nil { + response.Error(c, app_errors.ParseDBError(err)) + return + } + + response.Success(c, gin.H{"message": fmt.Sprintf("%d invalid keys cleared.", rowsAffected)}) +} + +// ExportKeys returns a list of keys for a group, filtered by status. +func (s *Server) ExportKeys(c *gin.Context) { + groupID, err := strconv.Atoi(c.Param("id")) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID")) + return + } + + filter := c.DefaultQuery("filter", "all") + keys, err := s.KeyService.ExportKeys(uint(groupID), filter) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error())) + return + } + + response.Success(c, gin.H{"keys": keys}) } diff --git a/internal/handler/task_handler.go b/internal/handler/task_handler.go new file mode 100644 index 0000000..d3464b8 --- /dev/null +++ b/internal/handler/task_handler.go @@ -0,0 +1,31 @@ +package handler + +import ( + "gpt-load/internal/response" + app_errors "gpt-load/internal/errors" + + "github.com/gin-gonic/gin" +) + +// GetTaskStatus handles requests for the status of the global long-running task. +func (s *Server) GetTaskStatus(c *gin.Context) { + taskStatus := s.TaskService.GetTaskStatus() + response.Success(c, taskStatus) +} + +// GetTaskResult handles requests for the result of a finished task. +func (s *Server) GetTaskResult(c *gin.Context) { + taskID := c.Param("task_id") + if taskID == "" { + response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Task ID is required")) + return + } + + result, found := s.TaskService.GetResult(taskID) + if !found { + response.Error(c, app_errors.ErrResourceNotFound) + return + } + + response.Success(c, result) +} diff --git a/internal/models/types.go b/internal/models/types.go index d5cc5c1..6193f91 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -1,9 +1,6 @@ package models import ( - "database/sql/driver" - "encoding/json" - "errors" "time" "gorm.io/datatypes" @@ -19,52 +16,36 @@ type SystemSetting struct { UpdatedAt time.Time `json:"updated_at"` } -// Upstreams 是一个上游地址的切片,可以被 GORM 正确处理 -type Upstreams []string - -// Value 实现 driver.Valuer 接口,用于将 Upstreams 类型转换为数据库值 -func (u Upstreams) Value() (driver.Value, error) { - if len(u) == 0 { - return "[]", nil - } - return json.Marshal(u) -} - -// Scan 实现 sql.Scanner 接口,用于将数据库值扫描到 Upstreams 类型 -func (u *Upstreams) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New("type assertion to []byte failed") - } - return json.Unmarshal(bytes, u) -} - // GroupConfig 存储特定于分组的配置 type GroupConfig struct { - BlacklistThreshold *int `json:"blacklist_threshold,omitempty"` - MaxRetries *int `json:"max_retries,omitempty"` - ServerReadTimeout *int `json:"server_read_timeout,omitempty"` - ServerWriteTimeout *int `json:"server_write_timeout,omitempty"` - ServerIdleTimeout *int `json:"server_idle_timeout,omitempty"` + BlacklistThreshold *int `json:"blacklist_threshold,omitempty"` + MaxRetries *int `json:"max_retries,omitempty"` + ServerReadTimeout *int `json:"server_read_timeout,omitempty"` + ServerWriteTimeout *int `json:"server_write_timeout,omitempty"` + ServerIdleTimeout *int `json:"server_idle_timeout,omitempty"` ServerGracefulShutdownTimeout *int `json:"server_graceful_shutdown_timeout,omitempty"` - RequestTimeout *int `json:"request_timeout,omitempty"` - ResponseTimeout *int `json:"response_timeout,omitempty"` - IdleConnTimeout *int `json:"idle_conn_timeout,omitempty"` + RequestTimeout *int `json:"request_timeout,omitempty"` + ResponseTimeout *int `json:"response_timeout,omitempty"` + IdleConnTimeout *int `json:"idle_conn_timeout,omitempty"` + KeyValidationIntervalMinutes *int `json:"key_validation_interval_minutes,omitempty"` } // Group 对应 groups 表 type Group struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` - Name string `gorm:"type:varchar(255);not null;unique" json:"name"` - DisplayName string `gorm:"type:varchar(255)" json:"display_name"` - Description string `gorm:"type:varchar(512)" json:"description"` - Upstreams Upstreams `gorm:"type:json;not null" json:"upstreams"` - ChannelType string `gorm:"type:varchar(50);not null" json:"channel_type"` - Sort int `gorm:"default:0" json:"sort"` - Config datatypes.JSONMap `gorm:"type:json" json:"config"` - APIKeys []APIKey `gorm:"foreignKey:GroupID" json:"api_keys"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + Name string `gorm:"type:varchar(255);not null;unique" json:"name"` + DisplayName string `gorm:"type:varchar(255)" json:"display_name"` + Description string `gorm:"type:varchar(512)" json:"description"` + Upstreams datatypes.JSON `gorm:"type:json;not null" json:"upstreams"` + ChannelType string `gorm:"type:varchar(50);not null" json:"channel_type"` + Sort int `gorm:"default:0" json:"sort"` + TestModel string `gorm:"type:varchar(255);not null" json:"test_model"` + ParamOverrides datatypes.JSONMap `gorm:"type:json" json:"param_overrides"` + Config datatypes.JSONMap `gorm:"type:json" json:"config"` + APIKeys []APIKey `gorm:"foreignKey:GroupID" json:"api_keys"` + LastValidatedAt *time.Time `json:"last_validated_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // APIKey 对应 api_keys 表 @@ -94,7 +75,7 @@ type RequestLog struct { // GroupRequestStat 用于表示每个分组的请求统计 type GroupRequestStat struct { - DisplayName string `json:"display_name"` + DisplayName string `json:"display_name"` RequestCount int64 `json:"request_count"` } diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 4407386..35086ec 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -2,11 +2,14 @@ package proxy import ( + "bytes" + "encoding/json" "fmt" "gpt-load/internal/channel" app_errors "gpt-load/internal/errors" "gpt-load/internal/models" "gpt-load/internal/response" + "io" "sync" "sync/atomic" "time" @@ -19,14 +22,16 @@ import ( // ProxyServer represents the proxy server type ProxyServer struct { DB *gorm.DB + channelFactory *channel.Factory groupCounters sync.Map // map[uint]*atomic.Uint64 requestLogChan chan models.RequestLog } // NewProxyServer creates a new proxy server -func NewProxyServer(db *gorm.DB, requestLogChan chan models.RequestLog) (*ProxyServer, error) { +func NewProxyServer(db *gorm.DB, channelFactory *channel.Factory, requestLogChan chan models.RequestLog) (*ProxyServer, error) { return &ProxyServer{ DB: db, + channelFactory: channelFactory, groupCounters: sync.Map{}, requestLogChan: requestLogChan, }, nil @@ -52,22 +57,31 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) { } // 3. Get the appropriate channel handler from the factory - channelHandler, err := channel.GetChannel(&group) + channelHandler, err := ps.channelFactory.GetChannel(&group) if err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s': %v", groupName, err))) return } - // 4. Forward the request using the channel handler + // 4. Apply parameter overrides if they exist + if len(group.ParamOverrides) > 0 { + err := ps.applyParamOverrides(c, &group) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to apply parameter overrides: %v", err))) + return + } + } + + // 5. Forward the request using the channel handler err = channelHandler.Handle(c, apiKey, &group) - // 5. Log the request asynchronously + // 6. Log the request asynchronously isSuccess := err == nil if !isSuccess { logrus.WithFields(logrus.Fields{ - "group": group.Name, + "group": group.Name, "key_id": apiKey.ID, - "error": err.Error(), + "error": err.Error(), }).Error("Channel handler failed") } go ps.logRequest(c, &group, apiKey, startTime, isSuccess) @@ -145,3 +159,51 @@ func (ps *ProxyServer) updateKeyStats(keyID uint, success bool) { func (ps *ProxyServer) Close() { // Nothing to close for now } + +func (ps *ProxyServer) applyParamOverrides(c *gin.Context, group *models.Group) error { + // Read the original request body + bodyBytes, err := io.ReadAll(c.Request.Body) + if err != nil { + return fmt.Errorf("failed to read request body: %w", err) + } + c.Request.Body.Close() // Close the original body + + // If body is empty, nothing to override, just restore the body + if len(bodyBytes) == 0 { + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + return nil + } + + // Save the original Content-Type + originalContentType := c.GetHeader("Content-Type") + + // Unmarshal the body into a map + var requestData map[string]interface{} + if err := json.Unmarshal(bodyBytes, &requestData); err != nil { + // If not a valid JSON, just pass it through + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + return nil + } + + // Merge the overrides into the request data + for key, value := range group.ParamOverrides { + requestData[key] = value + } + + // Marshal the new data back to JSON + newBodyBytes, err := json.Marshal(requestData) + if err != nil { + return fmt.Errorf("failed to marshal new request body: %w", err) + } + + // Replace the request body with the new one + c.Request.Body = io.NopCloser(bytes.NewBuffer(newBodyBytes)) + c.Request.ContentLength = int64(len(newBodyBytes)) + + // Restore the original Content-Type header + if originalContentType != "" { + c.Request.Header.Set("Content-Type", originalContentType) + } + + return nil +} diff --git a/internal/router/router.go b/internal/router/router.go index b107633..d2df176 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -106,13 +106,27 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser groups.PUT("/:id", serverHandler.UpdateGroup) groups.DELETE("/:id", serverHandler.DeleteGroup) + // Key-specific routes keys := groups.Group("/:id/keys") { - keys.POST("", serverHandler.CreateKeysInGroup) keys.GET("", serverHandler.ListKeysInGroup) - keys.PUT("/:key_id", serverHandler.UpdateKey) - keys.DELETE("", serverHandler.DeleteKeys) + keys.POST("/add-multiple", serverHandler.AddMultipleKeys) + keys.POST("/restore-all-invalid", serverHandler.RestoreAllInvalidKeys) + keys.POST("/clear-all-invalid", serverHandler.ClearAllInvalidKeys) + keys.GET("/export", serverHandler.ExportKeys) + keys.DELETE("/:key_id", serverHandler.DeleteSingleKey) + keys.POST("/:key_id/test", serverHandler.TestSingleKey) } + + // Group-level actions + groups.POST("/:id/validate-keys", serverHandler.ValidateGroupKeys) + } + + // Tasks + tasks := api.Group("/tasks") + { + tasks.GET("/key-validation/status", serverHandler.GetTaskStatus) + tasks.GET("/:task_id/result", serverHandler.GetTaskResult) } // 仪表板和日志 diff --git a/internal/services/key_cron_service.go b/internal/services/key_cron_service.go new file mode 100644 index 0000000..6fc9914 --- /dev/null +++ b/internal/services/key_cron_service.go @@ -0,0 +1,206 @@ +package services + +import ( + "context" + "gpt-load/internal/config" + "gpt-load/internal/models" + "sync" + "time" + + "github.com/sirupsen/logrus" + "gorm.io/gorm" +) + +// KeyCronService is responsible for periodically validating all API keys. +type KeyCronService struct { + DB *gorm.DB + Validator *KeyValidatorService + SettingsManager *config.SystemSettingsManager + stopChan chan struct{} + wg sync.WaitGroup +} + +// NewKeyCronService creates a new KeyCronService. +func NewKeyCronService(db *gorm.DB, validator *KeyValidatorService, settingsManager *config.SystemSettingsManager) *KeyCronService { + return &KeyCronService{ + DB: db, + Validator: validator, + SettingsManager: settingsManager, + stopChan: make(chan struct{}), + } +} + +// Start begins the cron job. +func (s *KeyCronService) Start() { + logrus.Info("Starting KeyCronService...") + s.wg.Add(1) + go s.run() +} + +// Stop stops the cron job. +func (s *KeyCronService) Stop() { + logrus.Info("Stopping KeyCronService...") + close(s.stopChan) + s.wg.Wait() + logrus.Info("KeyCronService stopped.") +} + +func (s *KeyCronService) run() { + defer s.wg.Done() + ctx := context.Background() + + // Run once on start + s.validateAllGroups(ctx) + + for { + // Dynamically get the interval for the next run + intervalMinutes := s.SettingsManager.GetInt("key_validation_interval_minutes", 60) + if intervalMinutes <= 0 { + intervalMinutes = 60 // Fallback to a safe default + } + nextRunTimer := time.NewTimer(time.Duration(intervalMinutes) * time.Minute) + + select { + case <-nextRunTimer.C: + s.validateAllGroups(ctx) + case <-s.stopChan: + nextRunTimer.Stop() + return + } + } +} + +func (s *KeyCronService) validateAllGroups(ctx context.Context) { + logrus.Info("KeyCronService: Starting validation cycle for all groups.") + var groups []models.Group + if err := s.DB.Find(&groups).Error; err != nil { + logrus.Errorf("KeyCronService: Failed to get groups: %v", err) + return + } + + for _, group := range groups { + groupCopy := group // Create a copy for the closure + go func(g models.Group) { + // Get effective settings for the group + effectiveSettings := s.SettingsManager.GetEffectiveConfig(g.Config) + interval := time.Duration(effectiveSettings.KeyValidationIntervalMinutes) * time.Minute + + // Check if it's time to validate this group + if g.LastValidatedAt == nil || time.Since(*g.LastValidatedAt) > interval { + s.validateGroup(ctx, &g) + } + }(groupCopy) + } + logrus.Info("KeyCronService: Validation cycle finished.") +} + +func (s *KeyCronService) validateGroup(ctx context.Context, group *models.Group) { + var keys []models.APIKey + if err := s.DB.Where("group_id = ?", group.ID).Find(&keys).Error; err != nil { + logrus.Errorf("KeyCronService: Failed to get keys for group %s: %v", group.Name, err) + return + } + + if len(keys) == 0 { + return + } + + logrus.Infof("KeyCronService: Validating %d keys for group %s", len(keys), group.Name) + + jobs := make(chan models.APIKey, len(keys)) + results := make(chan models.APIKey, len(keys)) + + concurrency := s.SettingsManager.GetInt("key_validation_concurrency", 10) + if concurrency <= 0 { + concurrency = 10 // Fallback to a safe default + } + + var wg sync.WaitGroup + for i := 0; i < concurrency; i++ { + wg.Add(1) + go s.worker(ctx, &wg, group, jobs, results) + } + + for _, key := range keys { + jobs <- key + } + close(jobs) + + wg.Wait() + close(results) + + var keysToUpdate []models.APIKey + for key := range results { + keysToUpdate = append(keysToUpdate, key) + } + + if len(keysToUpdate) > 0 { + s.batchUpdateKeyStatus(keysToUpdate) + } + + // Update the last validated timestamp for the group + if err := s.DB.Model(group).Update("last_validated_at", time.Now()).Error; err != nil { + logrus.Errorf("KeyCronService: Failed to update last_validated_at for group %s: %v", group.Name, err) + } +} + +func (s *KeyCronService) worker(ctx context.Context, wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- models.APIKey) { + defer wg.Done() + for key := range jobs { + isValid, err := s.Validator.ValidateSingleKey(ctx, &key, group) + // Only update status if there was no error during validation + if err != nil { + logrus.Warnf("KeyCronService: Failed to validate key ID %d for group %s: %v. Skipping status update.", key.ID, group.Name, err) + continue + } + + newStatus := "inactive" + if isValid { + newStatus = "active" + } + + // Only send to results if the status has changed + if key.Status != newStatus { + key.Status = newStatus + results <- key + } + } +} + +func (s *KeyCronService) batchUpdateKeyStatus(keys []models.APIKey) { + if len(keys) == 0 { + return + } + logrus.Infof("KeyCronService: Batch updating status for %d keys.", len(keys)) + + activeIDs := []uint{} + inactiveIDs := []uint{} + + for _, key := range keys { + if key.Status == "active" { + activeIDs = append(activeIDs, key.ID) + } else { + inactiveIDs = append(inactiveIDs, key.ID) + } + } + + err := s.DB.Transaction(func(tx *gorm.DB) error { + if len(activeIDs) > 0 { + if err := tx.Model(&models.APIKey{}).Where("id IN ?", activeIDs).Update("status", "active").Error; err != nil { + return err + } + logrus.Infof("KeyCronService: Set %d keys to 'active'.", len(activeIDs)) + } + if len(inactiveIDs) > 0 { + if err := tx.Model(&models.APIKey{}).Where("id IN ?", inactiveIDs).Update("status", "inactive").Error; err != nil { + return err + } + logrus.Infof("KeyCronService: Set %d keys to 'inactive'.", len(inactiveIDs)) + } + return nil + }) + + if err != nil { + logrus.Errorf("KeyCronService: Failed to batch update key status: %v", err) + } +} diff --git a/internal/services/key_manual_validation_service.go b/internal/services/key_manual_validation_service.go new file mode 100644 index 0000000..211857b --- /dev/null +++ b/internal/services/key_manual_validation_service.go @@ -0,0 +1,122 @@ +package services + +import ( + "context" + "fmt" + "gpt-load/internal/config" + "gpt-load/internal/models" + "sync" + "time" + + "github.com/google/uuid" + "github.com/sirupsen/logrus" + "gorm.io/gorm" +) + +// ManualValidationResult holds the result of a manual validation task. +type ManualValidationResult struct { + TotalKeys int `json:"total_keys"` + ValidKeys int `json:"valid_keys"` + InvalidKeys int `json:"invalid_keys"` +} + +// KeyManualValidationService handles user-initiated key validation for a group. +type KeyManualValidationService struct { + DB *gorm.DB + Validator *KeyValidatorService + TaskService *TaskService + SettingsManager *config.SystemSettingsManager +} + +// NewKeyManualValidationService creates a new KeyManualValidationService. +func NewKeyManualValidationService(db *gorm.DB, validator *KeyValidatorService, taskService *TaskService, settingsManager *config.SystemSettingsManager) *KeyManualValidationService { + return &KeyManualValidationService{ + DB: db, + Validator: validator, + TaskService: taskService, + SettingsManager: settingsManager, + } +} + +// StartValidationTask starts a new manual validation task for a given group. +func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (*TaskStatus, error) { + var keys []models.APIKey + if err := s.DB.Where("group_id = ?", group.ID).Find(&keys).Error; err != nil { + return nil, fmt.Errorf("failed to get keys for group %s: %w", group.Name, err) + } + + if len(keys) == 0 { + return nil, fmt.Errorf("no keys to validate in group %s", group.Name) + } + + taskID := uuid.New().String() + timeoutMinutes := s.SettingsManager.GetInt("key_validation_task_timeout_minutes", 60) + timeout := time.Duration(timeoutMinutes) * time.Minute + + taskStatus, err := s.TaskService.StartTask(taskID, group.Name, len(keys), timeout) + if err != nil { + return nil, err // A task is already running + } + + // Run the validation in a separate goroutine + go s.runValidation(group, keys, taskStatus) + + return taskStatus, nil +} + +func (s *KeyManualValidationService) runValidation(group *models.Group, keys []models.APIKey, task *TaskStatus) { + defer s.TaskService.EndTask() + + logrus.Infof("Starting manual validation for group %s (TaskID: %s)", group.Name, task.TaskID) + + jobs := make(chan models.APIKey, len(keys)) + results := make(chan bool, len(keys)) + + concurrency := s.SettingsManager.GetInt("key_validation_concurrency", 10) + if concurrency <= 0 { + concurrency = 10 // Fallback to a safe default + } + + var wg sync.WaitGroup + for i := 0; i < concurrency; i++ { + wg.Add(1) + go s.validationWorker(&wg, group, jobs, results) + } + + for _, key := range keys { + jobs <- key + } + close(jobs) + + wg.Wait() + close(results) + + validCount := 0 + processedCount := 0 + for isValid := range results { + processedCount++ + if isValid { + validCount++ + } + // Update progress + s.TaskService.UpdateProgress(processedCount) + } + + result := ManualValidationResult{ + TotalKeys: len(keys), + ValidKeys: validCount, + InvalidKeys: len(keys) - validCount, + } + + // Store the final result + s.TaskService.StoreResult(task.TaskID, result) + logrus.Infof("Manual validation finished for group %s (TaskID: %s): %+v", group.Name, task.TaskID, result) +} + +func (s *KeyManualValidationService) validationWorker(wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- bool) { + defer wg.Done() + for key := range jobs { + isValid, _ := s.Validator.ValidateSingleKey(context.Background(), &key, group) + results <- isValid + } +} diff --git a/internal/services/key_service.go b/internal/services/key_service.go new file mode 100644 index 0000000..624a2eb --- /dev/null +++ b/internal/services/key_service.go @@ -0,0 +1,206 @@ +package services + +import ( + "encoding/json" + "fmt" + "gpt-load/internal/models" + "regexp" + "strings" + + "gorm.io/gorm" +) + +// AddKeysResult holds the result of adding multiple keys. +type AddKeysResult struct { + AddedCount int `json:"added_count"` + IgnoredCount int `json:"ignored_count"` + TotalInGroup int64 `json:"total_in_group"` +} + +// KeyService provides services related to API keys. +type KeyService struct { + DB *gorm.DB +} + +// NewKeyService creates a new KeyService. +func NewKeyService(db *gorm.DB) *KeyService { + return &KeyService{DB: db} +} + +// AddMultipleKeys handles the business logic of creating new keys from a text block. +func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysResult, error) { + // 1. Parse keys from the text block + keys := s.parseKeysFromText(keysText) + if len(keys) == 0 { + return nil, fmt.Errorf("no valid keys found in the input text") + } + + // 2. Get the group information for validation + var group models.Group + if err := s.DB.First(&group, groupID).Error; err != nil { + return nil, fmt.Errorf("failed to find group: %w", err) + } + + // 3. Get existing keys in the group for deduplication + var existingKeys []models.APIKey + if err := s.DB.Where("group_id = ?", groupID).Select("key_value").Find(&existingKeys).Error; err != nil { + return nil, err + } + existingKeyMap := make(map[string]bool) + for _, k := range existingKeys { + existingKeyMap[k.KeyValue] = true + } + + // 4. Prepare new keys with basic validation only + var newKeysToCreate []models.APIKey + uniqueNewKeys := make(map[string]bool) + + for _, keyVal := range keys { + trimmedKey := strings.TrimSpace(keyVal) + if trimmedKey == "" { + continue + } + + // Check if key already exists + if existingKeyMap[trimmedKey] || uniqueNewKeys[trimmedKey] { + continue + } + + // 通用验证:只做基础格式检查,不做渠道特定验证 + if s.isValidKeyFormat(trimmedKey) { + uniqueNewKeys[trimmedKey] = true + newKeysToCreate = append(newKeysToCreate, models.APIKey{ + GroupID: groupID, + KeyValue: trimmedKey, + Status: "active", + }) + } + } + + addedCount := len(newKeysToCreate) + // 更准确的忽略计数:包括重复的和无效的 + ignoredCount := len(keys) - addedCount + + // 5. Insert new keys if any + if addedCount > 0 { + if err := s.DB.Create(&newKeysToCreate).Error; err != nil { + return nil, err + } + } + + // 6. Get the new total count + var totalInGroup int64 + if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil { + return nil, err + } + + return &AddKeysResult{ + AddedCount: addedCount, + IgnoredCount: ignoredCount, + TotalInGroup: totalInGroup, + }, nil +} + +func (s *KeyService) parseKeysFromText(text string) []string { + var keys []string + + // First, try to parse as a JSON array of strings + if json.Unmarshal([]byte(text), &keys) == nil && len(keys) > 0 { + return s.filterValidKeys(keys) + } + + // 通用解析:通过分隔符分割文本,不使用复杂的正则表达式 + delimiters := regexp.MustCompile(`[\s,;|\n\r\t]+`) + splitKeys := delimiters.Split(strings.TrimSpace(text), -1) + + for _, key := range splitKeys { + key = strings.TrimSpace(key) + if key != "" { + keys = append(keys, key) + } + } + + return s.filterValidKeys(keys) +} + +// filterValidKeys validates and filters potential API keys +func (s *KeyService) filterValidKeys(keys []string) []string { + var validKeys []string + for _, key := range keys { + key = strings.TrimSpace(key) + if s.isValidKeyFormat(key) { + validKeys = append(validKeys, key) + } + } + return validKeys +} + +// isValidKeyFormat performs basic validation on key format +func (s *KeyService) isValidKeyFormat(key string) bool { + if len(key) < 4 || len(key) > 1000 { + return false + } + + if key == "" || + strings.TrimSpace(key) == "" { + return false + } + + validChars := regexp.MustCompile(`^[a-zA-Z0-9_\-./+=:]+$`) + return validChars.MatchString(key) +} + +// RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'. +func (s *KeyService) RestoreAllInvalidKeys(groupID uint) (int64, error) { + result := s.DB.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, "inactive").Update("status", "active") + return result.RowsAffected, result.Error +} + +// ClearAllInvalidKeys deletes all 'inactive' keys from a group. +func (s *KeyService) ClearAllInvalidKeys(groupID uint) (int64, error) { + result := s.DB.Where("group_id = ? AND status = ?", groupID, "inactive").Delete(&models.APIKey{}) + return result.RowsAffected, result.Error +} + +// DeleteSingleKey deletes a specific key from a group. +func (s *KeyService) DeleteSingleKey(groupID, keyID uint) (int64, error) { + result := s.DB.Where("group_id = ? AND id = ?", groupID, keyID).Delete(&models.APIKey{}) + return result.RowsAffected, result.Error +} + +// ExportKeys returns a list of keys for a group, filtered by status. +func (s *KeyService) ExportKeys(groupID uint, filter string) ([]string, error) { + query := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID) + + switch filter { + case "valid": + query = query.Where("status = ?", "active") + case "invalid": + query = query.Where("status = ?", "inactive") + case "all": + // No status filter needed + default: + return nil, fmt.Errorf("invalid filter value. Use 'all', 'valid', or 'invalid'") + } + + var keys []string + if err := query.Pluck("key_value", &keys).Error; err != nil { + return nil, err + } + return keys, nil +} + +// ListKeysInGroup lists all keys within a specific group, filtered by status. +func (s *KeyService) ListKeysInGroup(groupID uint, statusFilter string) ([]models.APIKey, error) { + var keys []models.APIKey + query := s.DB.Where("group_id = ?", groupID) + + if statusFilter != "" { + query = query.Where("status = ?", statusFilter) + } + + if err := query.Find(&keys).Error; err != nil { + return nil, err + } + return keys, nil +} diff --git a/internal/services/key_validator_service.go b/internal/services/key_validator_service.go new file mode 100644 index 0000000..54d49dc --- /dev/null +++ b/internal/services/key_validator_service.go @@ -0,0 +1,91 @@ +package services + +import ( + "context" + "fmt" + "gpt-load/internal/channel" + "gpt-load/internal/models" + + "github.com/sirupsen/logrus" + "gorm.io/gorm" +) + +// KeyValidatorService provides methods to validate API keys. +type KeyValidatorService struct { + DB *gorm.DB + channelFactory *channel.Factory +} + +// NewKeyValidatorService creates a new KeyValidatorService. +func NewKeyValidatorService(db *gorm.DB, factory *channel.Factory) *KeyValidatorService { + return &KeyValidatorService{ + DB: db, + channelFactory: factory, + } +} + +// ValidateSingleKey performs a validation check on a single API key. +// It does not modify the key's state in the database. +// It returns true if the key is valid, and an error if it's not. +func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models.APIKey, group *models.Group) (bool, error) { + // 添加超时保护 + if ctx.Err() != nil { + return false, fmt.Errorf("context cancelled or timed out: %w", ctx.Err()) + } + + ch, err := s.channelFactory.GetChannel(group) + if err != nil { + logrus.WithFields(logrus.Fields{ + "group_id": group.ID, + "group_name": group.Name, + "channel_type": group.ChannelType, + "error": err, + }).Error("Failed to get channel for key validation") + return false, fmt.Errorf("failed to get channel for group %s: %w", group.Name, err) + } + + // 记录验证开始 + logrus.WithFields(logrus.Fields{ + "key_id": key.ID, + "group_id": group.ID, + "group_name": group.Name, + }).Debug("Starting key validation") + + isValid, validationErr := ch.ValidateKey(ctx, key.KeyValue) + if validationErr != nil { + logrus.WithFields(logrus.Fields{ + "key_id": key.ID, + "group_id": group.ID, + "group_name": group.Name, + "error": validationErr, + }).Warn("Key validation failed") + return false, validationErr + } + + // 记录验证结果 + logrus.WithFields(logrus.Fields{ + "key_id": key.ID, + "group_id": group.ID, + "group_name": group.Name, + "is_valid": isValid, + }).Debug("Key validation completed") + + return isValid, nil +} + +// TestSingleKeyByID performs a synchronous validation test for a single API key by its ID. +// It is intended for handling user-initiated "Test" actions. +// It does not modify the key's state in the database. +func (s *KeyValidatorService) TestSingleKeyByID(ctx context.Context, keyID uint) (bool, error) { + var apiKey models.APIKey + if err := s.DB.First(&apiKey, keyID).Error; err != nil { + return false, fmt.Errorf("failed to find api key with id %d: %w", keyID, err) + } + + var group models.Group + if err := s.DB.First(&group, apiKey.GroupID).Error; err != nil { + return false, fmt.Errorf("failed to find group with id %d: %w", apiKey.GroupID, err) + } + + return s.ValidateSingleKey(ctx, &apiKey, &group) +} diff --git a/internal/services/task_service.go b/internal/services/task_service.go new file mode 100644 index 0000000..139be45 --- /dev/null +++ b/internal/services/task_service.go @@ -0,0 +1,125 @@ +package services + +import ( + "errors" + "sync" + "time" +) + +// TaskStatus represents the status of a long-running task. +type TaskStatus struct { + IsRunning bool `json:"is_running"` + GroupName string `json:"group_name,omitempty"` + Processed int `json:"processed,omitempty"` + Total int `json:"total,omitempty"` + TaskID string `json:"task_id,omitempty"` + ExpiresAt time.Time `json:"-"` // Internal field to handle zombie tasks + lastUpdated time.Time +} + +// TaskService manages the state of a single, global, long-running task. +type TaskService struct { + mu sync.Mutex + status TaskStatus + resultsCache map[string]interface{} + cacheOrder []string + maxCacheSize int +} + +// NewTaskService creates a new TaskService. +func NewTaskService() *TaskService { + return &TaskService{ + resultsCache: make(map[string]interface{}), + cacheOrder: make([]string, 0), + maxCacheSize: 100, // Store results for the last 100 tasks + } +} + +// StartTask attempts to start a new task. It returns an error if a task is already running. +func (s *TaskService) StartTask(taskID, groupName string, total int, timeout time.Duration) (*TaskStatus, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Zombie task check + if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) { + // The previous task is considered a zombie, reset it. + s.status = TaskStatus{} + } + + if s.status.IsRunning { + return nil, errors.New("a task is already running") + } + + s.status = TaskStatus{ + IsRunning: true, + TaskID: taskID, + GroupName: groupName, + Total: total, + Processed: 0, + ExpiresAt: time.Now().Add(timeout), + lastUpdated: time.Now(), + } + + return &s.status, nil +} + +// GetTaskStatus returns the current status of the task. +func (s *TaskService) GetTaskStatus() *TaskStatus { + s.mu.Lock() + defer s.mu.Unlock() + + // Zombie task check + if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) { + s.status = TaskStatus{} // Reset if expired + } + + // Return a copy to prevent race conditions on the caller's side + statusCopy := s.status + return &statusCopy +} + +// UpdateProgress updates the progress of the current task. +func (s *TaskService) UpdateProgress(processed int) { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.status.IsRunning { + return + } + + s.status.Processed = processed + s.status.lastUpdated = time.Now() +} + +// EndTask marks the current task as finished. +func (s *TaskService) EndTask() { + s.mu.Lock() + defer s.mu.Unlock() + + s.status.IsRunning = false +} + +// StoreResult stores the result of a finished task. +func (s *TaskService) StoreResult(taskID string, result interface{}) { + s.mu.Lock() + defer s.mu.Unlock() + + if _, exists := s.resultsCache[taskID]; !exists { + if len(s.cacheOrder) >= s.maxCacheSize { + oldestTaskID := s.cacheOrder[0] + delete(s.resultsCache, oldestTaskID) + s.cacheOrder = s.cacheOrder[1:] + } + s.cacheOrder = append(s.cacheOrder, taskID) + } + s.resultsCache[taskID] = result +} + +// GetResult retrieves the result of a finished task. +func (s *TaskService) GetResult(taskID string) (interface{}, bool) { + s.mu.Lock() + defer s.mu.Unlock() + + result, found := s.resultsCache[taskID] + return result, found +}