feat: 密钥调整为异步任务,取消数量限制
This commit is contained in:
@@ -51,6 +51,9 @@ func BuildContainer() (*dig.Container, error) {
|
||||
if err := container.Provide(services.NewKeyService); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := container.Provide(services.NewKeyImportService); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := container.Provide(services.NewLogService); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -23,6 +23,7 @@ type Server struct {
|
||||
KeyManualValidationService *services.KeyManualValidationService
|
||||
TaskService *services.TaskService
|
||||
KeyService *services.KeyService
|
||||
KeyImportService *services.KeyImportService
|
||||
LogService *services.LogService
|
||||
CommonHandler *CommonHandler
|
||||
}
|
||||
@@ -37,6 +38,7 @@ type NewServerParams struct {
|
||||
KeyManualValidationService *services.KeyManualValidationService
|
||||
TaskService *services.TaskService
|
||||
KeyService *services.KeyService
|
||||
KeyImportService *services.KeyImportService
|
||||
LogService *services.LogService
|
||||
CommonHandler *CommonHandler
|
||||
}
|
||||
@@ -51,6 +53,7 @@ func NewServer(params NewServerParams) *Server {
|
||||
KeyManualValidationService: params.KeyManualValidationService,
|
||||
TaskService: params.TaskService,
|
||||
KeyService: params.KeyService,
|
||||
KeyImportService: params.KeyImportService,
|
||||
LogService: params.LogService,
|
||||
CommonHandler: params.CommonHandler,
|
||||
}
|
||||
|
@@ -34,10 +34,6 @@ func validateKeysText(keysText string) error {
|
||||
return fmt.Errorf("keys text cannot be empty")
|
||||
}
|
||||
|
||||
if len(keysText) > 10*1024*1024 {
|
||||
return fmt.Errorf("keys text is too large (max 10MB)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -98,6 +94,33 @@ func (s *Server) AddMultipleKeys(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// AddMultipleKeysAsync handles creating new keys from a text block within a specific group.
|
||||
func (s *Server) AddMultipleKeysAsync(c *gin.Context) {
|
||||
var req KeyTextRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
group, ok := s.findGroupByID(c, req.GroupID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateKeysText(req.KeysText); err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
taskStatus, err := s.KeyImportService.StartImportTask(group, req.KeysText)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// ListKeysInGroup handles listing all keys within a specific group with pagination.
|
||||
func (s *Server) ListKeysInGroup(c *gin.Context) {
|
||||
groupID, err := validateGroupIDFromQuery(c)
|
||||
|
@@ -123,6 +123,7 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser
|
||||
keys.GET("", serverHandler.ListKeysInGroup)
|
||||
keys.GET("/export", serverHandler.ExportKeys)
|
||||
keys.POST("/add-multiple", serverHandler.AddMultipleKeys)
|
||||
keys.POST("/add-async", serverHandler.AddMultipleKeysAsync)
|
||||
keys.POST("/delete-multiple", serverHandler.DeleteMultipleKeys)
|
||||
keys.POST("/restore-multiple", serverHandler.RestoreMultipleKeys)
|
||||
keys.POST("/restore-all-invalid", serverHandler.RestoreAllInvalidKeys)
|
||||
|
76
internal/services/key_import_service.go
Normal file
76
internal/services/key_import_service.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gpt-load/internal/models"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
importChunkSize = 1000
|
||||
importTimeout = 24 * time.Hour
|
||||
)
|
||||
|
||||
// KeyImportResult holds the result of an import task.
|
||||
type KeyImportResult struct {
|
||||
AddedCount int `json:"added_count"`
|
||||
IgnoredCount int `json:"ignored_count"`
|
||||
}
|
||||
|
||||
// KeyImportService handles the asynchronous import of a large number of keys.
|
||||
type KeyImportService struct {
|
||||
TaskService *TaskService
|
||||
KeyService *KeyService
|
||||
}
|
||||
|
||||
// NewKeyImportService creates a new KeyImportService.
|
||||
func NewKeyImportService(taskService *TaskService, keyService *KeyService) *KeyImportService {
|
||||
return &KeyImportService{
|
||||
TaskService: taskService,
|
||||
KeyService: keyService,
|
||||
}
|
||||
}
|
||||
|
||||
// StartImportTask initiates a new asynchronous key import task.
|
||||
func (s *KeyImportService) StartImportTask(group *models.Group, keysText string) (*TaskStatus, error) {
|
||||
keys := s.KeyService.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found in the input text")
|
||||
}
|
||||
|
||||
initialStatus, err := s.TaskService.StartTask(TaskTypeKeyImport, group.Name, len(keys), importTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.runImport(group, keys)
|
||||
|
||||
return initialStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runImport(group *models.Group, keys []string) {
|
||||
progressCallback := func(processed int) {
|
||||
if err := s.TaskService.UpdateProgress(processed); err != nil {
|
||||
logrus.Warnf("Failed to update task progress for group %d: %v", group.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
addedCount, ignoredCount, err := s.KeyService.processAndCreateKeys(group.ID, keys, progressCallback)
|
||||
if err != nil {
|
||||
if endErr := s.TaskService.EndTask(nil, err); endErr != nil {
|
||||
logrus.Errorf("Failed to end task with error for group %d: %v (original error: %v)", group.ID, endErr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
result := KeyImportResult{
|
||||
AddedCount: addedCount,
|
||||
IgnoredCount: ignoredCount,
|
||||
}
|
||||
|
||||
if endErr := s.TaskService.EndTask(result, nil); endErr != nil {
|
||||
logrus.Errorf("Failed to end task with success result for group %d: %v", group.ID, endErr)
|
||||
}
|
||||
}
|
@@ -53,7 +53,7 @@ func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (*
|
||||
|
||||
timeout := 30 * time.Minute
|
||||
|
||||
taskStatus, err := s.TaskService.StartTask(group.Name, len(keys), timeout)
|
||||
taskStatus, err := s.TaskService.StartTask(TaskTypeKeyValidation, group.Name, len(keys), timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -55,8 +55,8 @@ func NewKeyService(db *gorm.DB, keyProvider *keypool.KeyProvider, keyValidator *
|
||||
}
|
||||
|
||||
// AddMultipleKeys handles the business logic of creating new keys from a text block.
|
||||
// deprecated: use KeyImportService for large imports
|
||||
func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysResult, error) {
|
||||
// 1. Parse keys from the text block
|
||||
keys := s.ParseKeysFromText(keysText)
|
||||
if len(keys) > maxRequestKeys {
|
||||
return nil, fmt.Errorf("batch size exceeds the limit of %d keys, got %d", maxRequestKeys, len(keys))
|
||||
@@ -65,17 +65,40 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes
|
||||
return nil, fmt.Errorf("no valid keys found in the input text")
|
||||
}
|
||||
|
||||
// 2. Get existing keys in the group for deduplication
|
||||
addedCount, ignoredCount, err := s.processAndCreateKeys(groupID, keys, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var totalInGroup int64
|
||||
if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AddKeysResult{
|
||||
AddedCount: addedCount,
|
||||
IgnoredCount: ignoredCount,
|
||||
TotalInGroup: totalInGroup,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// processAndCreateKeys is the lowest-level reusable function for adding keys.
|
||||
func (s *KeyService) processAndCreateKeys(
|
||||
groupID uint,
|
||||
keys []string,
|
||||
progressCallback func(processed int),
|
||||
) (addedCount int, ignoredCount int, err error) {
|
||||
// 1. Get existing keys in the group for deduplication
|
||||
var existingKeys []models.APIKey
|
||||
if err := s.DB.Where("group_id = ?", groupID).Select("key_value").Find(&existingKeys).Error; err != nil {
|
||||
return nil, err
|
||||
return 0, 0, err
|
||||
}
|
||||
existingKeyMap := make(map[string]bool)
|
||||
for _, k := range existingKeys {
|
||||
existingKeyMap[k.KeyValue] = true
|
||||
}
|
||||
|
||||
// 3. Prepare new keys for creation
|
||||
// 2. Prepare new keys for creation
|
||||
var newKeysToCreate []models.APIKey
|
||||
uniqueNewKeys := make(map[string]bool)
|
||||
|
||||
@@ -98,14 +121,10 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes
|
||||
}
|
||||
|
||||
if len(newKeysToCreate) == 0 {
|
||||
return &AddKeysResult{
|
||||
AddedCount: 0,
|
||||
IgnoredCount: len(keys),
|
||||
TotalInGroup: int64(len(existingKeys)),
|
||||
}, nil
|
||||
return 0, len(keys), nil
|
||||
}
|
||||
|
||||
// 4. Use KeyProvider to add keys in chunks
|
||||
// 3. Use KeyProvider to add keys in chunks
|
||||
for i := 0; i < len(newKeysToCreate); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(newKeysToCreate) {
|
||||
@@ -113,18 +132,16 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes
|
||||
}
|
||||
chunk := newKeysToCreate[i:end]
|
||||
if err := s.KeyProvider.AddKeys(groupID, chunk); err != nil {
|
||||
return nil, err
|
||||
return addedCount, len(keys) - addedCount, err
|
||||
}
|
||||
addedCount += len(chunk)
|
||||
|
||||
if progressCallback != nil {
|
||||
progressCallback(i + len(chunk))
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Calculate new total count
|
||||
totalInGroup := int64(len(existingKeys) + len(newKeysToCreate))
|
||||
|
||||
return &AddKeysResult{
|
||||
AddedCount: len(newKeysToCreate),
|
||||
IgnoredCount: len(keys) - len(newKeysToCreate),
|
||||
TotalInGroup: totalInGroup,
|
||||
}, nil
|
||||
return addedCount, len(keys) - addedCount, nil
|
||||
}
|
||||
|
||||
// ParseKeysFromText parses a string of keys from various formats into a string slice.
|
||||
|
@@ -9,12 +9,18 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
globalTaskKey = "global_task:key_validation"
|
||||
ResultTTL = 60 * time.Minute
|
||||
globalTaskKey = "global_task"
|
||||
ResultTTL = 60 * time.Minute
|
||||
)
|
||||
|
||||
const (
|
||||
TaskTypeKeyValidation = "KEY_VALIDATION"
|
||||
TaskTypeKeyImport = "KEY_IMPORT"
|
||||
)
|
||||
|
||||
// TaskStatus represents the full lifecycle of a long-running task.
|
||||
type TaskStatus struct {
|
||||
TaskType string `json:"task_type"`
|
||||
IsRunning bool `json:"is_running"`
|
||||
GroupName string `json:"group_name,omitempty"`
|
||||
Processed int `json:"processed"`
|
||||
@@ -39,7 +45,7 @@ func NewTaskService(store store.Store) *TaskService {
|
||||
}
|
||||
|
||||
// StartTask attempts to start a new task. It returns an error if a task is already running.
|
||||
func (s *TaskService) StartTask(groupName string, total int, timeout time.Duration) (*TaskStatus, error) {
|
||||
func (s *TaskService) StartTask(taskType, groupName string, total int, timeout time.Duration) (*TaskStatus, error) {
|
||||
currentStatus, err := s.GetTaskStatus()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check current task status before starting a new one: %w", err)
|
||||
@@ -50,6 +56,7 @@ func (s *TaskService) StartTask(groupName string, total int, timeout time.Durati
|
||||
}
|
||||
|
||||
status := &TaskStatus{
|
||||
TaskType: taskType,
|
||||
IsRunning: true,
|
||||
GroupName: groupName,
|
||||
Total: total,
|
||||
|
Reference in New Issue
Block a user