feat: 密钥调整为异步任务,取消数量限制

This commit is contained in:
tbphp
2025-07-20 08:52:54 +08:00
parent 86c28dc20c
commit 90498298bf
12 changed files with 213 additions and 40 deletions

View File

@@ -51,6 +51,9 @@ func BuildContainer() (*dig.Container, error) {
if err := container.Provide(services.NewKeyService); err != nil { if err := container.Provide(services.NewKeyService); err != nil {
return nil, err return nil, err
} }
if err := container.Provide(services.NewKeyImportService); err != nil {
return nil, err
}
if err := container.Provide(services.NewLogService); err != nil { if err := container.Provide(services.NewLogService); err != nil {
return nil, err return nil, err
} }

View File

@@ -23,6 +23,7 @@ type Server struct {
KeyManualValidationService *services.KeyManualValidationService KeyManualValidationService *services.KeyManualValidationService
TaskService *services.TaskService TaskService *services.TaskService
KeyService *services.KeyService KeyService *services.KeyService
KeyImportService *services.KeyImportService
LogService *services.LogService LogService *services.LogService
CommonHandler *CommonHandler CommonHandler *CommonHandler
} }
@@ -37,6 +38,7 @@ type NewServerParams struct {
KeyManualValidationService *services.KeyManualValidationService KeyManualValidationService *services.KeyManualValidationService
TaskService *services.TaskService TaskService *services.TaskService
KeyService *services.KeyService KeyService *services.KeyService
KeyImportService *services.KeyImportService
LogService *services.LogService LogService *services.LogService
CommonHandler *CommonHandler CommonHandler *CommonHandler
} }
@@ -51,6 +53,7 @@ func NewServer(params NewServerParams) *Server {
KeyManualValidationService: params.KeyManualValidationService, KeyManualValidationService: params.KeyManualValidationService,
TaskService: params.TaskService, TaskService: params.TaskService,
KeyService: params.KeyService, KeyService: params.KeyService,
KeyImportService: params.KeyImportService,
LogService: params.LogService, LogService: params.LogService,
CommonHandler: params.CommonHandler, CommonHandler: params.CommonHandler,
} }

View File

@@ -34,10 +34,6 @@ func validateKeysText(keysText string) error {
return fmt.Errorf("keys text cannot be empty") return fmt.Errorf("keys text cannot be empty")
} }
if len(keysText) > 10*1024*1024 {
return fmt.Errorf("keys text is too large (max 10MB)")
}
return nil return nil
} }
@@ -98,6 +94,33 @@ func (s *Server) AddMultipleKeys(c *gin.Context) {
response.Success(c, result) response.Success(c, result)
} }
// AddMultipleKeysAsync handles creating new keys from a text block within a specific group.
func (s *Server) AddMultipleKeysAsync(c *gin.Context) {
var req KeyTextRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
return
}
group, ok := s.findGroupByID(c, req.GroupID)
if !ok {
return
}
if err := validateKeysText(req.KeysText); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
return
}
taskStatus, err := s.KeyImportService.StartImportTask(group, req.KeysText)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrTaskInProgress, err.Error()))
return
}
response.Success(c, taskStatus)
}
// ListKeysInGroup handles listing all keys within a specific group with pagination. // ListKeysInGroup handles listing all keys within a specific group with pagination.
func (s *Server) ListKeysInGroup(c *gin.Context) { func (s *Server) ListKeysInGroup(c *gin.Context) {
groupID, err := validateGroupIDFromQuery(c) groupID, err := validateGroupIDFromQuery(c)

View File

@@ -123,6 +123,7 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser
keys.GET("", serverHandler.ListKeysInGroup) keys.GET("", serverHandler.ListKeysInGroup)
keys.GET("/export", serverHandler.ExportKeys) keys.GET("/export", serverHandler.ExportKeys)
keys.POST("/add-multiple", serverHandler.AddMultipleKeys) keys.POST("/add-multiple", serverHandler.AddMultipleKeys)
keys.POST("/add-async", serverHandler.AddMultipleKeysAsync)
keys.POST("/delete-multiple", serverHandler.DeleteMultipleKeys) keys.POST("/delete-multiple", serverHandler.DeleteMultipleKeys)
keys.POST("/restore-multiple", serverHandler.RestoreMultipleKeys) keys.POST("/restore-multiple", serverHandler.RestoreMultipleKeys)
keys.POST("/restore-all-invalid", serverHandler.RestoreAllInvalidKeys) keys.POST("/restore-all-invalid", serverHandler.RestoreAllInvalidKeys)

View File

@@ -0,0 +1,76 @@
package services
import (
"fmt"
"gpt-load/internal/models"
"time"
"github.com/sirupsen/logrus"
)
const (
importChunkSize = 1000
importTimeout = 24 * time.Hour
)
// KeyImportResult holds the result of an import task.
type KeyImportResult struct {
AddedCount int `json:"added_count"`
IgnoredCount int `json:"ignored_count"`
}
// KeyImportService handles the asynchronous import of a large number of keys.
type KeyImportService struct {
TaskService *TaskService
KeyService *KeyService
}
// NewKeyImportService creates a new KeyImportService.
func NewKeyImportService(taskService *TaskService, keyService *KeyService) *KeyImportService {
return &KeyImportService{
TaskService: taskService,
KeyService: keyService,
}
}
// StartImportTask initiates a new asynchronous key import task.
func (s *KeyImportService) StartImportTask(group *models.Group, keysText string) (*TaskStatus, error) {
keys := s.KeyService.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found in the input text")
}
initialStatus, err := s.TaskService.StartTask(TaskTypeKeyImport, group.Name, len(keys), importTimeout)
if err != nil {
return nil, err
}
go s.runImport(group, keys)
return initialStatus, nil
}
func (s *KeyImportService) runImport(group *models.Group, keys []string) {
progressCallback := func(processed int) {
if err := s.TaskService.UpdateProgress(processed); err != nil {
logrus.Warnf("Failed to update task progress for group %d: %v", group.ID, err)
}
}
addedCount, ignoredCount, err := s.KeyService.processAndCreateKeys(group.ID, keys, progressCallback)
if err != nil {
if endErr := s.TaskService.EndTask(nil, err); endErr != nil {
logrus.Errorf("Failed to end task with error for group %d: %v (original error: %v)", group.ID, endErr, err)
}
return
}
result := KeyImportResult{
AddedCount: addedCount,
IgnoredCount: ignoredCount,
}
if endErr := s.TaskService.EndTask(result, nil); endErr != nil {
logrus.Errorf("Failed to end task with success result for group %d: %v", group.ID, endErr)
}
}

View File

@@ -53,7 +53,7 @@ func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (*
timeout := 30 * time.Minute timeout := 30 * time.Minute
taskStatus, err := s.TaskService.StartTask(group.Name, len(keys), timeout) taskStatus, err := s.TaskService.StartTask(TaskTypeKeyValidation, group.Name, len(keys), timeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -55,8 +55,8 @@ func NewKeyService(db *gorm.DB, keyProvider *keypool.KeyProvider, keyValidator *
} }
// AddMultipleKeys handles the business logic of creating new keys from a text block. // AddMultipleKeys handles the business logic of creating new keys from a text block.
// deprecated: use KeyImportService for large imports
func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysResult, error) { func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysResult, error) {
// 1. Parse keys from the text block
keys := s.ParseKeysFromText(keysText) keys := s.ParseKeysFromText(keysText)
if len(keys) > maxRequestKeys { if len(keys) > maxRequestKeys {
return nil, fmt.Errorf("batch size exceeds the limit of %d keys, got %d", maxRequestKeys, len(keys)) return nil, fmt.Errorf("batch size exceeds the limit of %d keys, got %d", maxRequestKeys, len(keys))
@@ -65,17 +65,40 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes
return nil, fmt.Errorf("no valid keys found in the input text") return nil, fmt.Errorf("no valid keys found in the input text")
} }
// 2. Get existing keys in the group for deduplication addedCount, ignoredCount, err := s.processAndCreateKeys(groupID, keys, nil)
if err != nil {
return nil, err
}
var totalInGroup int64
if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil {
return nil, err
}
return &AddKeysResult{
AddedCount: addedCount,
IgnoredCount: ignoredCount,
TotalInGroup: totalInGroup,
}, nil
}
// processAndCreateKeys is the lowest-level reusable function for adding keys.
func (s *KeyService) processAndCreateKeys(
groupID uint,
keys []string,
progressCallback func(processed int),
) (addedCount int, ignoredCount int, err error) {
// 1. Get existing keys in the group for deduplication
var existingKeys []models.APIKey var existingKeys []models.APIKey
if err := s.DB.Where("group_id = ?", groupID).Select("key_value").Find(&existingKeys).Error; err != nil { if err := s.DB.Where("group_id = ?", groupID).Select("key_value").Find(&existingKeys).Error; err != nil {
return nil, err return 0, 0, err
} }
existingKeyMap := make(map[string]bool) existingKeyMap := make(map[string]bool)
for _, k := range existingKeys { for _, k := range existingKeys {
existingKeyMap[k.KeyValue] = true existingKeyMap[k.KeyValue] = true
} }
// 3. Prepare new keys for creation // 2. Prepare new keys for creation
var newKeysToCreate []models.APIKey var newKeysToCreate []models.APIKey
uniqueNewKeys := make(map[string]bool) uniqueNewKeys := make(map[string]bool)
@@ -98,14 +121,10 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes
} }
if len(newKeysToCreate) == 0 { if len(newKeysToCreate) == 0 {
return &AddKeysResult{ return 0, len(keys), nil
AddedCount: 0,
IgnoredCount: len(keys),
TotalInGroup: int64(len(existingKeys)),
}, nil
} }
// 4. Use KeyProvider to add keys in chunks // 3. Use KeyProvider to add keys in chunks
for i := 0; i < len(newKeysToCreate); i += chunkSize { for i := 0; i < len(newKeysToCreate); i += chunkSize {
end := i + chunkSize end := i + chunkSize
if end > len(newKeysToCreate) { if end > len(newKeysToCreate) {
@@ -113,18 +132,16 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes
} }
chunk := newKeysToCreate[i:end] chunk := newKeysToCreate[i:end]
if err := s.KeyProvider.AddKeys(groupID, chunk); err != nil { if err := s.KeyProvider.AddKeys(groupID, chunk); err != nil {
return nil, err return addedCount, len(keys) - addedCount, err
}
addedCount += len(chunk)
if progressCallback != nil {
progressCallback(i + len(chunk))
} }
} }
// 5. Calculate new total count return addedCount, len(keys) - addedCount, nil
totalInGroup := int64(len(existingKeys) + len(newKeysToCreate))
return &AddKeysResult{
AddedCount: len(newKeysToCreate),
IgnoredCount: len(keys) - len(newKeysToCreate),
TotalInGroup: totalInGroup,
}, nil
} }
// ParseKeysFromText parses a string of keys from various formats into a string slice. // ParseKeysFromText parses a string of keys from various formats into a string slice.

View File

@@ -9,12 +9,18 @@ import (
) )
const ( const (
globalTaskKey = "global_task:key_validation" globalTaskKey = "global_task"
ResultTTL = 60 * time.Minute ResultTTL = 60 * time.Minute
)
const (
TaskTypeKeyValidation = "KEY_VALIDATION"
TaskTypeKeyImport = "KEY_IMPORT"
) )
// TaskStatus represents the full lifecycle of a long-running task. // TaskStatus represents the full lifecycle of a long-running task.
type TaskStatus struct { type TaskStatus struct {
TaskType string `json:"task_type"`
IsRunning bool `json:"is_running"` IsRunning bool `json:"is_running"`
GroupName string `json:"group_name,omitempty"` GroupName string `json:"group_name,omitempty"`
Processed int `json:"processed"` Processed int `json:"processed"`
@@ -39,7 +45,7 @@ func NewTaskService(store store.Store) *TaskService {
} }
// StartTask attempts to start a new task. It returns an error if a task is already running. // StartTask attempts to start a new task. It returns an error if a task is already running.
func (s *TaskService) StartTask(groupName string, total int, timeout time.Duration) (*TaskStatus, error) { func (s *TaskService) StartTask(taskType, groupName string, total int, timeout time.Duration) (*TaskStatus, error) {
currentStatus, err := s.GetTaskStatus() currentStatus, err := s.GetTaskStatus()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to check current task status before starting a new one: %w", err) return nil, fmt.Errorf("failed to check current task status before starting a new one: %w", err)
@@ -50,6 +56,7 @@ func (s *TaskService) StartTask(groupName string, total int, timeout time.Durati
} }
status := &TaskStatus{ status := &TaskStatus{
TaskType: taskType,
IsRunning: true, IsRunning: true,
GroupName: groupName, GroupName: groupName,
Total: total, Total: total,

View File

@@ -62,7 +62,7 @@ export const keysApi = {
return res.data; return res.data;
}, },
// 批量添加密钥 // 批量添加密钥-已弃用
async addMultipleKeys( async addMultipleKeys(
group_id: number, group_id: number,
keys_text: string keys_text: string
@@ -78,6 +78,15 @@ export const keysApi = {
return res.data; return res.data;
}, },
// 异步批量添加密钥
async addKeysAsync(group_id: number, keys_text: string): Promise<TaskInfo> {
const res = await http.post("/keys/add-async", {
group_id,
keys_text,
});
return res.data;
},
// 测试密钥 // 测试密钥
async testKeys( async testKeys(
group_id: number, group_id: number,

View File

@@ -5,7 +5,7 @@ import { appState } from "@/utils/app-state";
import { NButton, NCard, NProgress, NText, useMessage } from "naive-ui"; import { NButton, NCard, NProgress, NText, useMessage } from "naive-ui";
import { onBeforeUnmount, onMounted, ref, watch } from "vue"; import { onBeforeUnmount, onMounted, ref, watch } from "vue";
const taskInfo = ref<TaskInfo>({ is_running: false }); const taskInfo = ref<TaskInfo>({ is_running: false, task_type: "KEY_VALIDATION" });
const visible = ref(false); const visible = ref(false);
let pollTimer: number | null = null; let pollTimer: number | null = null;
let isPolling = false; // 添加标志位 let isPolling = false; // 添加标志位
@@ -46,8 +46,15 @@ async function pollOnce() {
if (task.result) { if (task.result) {
const lastTask = localStorage.getItem("last_closed_task"); const lastTask = localStorage.getItem("last_closed_task");
if (lastTask !== task.finished_at) { if (lastTask !== task.finished_at) {
const { total_keys, valid_keys, invalid_keys } = task.result; let msg = "任务已完成。";
const msg = `任务已完成,处理了 ${total_keys} 个密钥,其中 ${valid_keys} 个有效密钥,${invalid_keys} 个无效密钥。`; if (task.task_type === "KEY_VALIDATION") {
const result = task.result as import("@/types/models").KeyValidationResult;
msg = `密钥验证完成,处理了 ${result.total_keys} 个密钥,其中 ${result.valid_keys} 个有效,${result.invalid_keys} 个无效。`;
} else if (task.task_type === "KEY_IMPORT") {
const result = task.result as import("@/types/models").KeyImportResult;
msg = `密钥导入完成,成功添加 ${result.added_count} 个密钥,忽略了 ${result.ignored_count} 个。`;
}
message.info(msg, { message.info(msg, {
closable: true, closable: true,
duration: 0, duration: 0,
@@ -92,6 +99,20 @@ function getProgressText(): string {
function handleClose() { function handleClose() {
visible.value = false; visible.value = false;
} }
function getTaskTitle(): string {
if (!taskInfo.value) {
return "正在处理任务...";
}
switch (taskInfo.value.task_type) {
case "KEY_VALIDATION":
return `正在验证分组 [${taskInfo.value.group_name}] 的密钥`;
case "KEY_IMPORT":
return `正在向分组 [${taskInfo.value.group_name}] 导入密钥`;
default:
return "正在处理任务...";
}
}
</script> </script>
<template> <template>
@@ -102,7 +123,7 @@ function handleClose() {
<span class="progress-icon"></span> <span class="progress-icon"></span>
<div class="progress-details"> <div class="progress-details">
<n-text strong class="progress-title"> <n-text strong class="progress-title">
正在处理分组 {{ taskInfo.group_name }} 的任务 {{ getTaskTitle() }}
</n-text> </n-text>
<n-text depth="3" class="progress-subtitle"> <n-text depth="3" class="progress-subtitle">
{{ getProgressText() }} ({{ getProgressPercentage() }}%) {{ getProgressText() }} ({{ getProgressPercentage() }}%)

View File

@@ -1,5 +1,6 @@
<script setup lang="ts"> <script setup lang="ts">
import { keysApi } from "@/api/keys"; import { keysApi } from "@/api/keys";
import { appState } from "@/utils/app-state";
import { Close } from "@vicons/ionicons5"; import { Close } from "@vicons/ionicons5";
import { NButton, NCard, NInput, NModal } from "naive-ui"; import { NButton, NCard, NInput, NModal } from "naive-ui";
import { ref, watch } from "vue"; import { ref, watch } from "vue";
@@ -51,10 +52,11 @@ async function handleSubmit() {
try { try {
loading.value = true; loading.value = true;
await keysApi.addMultipleKeys(props.groupId, keysText.value); await keysApi.addKeysAsync(props.groupId, keysText.value);
resetForm();
emit("success");
handleClose(); handleClose();
window.$message.success("密钥导入任务已开始,请稍后在下方查看进度。");
appState.taskPollingTrigger++;
} finally { } finally {
loading.value = false; loading.value = false;
} }

View File

@@ -75,18 +75,29 @@ export interface RequestStats {
failure_rate: number; failure_rate: number;
} }
export type TaskType = "KEY_VALIDATION" | "KEY_IMPORT";
export interface KeyValidationResult {
invalid_keys: number;
total_keys: number;
valid_keys: number;
}
export interface KeyImportResult {
added_count: number;
ignored_count: number;
}
export interface TaskInfo { export interface TaskInfo {
task_type: TaskType;
is_running: boolean; is_running: boolean;
group_name?: string; group_name?: string;
processed?: number; processed?: number;
total?: number; total?: number;
started_at?: string; started_at?: string;
finished_at?: string; finished_at?: string;
result?: { result?: KeyValidationResult | KeyImportResult;
invalid_keys: number; error?: string;
total_keys: number;
valid_keys: number;
};
} }
// Based on backend response // Based on backend response