feat: 密钥管理

This commit is contained in:
tbphp
2025-07-04 21:19:15 +08:00
parent 7c10474d19
commit 01b86f7e30
23 changed files with 1427 additions and 250 deletions

View File

@@ -0,0 +1,206 @@
package services
import (
"context"
"gpt-load/internal/config"
"gpt-load/internal/models"
"sync"
"time"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// KeyCronService is responsible for periodically validating all API keys.
type KeyCronService struct {
DB *gorm.DB
Validator *KeyValidatorService
SettingsManager *config.SystemSettingsManager
stopChan chan struct{}
wg sync.WaitGroup
}
// NewKeyCronService creates a new KeyCronService.
func NewKeyCronService(db *gorm.DB, validator *KeyValidatorService, settingsManager *config.SystemSettingsManager) *KeyCronService {
return &KeyCronService{
DB: db,
Validator: validator,
SettingsManager: settingsManager,
stopChan: make(chan struct{}),
}
}
// Start begins the cron job.
func (s *KeyCronService) Start() {
logrus.Info("Starting KeyCronService...")
s.wg.Add(1)
go s.run()
}
// Stop stops the cron job.
func (s *KeyCronService) Stop() {
logrus.Info("Stopping KeyCronService...")
close(s.stopChan)
s.wg.Wait()
logrus.Info("KeyCronService stopped.")
}
func (s *KeyCronService) run() {
defer s.wg.Done()
ctx := context.Background()
// Run once on start
s.validateAllGroups(ctx)
for {
// Dynamically get the interval for the next run
intervalMinutes := s.SettingsManager.GetInt("key_validation_interval_minutes", 60)
if intervalMinutes <= 0 {
intervalMinutes = 60 // Fallback to a safe default
}
nextRunTimer := time.NewTimer(time.Duration(intervalMinutes) * time.Minute)
select {
case <-nextRunTimer.C:
s.validateAllGroups(ctx)
case <-s.stopChan:
nextRunTimer.Stop()
return
}
}
}
func (s *KeyCronService) validateAllGroups(ctx context.Context) {
logrus.Info("KeyCronService: Starting validation cycle for all groups.")
var groups []models.Group
if err := s.DB.Find(&groups).Error; err != nil {
logrus.Errorf("KeyCronService: Failed to get groups: %v", err)
return
}
for _, group := range groups {
groupCopy := group // Create a copy for the closure
go func(g models.Group) {
// Get effective settings for the group
effectiveSettings := s.SettingsManager.GetEffectiveConfig(g.Config)
interval := time.Duration(effectiveSettings.KeyValidationIntervalMinutes) * time.Minute
// Check if it's time to validate this group
if g.LastValidatedAt == nil || time.Since(*g.LastValidatedAt) > interval {
s.validateGroup(ctx, &g)
}
}(groupCopy)
}
logrus.Info("KeyCronService: Validation cycle finished.")
}
func (s *KeyCronService) validateGroup(ctx context.Context, group *models.Group) {
var keys []models.APIKey
if err := s.DB.Where("group_id = ?", group.ID).Find(&keys).Error; err != nil {
logrus.Errorf("KeyCronService: Failed to get keys for group %s: %v", group.Name, err)
return
}
if len(keys) == 0 {
return
}
logrus.Infof("KeyCronService: Validating %d keys for group %s", len(keys), group.Name)
jobs := make(chan models.APIKey, len(keys))
results := make(chan models.APIKey, len(keys))
concurrency := s.SettingsManager.GetInt("key_validation_concurrency", 10)
if concurrency <= 0 {
concurrency = 10 // Fallback to a safe default
}
var wg sync.WaitGroup
for i := 0; i < concurrency; i++ {
wg.Add(1)
go s.worker(ctx, &wg, group, jobs, results)
}
for _, key := range keys {
jobs <- key
}
close(jobs)
wg.Wait()
close(results)
var keysToUpdate []models.APIKey
for key := range results {
keysToUpdate = append(keysToUpdate, key)
}
if len(keysToUpdate) > 0 {
s.batchUpdateKeyStatus(keysToUpdate)
}
// Update the last validated timestamp for the group
if err := s.DB.Model(group).Update("last_validated_at", time.Now()).Error; err != nil {
logrus.Errorf("KeyCronService: Failed to update last_validated_at for group %s: %v", group.Name, err)
}
}
func (s *KeyCronService) worker(ctx context.Context, wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- models.APIKey) {
defer wg.Done()
for key := range jobs {
isValid, err := s.Validator.ValidateSingleKey(ctx, &key, group)
// Only update status if there was no error during validation
if err != nil {
logrus.Warnf("KeyCronService: Failed to validate key ID %d for group %s: %v. Skipping status update.", key.ID, group.Name, err)
continue
}
newStatus := "inactive"
if isValid {
newStatus = "active"
}
// Only send to results if the status has changed
if key.Status != newStatus {
key.Status = newStatus
results <- key
}
}
}
func (s *KeyCronService) batchUpdateKeyStatus(keys []models.APIKey) {
if len(keys) == 0 {
return
}
logrus.Infof("KeyCronService: Batch updating status for %d keys.", len(keys))
activeIDs := []uint{}
inactiveIDs := []uint{}
for _, key := range keys {
if key.Status == "active" {
activeIDs = append(activeIDs, key.ID)
} else {
inactiveIDs = append(inactiveIDs, key.ID)
}
}
err := s.DB.Transaction(func(tx *gorm.DB) error {
if len(activeIDs) > 0 {
if err := tx.Model(&models.APIKey{}).Where("id IN ?", activeIDs).Update("status", "active").Error; err != nil {
return err
}
logrus.Infof("KeyCronService: Set %d keys to 'active'.", len(activeIDs))
}
if len(inactiveIDs) > 0 {
if err := tx.Model(&models.APIKey{}).Where("id IN ?", inactiveIDs).Update("status", "inactive").Error; err != nil {
return err
}
logrus.Infof("KeyCronService: Set %d keys to 'inactive'.", len(inactiveIDs))
}
return nil
})
if err != nil {
logrus.Errorf("KeyCronService: Failed to batch update key status: %v", err)
}
}

View File

@@ -0,0 +1,122 @@
package services
import (
"context"
"fmt"
"gpt-load/internal/config"
"gpt-load/internal/models"
"sync"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// ManualValidationResult holds the result of a manual validation task.
type ManualValidationResult struct {
TotalKeys int `json:"total_keys"`
ValidKeys int `json:"valid_keys"`
InvalidKeys int `json:"invalid_keys"`
}
// KeyManualValidationService handles user-initiated key validation for a group.
type KeyManualValidationService struct {
DB *gorm.DB
Validator *KeyValidatorService
TaskService *TaskService
SettingsManager *config.SystemSettingsManager
}
// NewKeyManualValidationService creates a new KeyManualValidationService.
func NewKeyManualValidationService(db *gorm.DB, validator *KeyValidatorService, taskService *TaskService, settingsManager *config.SystemSettingsManager) *KeyManualValidationService {
return &KeyManualValidationService{
DB: db,
Validator: validator,
TaskService: taskService,
SettingsManager: settingsManager,
}
}
// StartValidationTask starts a new manual validation task for a given group.
func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (*TaskStatus, error) {
var keys []models.APIKey
if err := s.DB.Where("group_id = ?", group.ID).Find(&keys).Error; err != nil {
return nil, fmt.Errorf("failed to get keys for group %s: %w", group.Name, err)
}
if len(keys) == 0 {
return nil, fmt.Errorf("no keys to validate in group %s", group.Name)
}
taskID := uuid.New().String()
timeoutMinutes := s.SettingsManager.GetInt("key_validation_task_timeout_minutes", 60)
timeout := time.Duration(timeoutMinutes) * time.Minute
taskStatus, err := s.TaskService.StartTask(taskID, group.Name, len(keys), timeout)
if err != nil {
return nil, err // A task is already running
}
// Run the validation in a separate goroutine
go s.runValidation(group, keys, taskStatus)
return taskStatus, nil
}
func (s *KeyManualValidationService) runValidation(group *models.Group, keys []models.APIKey, task *TaskStatus) {
defer s.TaskService.EndTask()
logrus.Infof("Starting manual validation for group %s (TaskID: %s)", group.Name, task.TaskID)
jobs := make(chan models.APIKey, len(keys))
results := make(chan bool, len(keys))
concurrency := s.SettingsManager.GetInt("key_validation_concurrency", 10)
if concurrency <= 0 {
concurrency = 10 // Fallback to a safe default
}
var wg sync.WaitGroup
for i := 0; i < concurrency; i++ {
wg.Add(1)
go s.validationWorker(&wg, group, jobs, results)
}
for _, key := range keys {
jobs <- key
}
close(jobs)
wg.Wait()
close(results)
validCount := 0
processedCount := 0
for isValid := range results {
processedCount++
if isValid {
validCount++
}
// Update progress
s.TaskService.UpdateProgress(processedCount)
}
result := ManualValidationResult{
TotalKeys: len(keys),
ValidKeys: validCount,
InvalidKeys: len(keys) - validCount,
}
// Store the final result
s.TaskService.StoreResult(task.TaskID, result)
logrus.Infof("Manual validation finished for group %s (TaskID: %s): %+v", group.Name, task.TaskID, result)
}
func (s *KeyManualValidationService) validationWorker(wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- bool) {
defer wg.Done()
for key := range jobs {
isValid, _ := s.Validator.ValidateSingleKey(context.Background(), &key, group)
results <- isValid
}
}

View File

@@ -0,0 +1,206 @@
package services
import (
"encoding/json"
"fmt"
"gpt-load/internal/models"
"regexp"
"strings"
"gorm.io/gorm"
)
// AddKeysResult holds the result of adding multiple keys.
type AddKeysResult struct {
AddedCount int `json:"added_count"`
IgnoredCount int `json:"ignored_count"`
TotalInGroup int64 `json:"total_in_group"`
}
// KeyService provides services related to API keys.
type KeyService struct {
DB *gorm.DB
}
// NewKeyService creates a new KeyService.
func NewKeyService(db *gorm.DB) *KeyService {
return &KeyService{DB: db}
}
// AddMultipleKeys handles the business logic of creating new keys from a text block.
func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysResult, error) {
// 1. Parse keys from the text block
keys := s.parseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found in the input text")
}
// 2. Get the group information for validation
var group models.Group
if err := s.DB.First(&group, groupID).Error; err != nil {
return nil, fmt.Errorf("failed to find group: %w", err)
}
// 3. Get existing keys in the group for deduplication
var existingKeys []models.APIKey
if err := s.DB.Where("group_id = ?", groupID).Select("key_value").Find(&existingKeys).Error; err != nil {
return nil, err
}
existingKeyMap := make(map[string]bool)
for _, k := range existingKeys {
existingKeyMap[k.KeyValue] = true
}
// 4. Prepare new keys with basic validation only
var newKeysToCreate []models.APIKey
uniqueNewKeys := make(map[string]bool)
for _, keyVal := range keys {
trimmedKey := strings.TrimSpace(keyVal)
if trimmedKey == "" {
continue
}
// Check if key already exists
if existingKeyMap[trimmedKey] || uniqueNewKeys[trimmedKey] {
continue
}
// 通用验证:只做基础格式检查,不做渠道特定验证
if s.isValidKeyFormat(trimmedKey) {
uniqueNewKeys[trimmedKey] = true
newKeysToCreate = append(newKeysToCreate, models.APIKey{
GroupID: groupID,
KeyValue: trimmedKey,
Status: "active",
})
}
}
addedCount := len(newKeysToCreate)
// 更准确的忽略计数:包括重复的和无效的
ignoredCount := len(keys) - addedCount
// 5. Insert new keys if any
if addedCount > 0 {
if err := s.DB.Create(&newKeysToCreate).Error; err != nil {
return nil, err
}
}
// 6. Get the new total count
var totalInGroup int64
if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil {
return nil, err
}
return &AddKeysResult{
AddedCount: addedCount,
IgnoredCount: ignoredCount,
TotalInGroup: totalInGroup,
}, nil
}
func (s *KeyService) parseKeysFromText(text string) []string {
var keys []string
// First, try to parse as a JSON array of strings
if json.Unmarshal([]byte(text), &keys) == nil && len(keys) > 0 {
return s.filterValidKeys(keys)
}
// 通用解析:通过分隔符分割文本,不使用复杂的正则表达式
delimiters := regexp.MustCompile(`[\s,;|\n\r\t]+`)
splitKeys := delimiters.Split(strings.TrimSpace(text), -1)
for _, key := range splitKeys {
key = strings.TrimSpace(key)
if key != "" {
keys = append(keys, key)
}
}
return s.filterValidKeys(keys)
}
// filterValidKeys validates and filters potential API keys
func (s *KeyService) filterValidKeys(keys []string) []string {
var validKeys []string
for _, key := range keys {
key = strings.TrimSpace(key)
if s.isValidKeyFormat(key) {
validKeys = append(validKeys, key)
}
}
return validKeys
}
// isValidKeyFormat performs basic validation on key format
func (s *KeyService) isValidKeyFormat(key string) bool {
if len(key) < 4 || len(key) > 1000 {
return false
}
if key == "" ||
strings.TrimSpace(key) == "" {
return false
}
validChars := regexp.MustCompile(`^[a-zA-Z0-9_\-./+=:]+$`)
return validChars.MatchString(key)
}
// RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'.
func (s *KeyService) RestoreAllInvalidKeys(groupID uint) (int64, error) {
result := s.DB.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, "inactive").Update("status", "active")
return result.RowsAffected, result.Error
}
// ClearAllInvalidKeys deletes all 'inactive' keys from a group.
func (s *KeyService) ClearAllInvalidKeys(groupID uint) (int64, error) {
result := s.DB.Where("group_id = ? AND status = ?", groupID, "inactive").Delete(&models.APIKey{})
return result.RowsAffected, result.Error
}
// DeleteSingleKey deletes a specific key from a group.
func (s *KeyService) DeleteSingleKey(groupID, keyID uint) (int64, error) {
result := s.DB.Where("group_id = ? AND id = ?", groupID, keyID).Delete(&models.APIKey{})
return result.RowsAffected, result.Error
}
// ExportKeys returns a list of keys for a group, filtered by status.
func (s *KeyService) ExportKeys(groupID uint, filter string) ([]string, error) {
query := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID)
switch filter {
case "valid":
query = query.Where("status = ?", "active")
case "invalid":
query = query.Where("status = ?", "inactive")
case "all":
// No status filter needed
default:
return nil, fmt.Errorf("invalid filter value. Use 'all', 'valid', or 'invalid'")
}
var keys []string
if err := query.Pluck("key_value", &keys).Error; err != nil {
return nil, err
}
return keys, nil
}
// ListKeysInGroup lists all keys within a specific group, filtered by status.
func (s *KeyService) ListKeysInGroup(groupID uint, statusFilter string) ([]models.APIKey, error) {
var keys []models.APIKey
query := s.DB.Where("group_id = ?", groupID)
if statusFilter != "" {
query = query.Where("status = ?", statusFilter)
}
if err := query.Find(&keys).Error; err != nil {
return nil, err
}
return keys, nil
}

View File

@@ -0,0 +1,91 @@
package services
import (
"context"
"fmt"
"gpt-load/internal/channel"
"gpt-load/internal/models"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// KeyValidatorService provides methods to validate API keys.
type KeyValidatorService struct {
DB *gorm.DB
channelFactory *channel.Factory
}
// NewKeyValidatorService creates a new KeyValidatorService.
func NewKeyValidatorService(db *gorm.DB, factory *channel.Factory) *KeyValidatorService {
return &KeyValidatorService{
DB: db,
channelFactory: factory,
}
}
// ValidateSingleKey performs a validation check on a single API key.
// It does not modify the key's state in the database.
// It returns true if the key is valid, and an error if it's not.
func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models.APIKey, group *models.Group) (bool, error) {
// 添加超时保护
if ctx.Err() != nil {
return false, fmt.Errorf("context cancelled or timed out: %w", ctx.Err())
}
ch, err := s.channelFactory.GetChannel(group)
if err != nil {
logrus.WithFields(logrus.Fields{
"group_id": group.ID,
"group_name": group.Name,
"channel_type": group.ChannelType,
"error": err,
}).Error("Failed to get channel for key validation")
return false, fmt.Errorf("failed to get channel for group %s: %w", group.Name, err)
}
// 记录验证开始
logrus.WithFields(logrus.Fields{
"key_id": key.ID,
"group_id": group.ID,
"group_name": group.Name,
}).Debug("Starting key validation")
isValid, validationErr := ch.ValidateKey(ctx, key.KeyValue)
if validationErr != nil {
logrus.WithFields(logrus.Fields{
"key_id": key.ID,
"group_id": group.ID,
"group_name": group.Name,
"error": validationErr,
}).Warn("Key validation failed")
return false, validationErr
}
// 记录验证结果
logrus.WithFields(logrus.Fields{
"key_id": key.ID,
"group_id": group.ID,
"group_name": group.Name,
"is_valid": isValid,
}).Debug("Key validation completed")
return isValid, nil
}
// TestSingleKeyByID performs a synchronous validation test for a single API key by its ID.
// It is intended for handling user-initiated "Test" actions.
// It does not modify the key's state in the database.
func (s *KeyValidatorService) TestSingleKeyByID(ctx context.Context, keyID uint) (bool, error) {
var apiKey models.APIKey
if err := s.DB.First(&apiKey, keyID).Error; err != nil {
return false, fmt.Errorf("failed to find api key with id %d: %w", keyID, err)
}
var group models.Group
if err := s.DB.First(&group, apiKey.GroupID).Error; err != nil {
return false, fmt.Errorf("failed to find group with id %d: %w", apiKey.GroupID, err)
}
return s.ValidateSingleKey(ctx, &apiKey, &group)
}

View File

@@ -0,0 +1,125 @@
package services
import (
"errors"
"sync"
"time"
)
// TaskStatus represents the status of a long-running task.
type TaskStatus struct {
IsRunning bool `json:"is_running"`
GroupName string `json:"group_name,omitempty"`
Processed int `json:"processed,omitempty"`
Total int `json:"total,omitempty"`
TaskID string `json:"task_id,omitempty"`
ExpiresAt time.Time `json:"-"` // Internal field to handle zombie tasks
lastUpdated time.Time
}
// TaskService manages the state of a single, global, long-running task.
type TaskService struct {
mu sync.Mutex
status TaskStatus
resultsCache map[string]interface{}
cacheOrder []string
maxCacheSize int
}
// NewTaskService creates a new TaskService.
func NewTaskService() *TaskService {
return &TaskService{
resultsCache: make(map[string]interface{}),
cacheOrder: make([]string, 0),
maxCacheSize: 100, // Store results for the last 100 tasks
}
}
// StartTask attempts to start a new task. It returns an error if a task is already running.
func (s *TaskService) StartTask(taskID, groupName string, total int, timeout time.Duration) (*TaskStatus, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Zombie task check
if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) {
// The previous task is considered a zombie, reset it.
s.status = TaskStatus{}
}
if s.status.IsRunning {
return nil, errors.New("a task is already running")
}
s.status = TaskStatus{
IsRunning: true,
TaskID: taskID,
GroupName: groupName,
Total: total,
Processed: 0,
ExpiresAt: time.Now().Add(timeout),
lastUpdated: time.Now(),
}
return &s.status, nil
}
// GetTaskStatus returns the current status of the task.
func (s *TaskService) GetTaskStatus() *TaskStatus {
s.mu.Lock()
defer s.mu.Unlock()
// Zombie task check
if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) {
s.status = TaskStatus{} // Reset if expired
}
// Return a copy to prevent race conditions on the caller's side
statusCopy := s.status
return &statusCopy
}
// UpdateProgress updates the progress of the current task.
func (s *TaskService) UpdateProgress(processed int) {
s.mu.Lock()
defer s.mu.Unlock()
if !s.status.IsRunning {
return
}
s.status.Processed = processed
s.status.lastUpdated = time.Now()
}
// EndTask marks the current task as finished.
func (s *TaskService) EndTask() {
s.mu.Lock()
defer s.mu.Unlock()
s.status.IsRunning = false
}
// StoreResult stores the result of a finished task.
func (s *TaskService) StoreResult(taskID string, result interface{}) {
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.resultsCache[taskID]; !exists {
if len(s.cacheOrder) >= s.maxCacheSize {
oldestTaskID := s.cacheOrder[0]
delete(s.resultsCache, oldestTaskID)
s.cacheOrder = s.cacheOrder[1:]
}
s.cacheOrder = append(s.cacheOrder, taskID)
}
s.resultsCache[taskID] = result
}
// GetResult retrieves the result of a finished task.
func (s *TaskService) GetResult(taskID string) (interface{}, bool) {
s.mu.Lock()
defer s.mu.Unlock()
result, found := s.resultsCache[taskID]
return result, found
}