diff --git a/internal/container/container.go b/internal/container/container.go index 06f2d35..2362fce 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -51,6 +51,9 @@ func BuildContainer() (*dig.Container, error) { if err := container.Provide(services.NewKeyService); err != nil { return nil, err } + if err := container.Provide(services.NewKeyImportService); err != nil { + return nil, err + } if err := container.Provide(services.NewLogService); err != nil { return nil, err } diff --git a/internal/handler/handler.go b/internal/handler/handler.go index a4e5e5c..f7a4d5c 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -23,6 +23,7 @@ type Server struct { KeyManualValidationService *services.KeyManualValidationService TaskService *services.TaskService KeyService *services.KeyService + KeyImportService *services.KeyImportService LogService *services.LogService CommonHandler *CommonHandler } @@ -37,6 +38,7 @@ type NewServerParams struct { KeyManualValidationService *services.KeyManualValidationService TaskService *services.TaskService KeyService *services.KeyService + KeyImportService *services.KeyImportService LogService *services.LogService CommonHandler *CommonHandler } @@ -51,6 +53,7 @@ func NewServer(params NewServerParams) *Server { KeyManualValidationService: params.KeyManualValidationService, TaskService: params.TaskService, KeyService: params.KeyService, + KeyImportService: params.KeyImportService, LogService: params.LogService, CommonHandler: params.CommonHandler, } diff --git a/internal/handler/key_handler.go b/internal/handler/key_handler.go index 0c00fc9..09949b5 100644 --- a/internal/handler/key_handler.go +++ b/internal/handler/key_handler.go @@ -34,10 +34,6 @@ func validateKeysText(keysText string) error { return fmt.Errorf("keys text cannot be empty") } - if len(keysText) > 10*1024*1024 { - return fmt.Errorf("keys text is too large (max 10MB)") - } - return nil } @@ -98,6 +94,33 @@ func (s *Server) AddMultipleKeys(c *gin.Context) { response.Success(c, result) } +// AddMultipleKeysAsync handles creating new keys from a text block within a specific group. +func (s *Server) AddMultipleKeysAsync(c *gin.Context) { + var req KeyTextRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) + return + } + + group, ok := s.findGroupByID(c, req.GroupID) + if !ok { + return + } + + if err := validateKeysText(req.KeysText); err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error())) + return + } + + taskStatus, err := s.KeyImportService.StartImportTask(group, req.KeysText) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrTaskInProgress, err.Error())) + return + } + + response.Success(c, taskStatus) +} + // ListKeysInGroup handles listing all keys within a specific group with pagination. func (s *Server) ListKeysInGroup(c *gin.Context) { groupID, err := validateGroupIDFromQuery(c) diff --git a/internal/router/router.go b/internal/router/router.go index 849387b..57acb5a 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -123,6 +123,7 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser keys.GET("", serverHandler.ListKeysInGroup) keys.GET("/export", serverHandler.ExportKeys) keys.POST("/add-multiple", serverHandler.AddMultipleKeys) + keys.POST("/add-async", serverHandler.AddMultipleKeysAsync) keys.POST("/delete-multiple", serverHandler.DeleteMultipleKeys) keys.POST("/restore-multiple", serverHandler.RestoreMultipleKeys) keys.POST("/restore-all-invalid", serverHandler.RestoreAllInvalidKeys) diff --git a/internal/services/key_import_service.go b/internal/services/key_import_service.go new file mode 100644 index 0000000..d4d7d8c --- /dev/null +++ b/internal/services/key_import_service.go @@ -0,0 +1,76 @@ +package services + +import ( + "fmt" + "gpt-load/internal/models" + "time" + + "github.com/sirupsen/logrus" +) + +const ( + importChunkSize = 1000 + importTimeout = 24 * time.Hour +) + +// KeyImportResult holds the result of an import task. +type KeyImportResult struct { + AddedCount int `json:"added_count"` + IgnoredCount int `json:"ignored_count"` +} + +// KeyImportService handles the asynchronous import of a large number of keys. +type KeyImportService struct { + TaskService *TaskService + KeyService *KeyService +} + +// NewKeyImportService creates a new KeyImportService. +func NewKeyImportService(taskService *TaskService, keyService *KeyService) *KeyImportService { + return &KeyImportService{ + TaskService: taskService, + KeyService: keyService, + } +} + +// StartImportTask initiates a new asynchronous key import task. +func (s *KeyImportService) StartImportTask(group *models.Group, keysText string) (*TaskStatus, error) { + keys := s.KeyService.ParseKeysFromText(keysText) + if len(keys) == 0 { + return nil, fmt.Errorf("no valid keys found in the input text") + } + + initialStatus, err := s.TaskService.StartTask(TaskTypeKeyImport, group.Name, len(keys), importTimeout) + if err != nil { + return nil, err + } + + go s.runImport(group, keys) + + return initialStatus, nil +} + +func (s *KeyImportService) runImport(group *models.Group, keys []string) { + progressCallback := func(processed int) { + if err := s.TaskService.UpdateProgress(processed); err != nil { + logrus.Warnf("Failed to update task progress for group %d: %v", group.ID, err) + } + } + + addedCount, ignoredCount, err := s.KeyService.processAndCreateKeys(group.ID, keys, progressCallback) + if err != nil { + if endErr := s.TaskService.EndTask(nil, err); endErr != nil { + logrus.Errorf("Failed to end task with error for group %d: %v (original error: %v)", group.ID, endErr, err) + } + return + } + + result := KeyImportResult{ + AddedCount: addedCount, + IgnoredCount: ignoredCount, + } + + if endErr := s.TaskService.EndTask(result, nil); endErr != nil { + logrus.Errorf("Failed to end task with success result for group %d: %v", group.ID, endErr) + } +} diff --git a/internal/services/key_manual_validation_service.go b/internal/services/key_manual_validation_service.go index 811f29f..80ef858 100644 --- a/internal/services/key_manual_validation_service.go +++ b/internal/services/key_manual_validation_service.go @@ -53,7 +53,7 @@ func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (* timeout := 30 * time.Minute - taskStatus, err := s.TaskService.StartTask(group.Name, len(keys), timeout) + taskStatus, err := s.TaskService.StartTask(TaskTypeKeyValidation, group.Name, len(keys), timeout) if err != nil { return nil, err } diff --git a/internal/services/key_service.go b/internal/services/key_service.go index bb7b241..625b692 100644 --- a/internal/services/key_service.go +++ b/internal/services/key_service.go @@ -55,8 +55,8 @@ func NewKeyService(db *gorm.DB, keyProvider *keypool.KeyProvider, keyValidator * } // AddMultipleKeys handles the business logic of creating new keys from a text block. +// deprecated: use KeyImportService for large imports func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysResult, error) { - // 1. Parse keys from the text block keys := s.ParseKeysFromText(keysText) if len(keys) > maxRequestKeys { return nil, fmt.Errorf("batch size exceeds the limit of %d keys, got %d", maxRequestKeys, len(keys)) @@ -65,17 +65,40 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes return nil, fmt.Errorf("no valid keys found in the input text") } - // 2. Get existing keys in the group for deduplication + addedCount, ignoredCount, err := s.processAndCreateKeys(groupID, keys, nil) + if err != nil { + return nil, err + } + + var totalInGroup int64 + if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil { + return nil, err + } + + return &AddKeysResult{ + AddedCount: addedCount, + IgnoredCount: ignoredCount, + TotalInGroup: totalInGroup, + }, nil +} + +// processAndCreateKeys is the lowest-level reusable function for adding keys. +func (s *KeyService) processAndCreateKeys( + groupID uint, + keys []string, + progressCallback func(processed int), +) (addedCount int, ignoredCount int, err error) { + // 1. Get existing keys in the group for deduplication var existingKeys []models.APIKey if err := s.DB.Where("group_id = ?", groupID).Select("key_value").Find(&existingKeys).Error; err != nil { - return nil, err + return 0, 0, err } existingKeyMap := make(map[string]bool) for _, k := range existingKeys { existingKeyMap[k.KeyValue] = true } - // 3. Prepare new keys for creation + // 2. Prepare new keys for creation var newKeysToCreate []models.APIKey uniqueNewKeys := make(map[string]bool) @@ -98,14 +121,10 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes } if len(newKeysToCreate) == 0 { - return &AddKeysResult{ - AddedCount: 0, - IgnoredCount: len(keys), - TotalInGroup: int64(len(existingKeys)), - }, nil + return 0, len(keys), nil } - // 4. Use KeyProvider to add keys in chunks + // 3. Use KeyProvider to add keys in chunks for i := 0; i < len(newKeysToCreate); i += chunkSize { end := i + chunkSize if end > len(newKeysToCreate) { @@ -113,18 +132,16 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes } chunk := newKeysToCreate[i:end] if err := s.KeyProvider.AddKeys(groupID, chunk); err != nil { - return nil, err + return addedCount, len(keys) - addedCount, err + } + addedCount += len(chunk) + + if progressCallback != nil { + progressCallback(i + len(chunk)) } } - // 5. Calculate new total count - totalInGroup := int64(len(existingKeys) + len(newKeysToCreate)) - - return &AddKeysResult{ - AddedCount: len(newKeysToCreate), - IgnoredCount: len(keys) - len(newKeysToCreate), - TotalInGroup: totalInGroup, - }, nil + return addedCount, len(keys) - addedCount, nil } // ParseKeysFromText parses a string of keys from various formats into a string slice. diff --git a/internal/services/task_service.go b/internal/services/task_service.go index c6be925..7be1eb9 100644 --- a/internal/services/task_service.go +++ b/internal/services/task_service.go @@ -9,12 +9,18 @@ import ( ) const ( - globalTaskKey = "global_task:key_validation" - ResultTTL = 60 * time.Minute + globalTaskKey = "global_task" + ResultTTL = 60 * time.Minute +) + +const ( + TaskTypeKeyValidation = "KEY_VALIDATION" + TaskTypeKeyImport = "KEY_IMPORT" ) // TaskStatus represents the full lifecycle of a long-running task. type TaskStatus struct { + TaskType string `json:"task_type"` IsRunning bool `json:"is_running"` GroupName string `json:"group_name,omitempty"` Processed int `json:"processed"` @@ -39,7 +45,7 @@ func NewTaskService(store store.Store) *TaskService { } // StartTask attempts to start a new task. It returns an error if a task is already running. -func (s *TaskService) StartTask(groupName string, total int, timeout time.Duration) (*TaskStatus, error) { +func (s *TaskService) StartTask(taskType, groupName string, total int, timeout time.Duration) (*TaskStatus, error) { currentStatus, err := s.GetTaskStatus() if err != nil { return nil, fmt.Errorf("failed to check current task status before starting a new one: %w", err) @@ -50,6 +56,7 @@ func (s *TaskService) StartTask(groupName string, total int, timeout time.Durati } status := &TaskStatus{ + TaskType: taskType, IsRunning: true, GroupName: groupName, Total: total, diff --git a/web/src/api/keys.ts b/web/src/api/keys.ts index 97d3906..2a10a1e 100644 --- a/web/src/api/keys.ts +++ b/web/src/api/keys.ts @@ -62,7 +62,7 @@ export const keysApi = { return res.data; }, - // 批量添加密钥 + // 批量添加密钥-已弃用 async addMultipleKeys( group_id: number, keys_text: string @@ -78,6 +78,15 @@ export const keysApi = { return res.data; }, + // 异步批量添加密钥 + async addKeysAsync(group_id: number, keys_text: string): Promise { + const res = await http.post("/keys/add-async", { + group_id, + keys_text, + }); + return res.data; + }, + // 测试密钥 async testKeys( group_id: number, diff --git a/web/src/components/GlobalTaskProgressBar.vue b/web/src/components/GlobalTaskProgressBar.vue index f885ce4..d248a59 100644 --- a/web/src/components/GlobalTaskProgressBar.vue +++ b/web/src/components/GlobalTaskProgressBar.vue @@ -5,7 +5,7 @@ import { appState } from "@/utils/app-state"; import { NButton, NCard, NProgress, NText, useMessage } from "naive-ui"; import { onBeforeUnmount, onMounted, ref, watch } from "vue"; -const taskInfo = ref({ is_running: false }); +const taskInfo = ref({ is_running: false, task_type: "KEY_VALIDATION" }); const visible = ref(false); let pollTimer: number | null = null; let isPolling = false; // 添加标志位 @@ -46,8 +46,15 @@ async function pollOnce() { if (task.result) { const lastTask = localStorage.getItem("last_closed_task"); if (lastTask !== task.finished_at) { - const { total_keys, valid_keys, invalid_keys } = task.result; - const msg = `任务已完成,处理了 ${total_keys} 个密钥,其中 ${valid_keys} 个有效密钥,${invalid_keys} 个无效密钥。`; + let msg = "任务已完成。"; + if (task.task_type === "KEY_VALIDATION") { + const result = task.result as import("@/types/models").KeyValidationResult; + msg = `密钥验证完成,处理了 ${result.total_keys} 个密钥,其中 ${result.valid_keys} 个有效,${result.invalid_keys} 个无效。`; + } else if (task.task_type === "KEY_IMPORT") { + const result = task.result as import("@/types/models").KeyImportResult; + msg = `密钥导入完成,成功添加 ${result.added_count} 个密钥,忽略了 ${result.ignored_count} 个。`; + } + message.info(msg, { closable: true, duration: 0, @@ -92,6 +99,20 @@ function getProgressText(): string { function handleClose() { visible.value = false; } + +function getTaskTitle(): string { + if (!taskInfo.value) { + return "正在处理任务..."; + } + switch (taskInfo.value.task_type) { + case "KEY_VALIDATION": + return `正在验证分组 [${taskInfo.value.group_name}] 的密钥`; + case "KEY_IMPORT": + return `正在向分组 [${taskInfo.value.group_name}] 导入密钥`; + default: + return "正在处理任务..."; + } +}