From c40ee7cb5badc4a8b79c19e78dd57cbaa32908d0 Mon Sep 17 00:00:00 2001 From: tbphp Date: Sat, 5 Jul 2025 17:19:20 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BB=BB=E5=8A=A1=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E6=80=A7=E8=83=BD=E5=8F=8A=E6=8E=A5=E5=8F=A3=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 1 - go.sum | 2 - internal/handler/task_handler.go | 25 +-- internal/router/router.go | 6 +- .../services/key_manual_validation_service.go | 43 ++-- internal/services/key_validator_service.go | 4 +- internal/services/task_service.go | 183 ++++++++++-------- 7 files changed, 135 insertions(+), 129 deletions(-) diff --git a/go.mod b/go.mod index 3ebe6e7..a3292a4 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,6 @@ require ( github.com/gin-contrib/static v1.1.5 github.com/gin-gonic/gin v1.10.1 github.com/go-sql-driver/mysql v1.8.1 - github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/redis/go-redis/v9 v9.5.3 github.com/sirupsen/logrus v1.9.3 diff --git a/go.sum b/go.sum index acb8e11..df7cfbd 100644 --- a/go.sum +++ b/go.sum @@ -49,8 +49,6 @@ github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EO github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= diff --git a/internal/handler/task_handler.go b/internal/handler/task_handler.go index d3464b8..541a3d8 100644 --- a/internal/handler/task_handler.go +++ b/internal/handler/task_handler.go @@ -1,31 +1,18 @@ package handler import ( - "gpt-load/internal/response" app_errors "gpt-load/internal/errors" + "gpt-load/internal/response" "github.com/gin-gonic/gin" ) // GetTaskStatus handles requests for the status of the global long-running task. func (s *Server) GetTaskStatus(c *gin.Context) { - taskStatus := s.TaskService.GetTaskStatus() + taskStatus, err := s.TaskService.GetTaskStatus() + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, "Failed to get task status")) + return + } response.Success(c, taskStatus) } - -// GetTaskResult handles requests for the result of a finished task. -func (s *Server) GetTaskResult(c *gin.Context) { - taskID := c.Param("task_id") - if taskID == "" { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Task ID is required")) - return - } - - result, found := s.TaskService.GetResult(taskID) - if !found { - response.Error(c, app_errors.ErrResourceNotFound) - return - } - - response.Success(c, result) -} diff --git a/internal/router/router.go b/internal/router/router.go index 84db09e..91e4abe 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -124,11 +124,7 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser } // Tasks - tasks := api.Group("/tasks") - { - tasks.GET("/key-validation/status", serverHandler.GetTaskStatus) - tasks.GET("/:task_id/result", serverHandler.GetTaskResult) - } + api.GET("/tasks/status", serverHandler.GetTaskStatus) // 仪表板和日志 dashboard := api.Group("/dashboard") diff --git a/internal/services/key_manual_validation_service.go b/internal/services/key_manual_validation_service.go index 211857b..ed83890 100644 --- a/internal/services/key_manual_validation_service.go +++ b/internal/services/key_manual_validation_service.go @@ -8,7 +8,6 @@ import ( "sync" "time" - "github.com/google/uuid" "github.com/sirupsen/logrus" "gorm.io/gorm" ) @@ -49,13 +48,12 @@ func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (* return nil, fmt.Errorf("no keys to validate in group %s", group.Name) } - taskID := uuid.New().String() timeoutMinutes := s.SettingsManager.GetInt("key_validation_task_timeout_minutes", 60) timeout := time.Duration(timeoutMinutes) * time.Minute - taskStatus, err := s.TaskService.StartTask(taskID, group.Name, len(keys), timeout) + taskStatus, err := s.TaskService.StartTask(group.Name, len(keys), timeout) if err != nil { - return nil, err // A task is already running + return nil, err } // Run the validation in a separate goroutine @@ -65,9 +63,7 @@ func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (* } func (s *KeyManualValidationService) runValidation(group *models.Group, keys []models.APIKey, task *TaskStatus) { - defer s.TaskService.EndTask() - - logrus.Infof("Starting manual validation for group %s (TaskID: %s)", group.Name, task.TaskID) + logrus.Infof("Starting manual validation for group %s", group.Name) jobs := make(chan models.APIKey, len(keys)) results := make(chan bool, len(keys)) @@ -88,18 +84,34 @@ func (s *KeyManualValidationService) runValidation(group *models.Group, keys []m } close(jobs) - wg.Wait() - close(results) + // Wait for all workers to complete in a separate goroutine + go func() { + wg.Wait() + close(results) + }() validCount := 0 processedCount := 0 + lastUpdateTime := time.Now() + for isValid := range results { processedCount++ if isValid { validCount++ } - // Update progress - s.TaskService.UpdateProgress(processedCount) + + // Throttle progress updates to once per second + if time.Since(lastUpdateTime) > time.Second { + if err := s.TaskService.UpdateProgress(processedCount); err != nil { + logrus.Warnf("Failed to update task progress: %v", err) + } + lastUpdateTime = time.Now() + } + } + + // Ensure the final progress is always updated + if err := s.TaskService.UpdateProgress(processedCount); err != nil { + logrus.Warnf("Failed to update final task progress: %v", err) } result := ManualValidationResult{ @@ -108,11 +120,14 @@ func (s *KeyManualValidationService) runValidation(group *models.Group, keys []m InvalidKeys: len(keys) - validCount, } - // Store the final result - s.TaskService.StoreResult(task.TaskID, result) - logrus.Infof("Manual validation finished for group %s (TaskID: %s): %+v", group.Name, task.TaskID, result) + // End the task and store the final result + if err := s.TaskService.EndTask(result, nil); err != nil { + logrus.Errorf("Failed to end task for group %s: %v", group.Name, err) + } + logrus.Infof("Manual validation finished for group %s: %+v", group.Name, result) } +// validationResult 包含验证结果信息 func (s *KeyManualValidationService) validationWorker(wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- bool) { defer wg.Done() for key := range jobs { diff --git a/internal/services/key_validator_service.go b/internal/services/key_validator_service.go index 72c9e25..6cc11d2 100644 --- a/internal/services/key_validator_service.go +++ b/internal/services/key_validator_service.go @@ -32,8 +32,6 @@ func NewKeyValidatorService(db *gorm.DB, factory *channel.Factory) *KeyValidator } // ValidateSingleKey performs a validation check on a single API key. -// It does not modify the key's state in the database. -// It returns true if the key is valid, and an error if it's not. func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models.APIKey, group *models.Group) (bool, error) { // 添加超时保护 if ctx.Err() != nil { @@ -65,7 +63,7 @@ func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models "group_id": group.ID, "group_name": group.Name, "error": validationErr, - }).Warn("Key validation failed") + }).Debug("Key validation failed") return false, validationErr } diff --git a/internal/services/task_service.go b/internal/services/task_service.go index ea6689d..c6be925 100644 --- a/internal/services/task_service.go +++ b/internal/services/task_service.go @@ -1,125 +1,138 @@ package services import ( + "encoding/json" "errors" - "sync" + "fmt" + "gpt-load/internal/store" "time" ) -// TaskStatus represents the status of a long-running task. +const ( + globalTaskKey = "global_task:key_validation" + ResultTTL = 60 * time.Minute +) + +// TaskStatus represents the full lifecycle of a long-running task. type TaskStatus struct { - IsRunning bool `json:"is_running"` - GroupName string `json:"group_name,omitempty"` - Processed int `json:"processed,omitempty"` - Total int `json:"total,omitempty"` - TaskID string `json:"task_id,omitempty"` - ExpiresAt time.Time `json:"-"` // Internal field to handle zombie tasks - lastUpdated time.Time + IsRunning bool `json:"is_running"` + GroupName string `json:"group_name,omitempty"` + Processed int `json:"processed"` + Total int `json:"total"` + Result any `json:"result,omitempty"` + Error string `json:"error,omitempty"` + StartedAt time.Time `json:"started_at"` + FinishedAt *time.Time `json:"finished_at,omitempty"` + DurationSeconds float64 `json:"duration_seconds,omitempty"` } -// TaskService manages the state of a single, global, long-running task. +// TaskService manages the state of a single, global, long-running task using the store interface. type TaskService struct { - mu sync.Mutex - status TaskStatus - resultsCache map[string]any - cacheOrder []string - maxCacheSize int + store store.Store } // NewTaskService creates a new TaskService. -func NewTaskService() *TaskService { +func NewTaskService(store store.Store) *TaskService { return &TaskService{ - resultsCache: make(map[string]any), - cacheOrder: make([]string, 0), - maxCacheSize: 100, + store: store, } } // StartTask attempts to start a new task. It returns an error if a task is already running. -func (s *TaskService) StartTask(taskID, groupName string, total int, timeout time.Duration) (*TaskStatus, error) { - s.mu.Lock() - defer s.mu.Unlock() - - // Zombie task check - if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) { - // The previous task is considered a zombie, reset it. - s.status = TaskStatus{} +func (s *TaskService) StartTask(groupName string, total int, timeout time.Duration) (*TaskStatus, error) { + currentStatus, err := s.GetTaskStatus() + if err != nil { + return nil, fmt.Errorf("failed to check current task status before starting a new one: %w", err) } - if s.status.IsRunning { - return nil, errors.New("a task is already running") + if currentStatus.IsRunning { + return nil, errors.New("a task is already running, please wait") } - s.status = TaskStatus{ - IsRunning: true, - TaskID: taskID, - GroupName: groupName, - Total: total, - Processed: 0, - ExpiresAt: time.Now().Add(timeout), - lastUpdated: time.Now(), + status := &TaskStatus{ + IsRunning: true, + GroupName: groupName, + Total: total, + Processed: 0, + StartedAt: time.Now(), + } + statusBytes, err := json.Marshal(status) + if err != nil { + return nil, fmt.Errorf("failed to serialize new task status: %w", err) } - return &s.status, nil + if err := s.store.Set(globalTaskKey, statusBytes, timeout); err != nil { + return nil, fmt.Errorf("failed to set initial task status: %w", err) + } + + return status, nil } // GetTaskStatus returns the current status of the task. -func (s *TaskService) GetTaskStatus() *TaskStatus { - s.mu.Lock() - defer s.mu.Unlock() - - // Zombie task check - if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) { - s.status = TaskStatus{} // Reset if expired +func (s *TaskService) GetTaskStatus() (*TaskStatus, error) { + statusBytes, err := s.store.Get(globalTaskKey) + if err != nil { + if errors.Is(err, store.ErrNotFound) { + return &TaskStatus{IsRunning: false}, nil + } + return nil, fmt.Errorf("failed to get task status: %w", err) } - // Return a copy to prevent race conditions on the caller's side - statusCopy := s.status - return &statusCopy + var status TaskStatus + if err := json.Unmarshal(statusBytes, &status); err != nil { + return nil, fmt.Errorf("failed to deserialize task status: %w", err) + } + + if !status.IsRunning && status.FinishedAt != nil { + status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds() + } + + return &status, nil } // UpdateProgress updates the progress of the current task. -func (s *TaskService) UpdateProgress(processed int) { - s.mu.Lock() - defer s.mu.Unlock() - - if !s.status.IsRunning { - return +func (s *TaskService) UpdateProgress(processed int) error { + status, err := s.GetTaskStatus() + if err != nil { + return err + } + if !status.IsRunning { + return nil } - s.status.Processed = processed - s.status.lastUpdated = time.Now() -} - -// EndTask marks the current task as finished. -func (s *TaskService) EndTask() { - s.mu.Lock() - defer s.mu.Unlock() - - s.status.IsRunning = false -} - -// StoreResult stores the result of a finished task. -func (s *TaskService) StoreResult(taskID string, result any) { - s.mu.Lock() - defer s.mu.Unlock() - - if _, exists := s.resultsCache[taskID]; !exists { - if len(s.cacheOrder) >= s.maxCacheSize { - oldestTaskID := s.cacheOrder[0] - delete(s.resultsCache, oldestTaskID) - s.cacheOrder = s.cacheOrder[1:] - } - s.cacheOrder = append(s.cacheOrder, taskID) + status.Processed = processed + statusBytes, err := json.Marshal(status) + if err != nil { + return fmt.Errorf("failed to serialize updated status: %w", err) } - s.resultsCache[taskID] = result + + return s.store.Set(globalTaskKey, statusBytes, ResultTTL) } -// GetResult retrieves the result of a finished task. -func (s *TaskService) GetResult(taskID string) (any, bool) { - s.mu.Lock() - defer s.mu.Unlock() +// EndTask marks the current task as finished and stores its final result. +func (s *TaskService) EndTask(resultData any, taskErr error) error { + status, err := s.GetTaskStatus() + if err != nil { + return fmt.Errorf("failed to get task object to end task: %w", err) + } + if !status.IsRunning { + return nil + } - result, found := s.resultsCache[taskID] - return result, found + now := time.Now() + status.IsRunning = false + status.FinishedAt = &now + status.DurationSeconds = now.Sub(status.StartedAt).Seconds() + if taskErr != nil { + status.Error = taskErr.Error() + } else { + status.Result = resultData + } + + updatedTaskBytes, err := json.Marshal(status) + if err != nil { + return fmt.Errorf("failed to serialize final task status: %w", err) + } + + return s.store.Set(globalTaskKey, updatedTaskBytes, ResultTTL) }