feat: 任务功能性能及接口优化
This commit is contained in:
@@ -1,31 +1,18 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"gpt-load/internal/response"
|
||||
app_errors "gpt-load/internal/errors"
|
||||
"gpt-load/internal/response"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetTaskStatus handles requests for the status of the global long-running task.
|
||||
func (s *Server) GetTaskStatus(c *gin.Context) {
|
||||
taskStatus := s.TaskService.GetTaskStatus()
|
||||
taskStatus, err := s.TaskService.GetTaskStatus()
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, "Failed to get task status"))
|
||||
return
|
||||
}
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// GetTaskResult handles requests for the result of a finished task.
|
||||
func (s *Server) GetTaskResult(c *gin.Context) {
|
||||
taskID := c.Param("task_id")
|
||||
if taskID == "" {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Task ID is required"))
|
||||
return
|
||||
}
|
||||
|
||||
result, found := s.TaskService.GetResult(taskID)
|
||||
if !found {
|
||||
response.Error(c, app_errors.ErrResourceNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
@@ -124,11 +124,7 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser
|
||||
}
|
||||
|
||||
// Tasks
|
||||
tasks := api.Group("/tasks")
|
||||
{
|
||||
tasks.GET("/key-validation/status", serverHandler.GetTaskStatus)
|
||||
tasks.GET("/:task_id/result", serverHandler.GetTaskResult)
|
||||
}
|
||||
api.GET("/tasks/status", serverHandler.GetTaskStatus)
|
||||
|
||||
// 仪表板和日志
|
||||
dashboard := api.Group("/dashboard")
|
||||
|
@@ -8,7 +8,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -49,13 +48,12 @@ func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (*
|
||||
return nil, fmt.Errorf("no keys to validate in group %s", group.Name)
|
||||
}
|
||||
|
||||
taskID := uuid.New().String()
|
||||
timeoutMinutes := s.SettingsManager.GetInt("key_validation_task_timeout_minutes", 60)
|
||||
timeout := time.Duration(timeoutMinutes) * time.Minute
|
||||
|
||||
taskStatus, err := s.TaskService.StartTask(taskID, group.Name, len(keys), timeout)
|
||||
taskStatus, err := s.TaskService.StartTask(group.Name, len(keys), timeout)
|
||||
if err != nil {
|
||||
return nil, err // A task is already running
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Run the validation in a separate goroutine
|
||||
@@ -65,9 +63,7 @@ func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (*
|
||||
}
|
||||
|
||||
func (s *KeyManualValidationService) runValidation(group *models.Group, keys []models.APIKey, task *TaskStatus) {
|
||||
defer s.TaskService.EndTask()
|
||||
|
||||
logrus.Infof("Starting manual validation for group %s (TaskID: %s)", group.Name, task.TaskID)
|
||||
logrus.Infof("Starting manual validation for group %s", group.Name)
|
||||
|
||||
jobs := make(chan models.APIKey, len(keys))
|
||||
results := make(chan bool, len(keys))
|
||||
@@ -88,18 +84,34 @@ func (s *KeyManualValidationService) runValidation(group *models.Group, keys []m
|
||||
}
|
||||
close(jobs)
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
// Wait for all workers to complete in a separate goroutine
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}()
|
||||
|
||||
validCount := 0
|
||||
processedCount := 0
|
||||
lastUpdateTime := time.Now()
|
||||
|
||||
for isValid := range results {
|
||||
processedCount++
|
||||
if isValid {
|
||||
validCount++
|
||||
}
|
||||
// Update progress
|
||||
s.TaskService.UpdateProgress(processedCount)
|
||||
|
||||
// Throttle progress updates to once per second
|
||||
if time.Since(lastUpdateTime) > time.Second {
|
||||
if err := s.TaskService.UpdateProgress(processedCount); err != nil {
|
||||
logrus.Warnf("Failed to update task progress: %v", err)
|
||||
}
|
||||
lastUpdateTime = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the final progress is always updated
|
||||
if err := s.TaskService.UpdateProgress(processedCount); err != nil {
|
||||
logrus.Warnf("Failed to update final task progress: %v", err)
|
||||
}
|
||||
|
||||
result := ManualValidationResult{
|
||||
@@ -108,11 +120,14 @@ func (s *KeyManualValidationService) runValidation(group *models.Group, keys []m
|
||||
InvalidKeys: len(keys) - validCount,
|
||||
}
|
||||
|
||||
// Store the final result
|
||||
s.TaskService.StoreResult(task.TaskID, result)
|
||||
logrus.Infof("Manual validation finished for group %s (TaskID: %s): %+v", group.Name, task.TaskID, result)
|
||||
// End the task and store the final result
|
||||
if err := s.TaskService.EndTask(result, nil); err != nil {
|
||||
logrus.Errorf("Failed to end task for group %s: %v", group.Name, err)
|
||||
}
|
||||
logrus.Infof("Manual validation finished for group %s: %+v", group.Name, result)
|
||||
}
|
||||
|
||||
// validationResult 包含验证结果信息
|
||||
func (s *KeyManualValidationService) validationWorker(wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- bool) {
|
||||
defer wg.Done()
|
||||
for key := range jobs {
|
||||
|
@@ -32,8 +32,6 @@ func NewKeyValidatorService(db *gorm.DB, factory *channel.Factory) *KeyValidator
|
||||
}
|
||||
|
||||
// ValidateSingleKey performs a validation check on a single API key.
|
||||
// It does not modify the key's state in the database.
|
||||
// It returns true if the key is valid, and an error if it's not.
|
||||
func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models.APIKey, group *models.Group) (bool, error) {
|
||||
// 添加超时保护
|
||||
if ctx.Err() != nil {
|
||||
@@ -65,7 +63,7 @@ func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models
|
||||
"group_id": group.ID,
|
||||
"group_name": group.Name,
|
||||
"error": validationErr,
|
||||
}).Warn("Key validation failed")
|
||||
}).Debug("Key validation failed")
|
||||
return false, validationErr
|
||||
}
|
||||
|
||||
|
@@ -1,125 +1,138 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"fmt"
|
||||
"gpt-load/internal/store"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TaskStatus represents the status of a long-running task.
|
||||
const (
|
||||
globalTaskKey = "global_task:key_validation"
|
||||
ResultTTL = 60 * time.Minute
|
||||
)
|
||||
|
||||
// TaskStatus represents the full lifecycle of a long-running task.
|
||||
type TaskStatus struct {
|
||||
IsRunning bool `json:"is_running"`
|
||||
GroupName string `json:"group_name,omitempty"`
|
||||
Processed int `json:"processed,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"` // Internal field to handle zombie tasks
|
||||
lastUpdated time.Time
|
||||
IsRunning bool `json:"is_running"`
|
||||
GroupName string `json:"group_name,omitempty"`
|
||||
Processed int `json:"processed"`
|
||||
Total int `json:"total"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
||||
DurationSeconds float64 `json:"duration_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// TaskService manages the state of a single, global, long-running task.
|
||||
// TaskService manages the state of a single, global, long-running task using the store interface.
|
||||
type TaskService struct {
|
||||
mu sync.Mutex
|
||||
status TaskStatus
|
||||
resultsCache map[string]any
|
||||
cacheOrder []string
|
||||
maxCacheSize int
|
||||
store store.Store
|
||||
}
|
||||
|
||||
// NewTaskService creates a new TaskService.
|
||||
func NewTaskService() *TaskService {
|
||||
func NewTaskService(store store.Store) *TaskService {
|
||||
return &TaskService{
|
||||
resultsCache: make(map[string]any),
|
||||
cacheOrder: make([]string, 0),
|
||||
maxCacheSize: 100,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// StartTask attempts to start a new task. It returns an error if a task is already running.
|
||||
func (s *TaskService) StartTask(taskID, groupName string, total int, timeout time.Duration) (*TaskStatus, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Zombie task check
|
||||
if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) {
|
||||
// The previous task is considered a zombie, reset it.
|
||||
s.status = TaskStatus{}
|
||||
func (s *TaskService) StartTask(groupName string, total int, timeout time.Duration) (*TaskStatus, error) {
|
||||
currentStatus, err := s.GetTaskStatus()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check current task status before starting a new one: %w", err)
|
||||
}
|
||||
|
||||
if s.status.IsRunning {
|
||||
return nil, errors.New("a task is already running")
|
||||
if currentStatus.IsRunning {
|
||||
return nil, errors.New("a task is already running, please wait")
|
||||
}
|
||||
|
||||
s.status = TaskStatus{
|
||||
IsRunning: true,
|
||||
TaskID: taskID,
|
||||
GroupName: groupName,
|
||||
Total: total,
|
||||
Processed: 0,
|
||||
ExpiresAt: time.Now().Add(timeout),
|
||||
lastUpdated: time.Now(),
|
||||
status := &TaskStatus{
|
||||
IsRunning: true,
|
||||
GroupName: groupName,
|
||||
Total: total,
|
||||
Processed: 0,
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
statusBytes, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to serialize new task status: %w", err)
|
||||
}
|
||||
|
||||
return &s.status, nil
|
||||
if err := s.store.Set(globalTaskKey, statusBytes, timeout); err != nil {
|
||||
return nil, fmt.Errorf("failed to set initial task status: %w", err)
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// GetTaskStatus returns the current status of the task.
|
||||
func (s *TaskService) GetTaskStatus() *TaskStatus {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Zombie task check
|
||||
if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) {
|
||||
s.status = TaskStatus{} // Reset if expired
|
||||
func (s *TaskService) GetTaskStatus() (*TaskStatus, error) {
|
||||
statusBytes, err := s.store.Get(globalTaskKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
return &TaskStatus{IsRunning: false}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get task status: %w", err)
|
||||
}
|
||||
|
||||
// Return a copy to prevent race conditions on the caller's side
|
||||
statusCopy := s.status
|
||||
return &statusCopy
|
||||
var status TaskStatus
|
||||
if err := json.Unmarshal(statusBytes, &status); err != nil {
|
||||
return nil, fmt.Errorf("failed to deserialize task status: %w", err)
|
||||
}
|
||||
|
||||
if !status.IsRunning && status.FinishedAt != nil {
|
||||
status.DurationSeconds = status.FinishedAt.Sub(status.StartedAt).Seconds()
|
||||
}
|
||||
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
// UpdateProgress updates the progress of the current task.
|
||||
func (s *TaskService) UpdateProgress(processed int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.status.IsRunning {
|
||||
return
|
||||
func (s *TaskService) UpdateProgress(processed int) error {
|
||||
status, err := s.GetTaskStatus()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !status.IsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.status.Processed = processed
|
||||
s.status.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
// EndTask marks the current task as finished.
|
||||
func (s *TaskService) EndTask() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.status.IsRunning = false
|
||||
}
|
||||
|
||||
// StoreResult stores the result of a finished task.
|
||||
func (s *TaskService) StoreResult(taskID string, result any) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.resultsCache[taskID]; !exists {
|
||||
if len(s.cacheOrder) >= s.maxCacheSize {
|
||||
oldestTaskID := s.cacheOrder[0]
|
||||
delete(s.resultsCache, oldestTaskID)
|
||||
s.cacheOrder = s.cacheOrder[1:]
|
||||
}
|
||||
s.cacheOrder = append(s.cacheOrder, taskID)
|
||||
status.Processed = processed
|
||||
statusBytes, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize updated status: %w", err)
|
||||
}
|
||||
s.resultsCache[taskID] = result
|
||||
|
||||
return s.store.Set(globalTaskKey, statusBytes, ResultTTL)
|
||||
}
|
||||
|
||||
// GetResult retrieves the result of a finished task.
|
||||
func (s *TaskService) GetResult(taskID string) (any, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// EndTask marks the current task as finished and stores its final result.
|
||||
func (s *TaskService) EndTask(resultData any, taskErr error) error {
|
||||
status, err := s.GetTaskStatus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get task object to end task: %w", err)
|
||||
}
|
||||
if !status.IsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
result, found := s.resultsCache[taskID]
|
||||
return result, found
|
||||
now := time.Now()
|
||||
status.IsRunning = false
|
||||
status.FinishedAt = &now
|
||||
status.DurationSeconds = now.Sub(status.StartedAt).Seconds()
|
||||
if taskErr != nil {
|
||||
status.Error = taskErr.Error()
|
||||
} else {
|
||||
status.Result = resultData
|
||||
}
|
||||
|
||||
updatedTaskBytes, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize final task status: %w", err)
|
||||
}
|
||||
|
||||
return s.store.Set(globalTaskKey, updatedTaskBytes, ResultTTL)
|
||||
}
|
||||
|
Reference in New Issue
Block a user