feat: 任务功能性能及接口优化

This commit is contained in:
tbphp
2025-07-05 17:19:20 +08:00
parent 2edd48d0c1
commit c40ee7cb5b
7 changed files with 135 additions and 129 deletions

View File

@@ -1,31 +1,18 @@
package handler
import (
"gpt-load/internal/response"
app_errors "gpt-load/internal/errors"
"gpt-load/internal/response"
"github.com/gin-gonic/gin"
)
// GetTaskStatus handles requests for the status of the global long-running task.
func (s *Server) GetTaskStatus(c *gin.Context) {
taskStatus := s.TaskService.GetTaskStatus()
taskStatus, err := s.TaskService.GetTaskStatus()
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, "Failed to get task status"))
return
}
response.Success(c, taskStatus)
}
// GetTaskResult handles requests for the result of a finished task.
func (s *Server) GetTaskResult(c *gin.Context) {
taskID := c.Param("task_id")
if taskID == "" {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Task ID is required"))
return
}
result, found := s.TaskService.GetResult(taskID)
if !found {
response.Error(c, app_errors.ErrResourceNotFound)
return
}
response.Success(c, result)
}

View File

@@ -124,11 +124,7 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser
}
// Tasks
tasks := api.Group("/tasks")
{
tasks.GET("/key-validation/status", serverHandler.GetTaskStatus)
tasks.GET("/:task_id/result", serverHandler.GetTaskResult)
}
api.GET("/tasks/status", serverHandler.GetTaskStatus)
// 仪表板和日志
dashboard := api.Group("/dashboard")

View File

@@ -8,7 +8,6 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
@@ -49,13 +48,12 @@ func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (*
return nil, fmt.Errorf("no keys to validate in group %s", group.Name)
}
taskID := uuid.New().String()
timeoutMinutes := s.SettingsManager.GetInt("key_validation_task_timeout_minutes", 60)
timeout := time.Duration(timeoutMinutes) * time.Minute
taskStatus, err := s.TaskService.StartTask(taskID, group.Name, len(keys), timeout)
taskStatus, err := s.TaskService.StartTask(group.Name, len(keys), timeout)
if err != nil {
return nil, err // A task is already running
return nil, err
}
// Run the validation in a separate goroutine
@@ -65,9 +63,7 @@ func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (*
}
func (s *KeyManualValidationService) runValidation(group *models.Group, keys []models.APIKey, task *TaskStatus) {
defer s.TaskService.EndTask()
logrus.Infof("Starting manual validation for group %s (TaskID: %s)", group.Name, task.TaskID)
logrus.Infof("Starting manual validation for group %s", group.Name)
jobs := make(chan models.APIKey, len(keys))
results := make(chan bool, len(keys))
@@ -88,18 +84,34 @@ func (s *KeyManualValidationService) runValidation(group *models.Group, keys []m
}
close(jobs)
wg.Wait()
close(results)
// Wait for all workers to complete in a separate goroutine
go func() {
wg.Wait()
close(results)
}()
validCount := 0
processedCount := 0
lastUpdateTime := time.Now()
for isValid := range results {
processedCount++
if isValid {
validCount++
}
// Update progress
s.TaskService.UpdateProgress(processedCount)
// Throttle progress updates to once per second
if time.Since(lastUpdateTime) > time.Second {
if err := s.TaskService.UpdateProgress(processedCount); err != nil {
logrus.Warnf("Failed to update task progress: %v", err)
}
lastUpdateTime = time.Now()
}
}
// Ensure the final progress is always updated
if err := s.TaskService.UpdateProgress(processedCount); err != nil {
logrus.Warnf("Failed to update final task progress: %v", err)
}
result := ManualValidationResult{
@@ -108,11 +120,14 @@ func (s *KeyManualValidationService) runValidation(group *models.Group, keys []m
InvalidKeys: len(keys) - validCount,
}
// Store the final result
s.TaskService.StoreResult(task.TaskID, result)
logrus.Infof("Manual validation finished for group %s (TaskID: %s): %+v", group.Name, task.TaskID, result)
// End the task and store the final result
if err := s.TaskService.EndTask(result, nil); err != nil {
logrus.Errorf("Failed to end task for group %s: %v", group.Name, err)
}
logrus.Infof("Manual validation finished for group %s: %+v", group.Name, result)
}
// validationResult 包含验证结果信息
func (s *KeyManualValidationService) validationWorker(wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- bool) {
defer wg.Done()
for key := range jobs {

View File

@@ -32,8 +32,6 @@ func NewKeyValidatorService(db *gorm.DB, factory *channel.Factory) *KeyValidator
}
// ValidateSingleKey performs a validation check on a single API key.
// It does not modify the key's state in the database.
// It returns true if the key is valid, and an error if it's not.
func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models.APIKey, group *models.Group) (bool, error) {
// 添加超时保护
if ctx.Err() != nil {
@@ -65,7 +63,7 @@ func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models
"group_id": group.ID,
"group_name": group.Name,
"error": validationErr,
}).Warn("Key validation failed")
}).Debug("Key validation failed")
return false, validationErr
}

View File

@@ -1,125 +1,138 @@
package services
import (
"encoding/json"
"errors"
"sync"
"fmt"
"gpt-load/internal/store"
"time"
)
// TaskStatus represents the status of a long-running task.
const (
globalTaskKey = "global_task:key_validation"
ResultTTL = 60 * time.Minute
)
// TaskStatus represents the full lifecycle of a long-running task.
type TaskStatus struct {
IsRunning bool `json:"is_running"`
GroupName string `json:"group_name,omitempty"`
Processed int `json:"processed,omitempty"`
Total int `json:"total,omitempty"`
TaskID string `json:"task_id,omitempty"`
ExpiresAt time.Time `json:"-"` // Internal field to handle zombie tasks
lastUpdated time.Time
IsRunning bool `json:"is_running"`
GroupName string `json:"group_name,omitempty"`
Processed int `json:"processed"`
Total int `json:"total"`
Result any `json:"result,omitempty"`
Error string `json:"error,omitempty"`
StartedAt time.Time `json:"started_at"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
DurationSeconds float64 `json:"duration_seconds,omitempty"`
}
// TaskService manages the state of a single, global, long-running task.
// TaskService manages the state of a single, global, long-running task using the store interface.
type TaskService struct {
mu sync.Mutex
status TaskStatus
resultsCache map[string]any
cacheOrder []string
maxCacheSize int
store store.Store
}
// NewTaskService creates a new TaskService.
func NewTaskService() *TaskService {
func NewTaskService(store store.Store) *TaskService {
return &TaskService{
resultsCache: make(map[string]any),
cacheOrder: make([]string, 0),
maxCacheSize: 100,
store: store,
}
}
// StartTask attempts to start a new task. It returns an error if a task is already running.
func (s *TaskService) StartTask(taskID, groupName string, total int, timeout time.Duration) (*TaskStatus, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Zombie task check
if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) {
// The previous task is considered a zombie, reset it.
s.status = TaskStatus{}
func (s *TaskService) StartTask(groupName string, total int, timeout time.Duration) (*TaskStatus, error) {
currentStatus, err := s.GetTaskStatus()
if err != nil {
return nil, fmt.Errorf("failed to check current task status before starting a new one: %w", err)
}
if s.status.IsRunning {
return nil, errors.New("a task is already running")
if currentStatus.IsRunning {
return nil, errors.New("a task is already running, please wait")
}
s.status = TaskStatus{
IsRunning: true,
TaskID: taskID,
GroupName: groupName,
Total: total,
Processed: 0,
ExpiresAt: time.Now().Add(timeout),
lastUpdated: time.Now(),
status := &TaskStatus{
IsRunning: true,
GroupName: groupName,
Total: total,
Processed: 0,
StartedAt: time.Now(),
}
statusBytes, err := json.Marshal(status)
if err != nil {
return nil, fmt.Errorf("failed to serialize new task status: %w", err)
}
return &s.status, nil
if err := s.store.Set(globalTaskKey, statusBytes, timeout); err != nil {
return nil, fmt.Errorf("failed to set initial task status: %w", err)
}
return status, nil
}
// GetTaskStatus returns the current status of the task.
func (s *TaskService) GetTaskStatus() *TaskStatus {
s.mu.Lock()
defer s.mu.Unlock()
// Zombie task check
if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) {
s.status = TaskStatus{} // Reset if expired
func (s *TaskService) GetTaskStatus() (*TaskStatus, error) {
statusBytes, err := s.store.Get(globalTaskKey)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return &TaskStatus{IsRunning: false}, nil
}
return nil, fmt.Errorf("failed to get task status: %w", err)
}
// Return a copy to prevent race conditions on the caller's side
statusCopy := s.status
return &statusCopy
var status TaskStatus
if err := json.Unmarshal(statusBytes, &status); err != nil {
return nil, fmt.Errorf("failed to deserialize task status: %w", err)
}
if !status.IsRunning && status.FinishedAt != nil {
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
}
return &status, nil
}
// UpdateProgress updates the progress of the current task.
func (s *TaskService) UpdateProgress(processed int) {
s.mu.Lock()
defer s.mu.Unlock()
if !s.status.IsRunning {
return
func (s *TaskService) UpdateProgress(processed int) error {
status, err := s.GetTaskStatus()
if err != nil {
return err
}
if !status.IsRunning {
return nil
}
s.status.Processed = processed
s.status.lastUpdated = time.Now()
}
// EndTask marks the current task as finished.
func (s *TaskService) EndTask() {
s.mu.Lock()
defer s.mu.Unlock()
s.status.IsRunning = false
}
// StoreResult stores the result of a finished task.
func (s *TaskService) StoreResult(taskID string, result any) {
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.resultsCache[taskID]; !exists {
if len(s.cacheOrder) >= s.maxCacheSize {
oldestTaskID := s.cacheOrder[0]
delete(s.resultsCache, oldestTaskID)
s.cacheOrder = s.cacheOrder[1:]
}
s.cacheOrder = append(s.cacheOrder, taskID)
status.Processed = processed
statusBytes, err := json.Marshal(status)
if err != nil {
return fmt.Errorf("failed to serialize updated status: %w", err)
}
s.resultsCache[taskID] = result
return s.store.Set(globalTaskKey, statusBytes, ResultTTL)
}
// GetResult retrieves the result of a finished task.
func (s *TaskService) GetResult(taskID string) (any, bool) {
s.mu.Lock()
defer s.mu.Unlock()
// EndTask marks the current task as finished and stores its final result.
func (s *TaskService) EndTask(resultData any, taskErr error) error {
status, err := s.GetTaskStatus()
if err != nil {
return fmt.Errorf("failed to get task object to end task: %w", err)
}
if !status.IsRunning {
return nil
}
result, found := s.resultsCache[taskID]
return result, found
now := time.Now()
status.IsRunning = false
status.FinishedAt = &now
status.DurationSeconds = now.Sub(status.StartedAt).Seconds()
if taskErr != nil {
status.Error = taskErr.Error()
} else {
status.Result = resultData
}
updatedTaskBytes, err := json.Marshal(status)
if err != nil {
return fmt.Errorf("failed to serialize final task status: %w", err)
}
return s.store.Set(globalTaskKey, updatedTaskBytes, ResultTTL)
}