feat: 密钥管理
This commit is contained in:
206
internal/services/key_cron_service.go
Normal file
206
internal/services/key_cron_service.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gpt-load/internal/config"
|
||||
"gpt-load/internal/models"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// KeyCronService is responsible for periodically validating all API keys.
|
||||
type KeyCronService struct {
|
||||
DB *gorm.DB
|
||||
Validator *KeyValidatorService
|
||||
SettingsManager *config.SystemSettingsManager
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewKeyCronService creates a new KeyCronService.
|
||||
func NewKeyCronService(db *gorm.DB, validator *KeyValidatorService, settingsManager *config.SystemSettingsManager) *KeyCronService {
|
||||
return &KeyCronService{
|
||||
DB: db,
|
||||
Validator: validator,
|
||||
SettingsManager: settingsManager,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the cron job.
|
||||
func (s *KeyCronService) Start() {
|
||||
logrus.Info("Starting KeyCronService...")
|
||||
s.wg.Add(1)
|
||||
go s.run()
|
||||
}
|
||||
|
||||
// Stop stops the cron job.
|
||||
func (s *KeyCronService) Stop() {
|
||||
logrus.Info("Stopping KeyCronService...")
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
logrus.Info("KeyCronService stopped.")
|
||||
}
|
||||
|
||||
func (s *KeyCronService) run() {
|
||||
defer s.wg.Done()
|
||||
ctx := context.Background()
|
||||
|
||||
// Run once on start
|
||||
s.validateAllGroups(ctx)
|
||||
|
||||
for {
|
||||
// Dynamically get the interval for the next run
|
||||
intervalMinutes := s.SettingsManager.GetInt("key_validation_interval_minutes", 60)
|
||||
if intervalMinutes <= 0 {
|
||||
intervalMinutes = 60 // Fallback to a safe default
|
||||
}
|
||||
nextRunTimer := time.NewTimer(time.Duration(intervalMinutes) * time.Minute)
|
||||
|
||||
select {
|
||||
case <-nextRunTimer.C:
|
||||
s.validateAllGroups(ctx)
|
||||
case <-s.stopChan:
|
||||
nextRunTimer.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyCronService) validateAllGroups(ctx context.Context) {
|
||||
logrus.Info("KeyCronService: Starting validation cycle for all groups.")
|
||||
var groups []models.Group
|
||||
if err := s.DB.Find(&groups).Error; err != nil {
|
||||
logrus.Errorf("KeyCronService: Failed to get groups: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
groupCopy := group // Create a copy for the closure
|
||||
go func(g models.Group) {
|
||||
// Get effective settings for the group
|
||||
effectiveSettings := s.SettingsManager.GetEffectiveConfig(g.Config)
|
||||
interval := time.Duration(effectiveSettings.KeyValidationIntervalMinutes) * time.Minute
|
||||
|
||||
// Check if it's time to validate this group
|
||||
if g.LastValidatedAt == nil || time.Since(*g.LastValidatedAt) > interval {
|
||||
s.validateGroup(ctx, &g)
|
||||
}
|
||||
}(groupCopy)
|
||||
}
|
||||
logrus.Info("KeyCronService: Validation cycle finished.")
|
||||
}
|
||||
|
||||
func (s *KeyCronService) validateGroup(ctx context.Context, group *models.Group) {
|
||||
var keys []models.APIKey
|
||||
if err := s.DB.Where("group_id = ?", group.ID).Find(&keys).Error; err != nil {
|
||||
logrus.Errorf("KeyCronService: Failed to get keys for group %s: %v", group.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
logrus.Infof("KeyCronService: Validating %d keys for group %s", len(keys), group.Name)
|
||||
|
||||
jobs := make(chan models.APIKey, len(keys))
|
||||
results := make(chan models.APIKey, len(keys))
|
||||
|
||||
concurrency := s.SettingsManager.GetInt("key_validation_concurrency", 10)
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10 // Fallback to a safe default
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go s.worker(ctx, &wg, group, jobs, results)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
jobs <- key
|
||||
}
|
||||
close(jobs)
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
var keysToUpdate []models.APIKey
|
||||
for key := range results {
|
||||
keysToUpdate = append(keysToUpdate, key)
|
||||
}
|
||||
|
||||
if len(keysToUpdate) > 0 {
|
||||
s.batchUpdateKeyStatus(keysToUpdate)
|
||||
}
|
||||
|
||||
// Update the last validated timestamp for the group
|
||||
if err := s.DB.Model(group).Update("last_validated_at", time.Now()).Error; err != nil {
|
||||
logrus.Errorf("KeyCronService: Failed to update last_validated_at for group %s: %v", group.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyCronService) worker(ctx context.Context, wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- models.APIKey) {
|
||||
defer wg.Done()
|
||||
for key := range jobs {
|
||||
isValid, err := s.Validator.ValidateSingleKey(ctx, &key, group)
|
||||
// Only update status if there was no error during validation
|
||||
if err != nil {
|
||||
logrus.Warnf("KeyCronService: Failed to validate key ID %d for group %s: %v. Skipping status update.", key.ID, group.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
newStatus := "inactive"
|
||||
if isValid {
|
||||
newStatus = "active"
|
||||
}
|
||||
|
||||
// Only send to results if the status has changed
|
||||
if key.Status != newStatus {
|
||||
key.Status = newStatus
|
||||
results <- key
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyCronService) batchUpdateKeyStatus(keys []models.APIKey) {
|
||||
if len(keys) == 0 {
|
||||
return
|
||||
}
|
||||
logrus.Infof("KeyCronService: Batch updating status for %d keys.", len(keys))
|
||||
|
||||
activeIDs := []uint{}
|
||||
inactiveIDs := []uint{}
|
||||
|
||||
for _, key := range keys {
|
||||
if key.Status == "active" {
|
||||
activeIDs = append(activeIDs, key.ID)
|
||||
} else {
|
||||
inactiveIDs = append(inactiveIDs, key.ID)
|
||||
}
|
||||
}
|
||||
|
||||
err := s.DB.Transaction(func(tx *gorm.DB) error {
|
||||
if len(activeIDs) > 0 {
|
||||
if err := tx.Model(&models.APIKey{}).Where("id IN ?", activeIDs).Update("status", "active").Error; err != nil {
|
||||
return err
|
||||
}
|
||||
logrus.Infof("KeyCronService: Set %d keys to 'active'.", len(activeIDs))
|
||||
}
|
||||
if len(inactiveIDs) > 0 {
|
||||
if err := tx.Model(&models.APIKey{}).Where("id IN ?", inactiveIDs).Update("status", "inactive").Error; err != nil {
|
||||
return err
|
||||
}
|
||||
logrus.Infof("KeyCronService: Set %d keys to 'inactive'.", len(inactiveIDs))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logrus.Errorf("KeyCronService: Failed to batch update key status: %v", err)
|
||||
}
|
||||
}
|
122
internal/services/key_manual_validation_service.go
Normal file
122
internal/services/key_manual_validation_service.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gpt-load/internal/config"
|
||||
"gpt-load/internal/models"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ManualValidationResult holds the result of a manual validation task.
|
||||
type ManualValidationResult struct {
|
||||
TotalKeys int `json:"total_keys"`
|
||||
ValidKeys int `json:"valid_keys"`
|
||||
InvalidKeys int `json:"invalid_keys"`
|
||||
}
|
||||
|
||||
// KeyManualValidationService handles user-initiated key validation for a group.
|
||||
type KeyManualValidationService struct {
|
||||
DB *gorm.DB
|
||||
Validator *KeyValidatorService
|
||||
TaskService *TaskService
|
||||
SettingsManager *config.SystemSettingsManager
|
||||
}
|
||||
|
||||
// NewKeyManualValidationService creates a new KeyManualValidationService.
|
||||
func NewKeyManualValidationService(db *gorm.DB, validator *KeyValidatorService, taskService *TaskService, settingsManager *config.SystemSettingsManager) *KeyManualValidationService {
|
||||
return &KeyManualValidationService{
|
||||
DB: db,
|
||||
Validator: validator,
|
||||
TaskService: taskService,
|
||||
SettingsManager: settingsManager,
|
||||
}
|
||||
}
|
||||
|
||||
// StartValidationTask starts a new manual validation task for a given group.
|
||||
func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (*TaskStatus, error) {
|
||||
var keys []models.APIKey
|
||||
if err := s.DB.Where("group_id = ?", group.ID).Find(&keys).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to get keys for group %s: %w", group.Name, err)
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no keys to validate in group %s", group.Name)
|
||||
}
|
||||
|
||||
taskID := uuid.New().String()
|
||||
timeoutMinutes := s.SettingsManager.GetInt("key_validation_task_timeout_minutes", 60)
|
||||
timeout := time.Duration(timeoutMinutes) * time.Minute
|
||||
|
||||
taskStatus, err := s.TaskService.StartTask(taskID, group.Name, len(keys), timeout)
|
||||
if err != nil {
|
||||
return nil, err // A task is already running
|
||||
}
|
||||
|
||||
// Run the validation in a separate goroutine
|
||||
go s.runValidation(group, keys, taskStatus)
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyManualValidationService) runValidation(group *models.Group, keys []models.APIKey, task *TaskStatus) {
|
||||
defer s.TaskService.EndTask()
|
||||
|
||||
logrus.Infof("Starting manual validation for group %s (TaskID: %s)", group.Name, task.TaskID)
|
||||
|
||||
jobs := make(chan models.APIKey, len(keys))
|
||||
results := make(chan bool, len(keys))
|
||||
|
||||
concurrency := s.SettingsManager.GetInt("key_validation_concurrency", 10)
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10 // Fallback to a safe default
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go s.validationWorker(&wg, group, jobs, results)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
jobs <- key
|
||||
}
|
||||
close(jobs)
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
validCount := 0
|
||||
processedCount := 0
|
||||
for isValid := range results {
|
||||
processedCount++
|
||||
if isValid {
|
||||
validCount++
|
||||
}
|
||||
// Update progress
|
||||
s.TaskService.UpdateProgress(processedCount)
|
||||
}
|
||||
|
||||
result := ManualValidationResult{
|
||||
TotalKeys: len(keys),
|
||||
ValidKeys: validCount,
|
||||
InvalidKeys: len(keys) - validCount,
|
||||
}
|
||||
|
||||
// Store the final result
|
||||
s.TaskService.StoreResult(task.TaskID, result)
|
||||
logrus.Infof("Manual validation finished for group %s (TaskID: %s): %+v", group.Name, task.TaskID, result)
|
||||
}
|
||||
|
||||
func (s *KeyManualValidationService) validationWorker(wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- bool) {
|
||||
defer wg.Done()
|
||||
for key := range jobs {
|
||||
isValid, _ := s.Validator.ValidateSingleKey(context.Background(), &key, group)
|
||||
results <- isValid
|
||||
}
|
||||
}
|
206
internal/services/key_service.go
Normal file
206
internal/services/key_service.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gpt-load/internal/models"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AddKeysResult holds the result of adding multiple keys.
|
||||
type AddKeysResult struct {
|
||||
AddedCount int `json:"added_count"`
|
||||
IgnoredCount int `json:"ignored_count"`
|
||||
TotalInGroup int64 `json:"total_in_group"`
|
||||
}
|
||||
|
||||
// KeyService provides services related to API keys.
|
||||
type KeyService struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
// NewKeyService creates a new KeyService.
|
||||
func NewKeyService(db *gorm.DB) *KeyService {
|
||||
return &KeyService{DB: db}
|
||||
}
|
||||
|
||||
// AddMultipleKeys handles the business logic of creating new keys from a text block.
|
||||
func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysResult, error) {
|
||||
// 1. Parse keys from the text block
|
||||
keys := s.parseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found in the input text")
|
||||
}
|
||||
|
||||
// 2. Get the group information for validation
|
||||
var group models.Group
|
||||
if err := s.DB.First(&group, groupID).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to find group: %w", err)
|
||||
}
|
||||
|
||||
// 3. Get existing keys in the group for deduplication
|
||||
var existingKeys []models.APIKey
|
||||
if err := s.DB.Where("group_id = ?", groupID).Select("key_value").Find(&existingKeys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
existingKeyMap := make(map[string]bool)
|
||||
for _, k := range existingKeys {
|
||||
existingKeyMap[k.KeyValue] = true
|
||||
}
|
||||
|
||||
// 4. Prepare new keys with basic validation only
|
||||
var newKeysToCreate []models.APIKey
|
||||
uniqueNewKeys := make(map[string]bool)
|
||||
|
||||
for _, keyVal := range keys {
|
||||
trimmedKey := strings.TrimSpace(keyVal)
|
||||
if trimmedKey == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if key already exists
|
||||
if existingKeyMap[trimmedKey] || uniqueNewKeys[trimmedKey] {
|
||||
continue
|
||||
}
|
||||
|
||||
// 通用验证:只做基础格式检查,不做渠道特定验证
|
||||
if s.isValidKeyFormat(trimmedKey) {
|
||||
uniqueNewKeys[trimmedKey] = true
|
||||
newKeysToCreate = append(newKeysToCreate, models.APIKey{
|
||||
GroupID: groupID,
|
||||
KeyValue: trimmedKey,
|
||||
Status: "active",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
addedCount := len(newKeysToCreate)
|
||||
// 更准确的忽略计数:包括重复的和无效的
|
||||
ignoredCount := len(keys) - addedCount
|
||||
|
||||
// 5. Insert new keys if any
|
||||
if addedCount > 0 {
|
||||
if err := s.DB.Create(&newKeysToCreate).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Get the new total count
|
||||
var totalInGroup int64
|
||||
if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AddKeysResult{
|
||||
AddedCount: addedCount,
|
||||
IgnoredCount: ignoredCount,
|
||||
TotalInGroup: totalInGroup,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *KeyService) parseKeysFromText(text string) []string {
|
||||
var keys []string
|
||||
|
||||
// First, try to parse as a JSON array of strings
|
||||
if json.Unmarshal([]byte(text), &keys) == nil && len(keys) > 0 {
|
||||
return s.filterValidKeys(keys)
|
||||
}
|
||||
|
||||
// 通用解析:通过分隔符分割文本,不使用复杂的正则表达式
|
||||
delimiters := regexp.MustCompile(`[\s,;|\n\r\t]+`)
|
||||
splitKeys := delimiters.Split(strings.TrimSpace(text), -1)
|
||||
|
||||
for _, key := range splitKeys {
|
||||
key = strings.TrimSpace(key)
|
||||
if key != "" {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
return s.filterValidKeys(keys)
|
||||
}
|
||||
|
||||
// filterValidKeys validates and filters potential API keys
|
||||
func (s *KeyService) filterValidKeys(keys []string) []string {
|
||||
var validKeys []string
|
||||
for _, key := range keys {
|
||||
key = strings.TrimSpace(key)
|
||||
if s.isValidKeyFormat(key) {
|
||||
validKeys = append(validKeys, key)
|
||||
}
|
||||
}
|
||||
return validKeys
|
||||
}
|
||||
|
||||
// isValidKeyFormat performs basic validation on key format
|
||||
func (s *KeyService) isValidKeyFormat(key string) bool {
|
||||
if len(key) < 4 || len(key) > 1000 {
|
||||
return false
|
||||
}
|
||||
|
||||
if key == "" ||
|
||||
strings.TrimSpace(key) == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
validChars := regexp.MustCompile(`^[a-zA-Z0-9_\-./+=:]+$`)
|
||||
return validChars.MatchString(key)
|
||||
}
|
||||
|
||||
// RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'.
|
||||
func (s *KeyService) RestoreAllInvalidKeys(groupID uint) (int64, error) {
|
||||
result := s.DB.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, "inactive").Update("status", "active")
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// ClearAllInvalidKeys deletes all 'inactive' keys from a group.
|
||||
func (s *KeyService) ClearAllInvalidKeys(groupID uint) (int64, error) {
|
||||
result := s.DB.Where("group_id = ? AND status = ?", groupID, "inactive").Delete(&models.APIKey{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// DeleteSingleKey deletes a specific key from a group.
|
||||
func (s *KeyService) DeleteSingleKey(groupID, keyID uint) (int64, error) {
|
||||
result := s.DB.Where("group_id = ? AND id = ?", groupID, keyID).Delete(&models.APIKey{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// ExportKeys returns a list of keys for a group, filtered by status.
|
||||
func (s *KeyService) ExportKeys(groupID uint, filter string) ([]string, error) {
|
||||
query := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID)
|
||||
|
||||
switch filter {
|
||||
case "valid":
|
||||
query = query.Where("status = ?", "active")
|
||||
case "invalid":
|
||||
query = query.Where("status = ?", "inactive")
|
||||
case "all":
|
||||
// No status filter needed
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid filter value. Use 'all', 'valid', or 'invalid'")
|
||||
}
|
||||
|
||||
var keys []string
|
||||
if err := query.Pluck("key_value", &keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// ListKeysInGroup lists all keys within a specific group, filtered by status.
|
||||
func (s *KeyService) ListKeysInGroup(groupID uint, statusFilter string) ([]models.APIKey, error) {
|
||||
var keys []models.APIKey
|
||||
query := s.DB.Where("group_id = ?", groupID)
|
||||
|
||||
if statusFilter != "" {
|
||||
query = query.Where("status = ?", statusFilter)
|
||||
}
|
||||
|
||||
if err := query.Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
91
internal/services/key_validator_service.go
Normal file
91
internal/services/key_validator_service.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gpt-load/internal/channel"
|
||||
"gpt-load/internal/models"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// KeyValidatorService provides methods to validate API keys.
|
||||
type KeyValidatorService struct {
|
||||
DB *gorm.DB
|
||||
channelFactory *channel.Factory
|
||||
}
|
||||
|
||||
// NewKeyValidatorService creates a new KeyValidatorService.
|
||||
func NewKeyValidatorService(db *gorm.DB, factory *channel.Factory) *KeyValidatorService {
|
||||
return &KeyValidatorService{
|
||||
DB: db,
|
||||
channelFactory: factory,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateSingleKey performs a validation check on a single API key.
|
||||
// It does not modify the key's state in the database.
|
||||
// It returns true if the key is valid, and an error if it's not.
|
||||
func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models.APIKey, group *models.Group) (bool, error) {
|
||||
// 添加超时保护
|
||||
if ctx.Err() != nil {
|
||||
return false, fmt.Errorf("context cancelled or timed out: %w", ctx.Err())
|
||||
}
|
||||
|
||||
ch, err := s.channelFactory.GetChannel(group)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"group_id": group.ID,
|
||||
"group_name": group.Name,
|
||||
"channel_type": group.ChannelType,
|
||||
"error": err,
|
||||
}).Error("Failed to get channel for key validation")
|
||||
return false, fmt.Errorf("failed to get channel for group %s: %w", group.Name, err)
|
||||
}
|
||||
|
||||
// 记录验证开始
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"key_id": key.ID,
|
||||
"group_id": group.ID,
|
||||
"group_name": group.Name,
|
||||
}).Debug("Starting key validation")
|
||||
|
||||
isValid, validationErr := ch.ValidateKey(ctx, key.KeyValue)
|
||||
if validationErr != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"key_id": key.ID,
|
||||
"group_id": group.ID,
|
||||
"group_name": group.Name,
|
||||
"error": validationErr,
|
||||
}).Warn("Key validation failed")
|
||||
return false, validationErr
|
||||
}
|
||||
|
||||
// 记录验证结果
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"key_id": key.ID,
|
||||
"group_id": group.ID,
|
||||
"group_name": group.Name,
|
||||
"is_valid": isValid,
|
||||
}).Debug("Key validation completed")
|
||||
|
||||
return isValid, nil
|
||||
}
|
||||
|
||||
// TestSingleKeyByID performs a synchronous validation test for a single API key by its ID.
|
||||
// It is intended for handling user-initiated "Test" actions.
|
||||
// It does not modify the key's state in the database.
|
||||
func (s *KeyValidatorService) TestSingleKeyByID(ctx context.Context, keyID uint) (bool, error) {
|
||||
var apiKey models.APIKey
|
||||
if err := s.DB.First(&apiKey, keyID).Error; err != nil {
|
||||
return false, fmt.Errorf("failed to find api key with id %d: %w", keyID, err)
|
||||
}
|
||||
|
||||
var group models.Group
|
||||
if err := s.DB.First(&group, apiKey.GroupID).Error; err != nil {
|
||||
return false, fmt.Errorf("failed to find group with id %d: %w", apiKey.GroupID, err)
|
||||
}
|
||||
|
||||
return s.ValidateSingleKey(ctx, &apiKey, &group)
|
||||
}
|
125
internal/services/task_service.go
Normal file
125
internal/services/task_service.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TaskStatus represents the status of a long-running task.
|
||||
type TaskStatus struct {
|
||||
IsRunning bool `json:"is_running"`
|
||||
GroupName string `json:"group_name,omitempty"`
|
||||
Processed int `json:"processed,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"` // Internal field to handle zombie tasks
|
||||
lastUpdated time.Time
|
||||
}
|
||||
|
||||
// TaskService manages the state of a single, global, long-running task.
|
||||
type TaskService struct {
|
||||
mu sync.Mutex
|
||||
status TaskStatus
|
||||
resultsCache map[string]interface{}
|
||||
cacheOrder []string
|
||||
maxCacheSize int
|
||||
}
|
||||
|
||||
// NewTaskService creates a new TaskService.
|
||||
func NewTaskService() *TaskService {
|
||||
return &TaskService{
|
||||
resultsCache: make(map[string]interface{}),
|
||||
cacheOrder: make([]string, 0),
|
||||
maxCacheSize: 100, // Store results for the last 100 tasks
|
||||
}
|
||||
}
|
||||
|
||||
// StartTask attempts to start a new task. It returns an error if a task is already running.
|
||||
func (s *TaskService) StartTask(taskID, groupName string, total int, timeout time.Duration) (*TaskStatus, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Zombie task check
|
||||
if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) {
|
||||
// The previous task is considered a zombie, reset it.
|
||||
s.status = TaskStatus{}
|
||||
}
|
||||
|
||||
if s.status.IsRunning {
|
||||
return nil, errors.New("a task is already running")
|
||||
}
|
||||
|
||||
s.status = TaskStatus{
|
||||
IsRunning: true,
|
||||
TaskID: taskID,
|
||||
GroupName: groupName,
|
||||
Total: total,
|
||||
Processed: 0,
|
||||
ExpiresAt: time.Now().Add(timeout),
|
||||
lastUpdated: time.Now(),
|
||||
}
|
||||
|
||||
return &s.status, nil
|
||||
}
|
||||
|
||||
// GetTaskStatus returns the current status of the task.
|
||||
func (s *TaskService) GetTaskStatus() *TaskStatus {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Zombie task check
|
||||
if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) {
|
||||
s.status = TaskStatus{} // Reset if expired
|
||||
}
|
||||
|
||||
// Return a copy to prevent race conditions on the caller's side
|
||||
statusCopy := s.status
|
||||
return &statusCopy
|
||||
}
|
||||
|
||||
// UpdateProgress updates the progress of the current task.
|
||||
func (s *TaskService) UpdateProgress(processed int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.status.IsRunning {
|
||||
return
|
||||
}
|
||||
|
||||
s.status.Processed = processed
|
||||
s.status.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
// EndTask marks the current task as finished.
|
||||
func (s *TaskService) EndTask() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.status.IsRunning = false
|
||||
}
|
||||
|
||||
// StoreResult stores the result of a finished task.
|
||||
func (s *TaskService) StoreResult(taskID string, result interface{}) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.resultsCache[taskID]; !exists {
|
||||
if len(s.cacheOrder) >= s.maxCacheSize {
|
||||
oldestTaskID := s.cacheOrder[0]
|
||||
delete(s.resultsCache, oldestTaskID)
|
||||
s.cacheOrder = s.cacheOrder[1:]
|
||||
}
|
||||
s.cacheOrder = append(s.cacheOrder, taskID)
|
||||
}
|
||||
s.resultsCache[taskID] = result
|
||||
}
|
||||
|
||||
// GetResult retrieves the result of a finished task.
|
||||
func (s *TaskService) GetResult(taskID string) (interface{}, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
result, found := s.resultsCache[taskID]
|
||||
return result, found
|
||||
}
|
Reference in New Issue
Block a user