package keypool import ( "errors" "fmt" "gpt-load/internal/config" app_errors "gpt-load/internal/errors" "gpt-load/internal/models" "gpt-load/internal/store" "strconv" "time" "github.com/sirupsen/logrus" "gorm.io/gorm" ) type KeyProvider struct { db *gorm.DB store store.Store settingsManager *config.SystemSettingsManager } // NewProvider 创建一个新的 KeyProvider 实例。 func NewProvider(db *gorm.DB, store store.Store, settingsManager *config.SystemSettingsManager) *KeyProvider { return &KeyProvider{ db: db, store: store, settingsManager: settingsManager, } } // SelectKey 为指定的分组原子性地选择并轮换一个可用的 APIKey。 func (p *KeyProvider) SelectKey(groupID uint) (*models.APIKey, error) { activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID) // 1. Atomically rotate the key ID from the list keyIDStr, err := p.store.Rotate(activeKeysListKey) if err != nil { if errors.Is(err, store.ErrNotFound) { return nil, app_errors.ErrNoActiveKeys } return nil, fmt.Errorf("failed to rotate key from store: %w", err) } keyID, err := strconv.ParseUint(keyIDStr, 10, 64) if err != nil { return nil, fmt.Errorf("failed to parse key ID '%s': %w", keyIDStr, err) } // 2. Get key details from HASH keyHashKey := fmt.Sprintf("key:%d", keyID) keyDetails, err := p.store.HGetAll(keyHashKey) if err != nil { return nil, fmt.Errorf("failed to get key details for key ID %d: %w", keyID, err) } // 3. Manually unmarshal the map into an APIKey struct failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64) createdAt, _ := strconv.ParseInt(keyDetails["created_at"], 10, 64) apiKey := &models.APIKey{ ID: uint(keyID), KeyValue: keyDetails["key_string"], Status: keyDetails["status"], FailureCount: failureCount, GroupID: groupID, CreatedAt: time.Unix(createdAt, 0), } return apiKey, nil } // UpdateStatus 异步地提交一个 Key 状态更新任务。 func (p *KeyProvider) UpdateStatus(apiKey *models.APIKey, group *models.Group, isSuccess bool) { go func() { keyHashKey := fmt.Sprintf("key:%d", apiKey.ID) activeKeysListKey := fmt.Sprintf("group:%d:active_keys", group.ID) if isSuccess { if err := p.handleSuccess(apiKey.ID, keyHashKey, activeKeysListKey); err != nil { logrus.WithFields(logrus.Fields{"keyID": apiKey.ID, "error": err}).Error("Failed to handle key success") } } else { if err := p.handleFailure(apiKey, group, keyHashKey, activeKeysListKey); err != nil { logrus.WithFields(logrus.Fields{"keyID": apiKey.ID, "error": err}).Error("Failed to handle key failure") } } }() } func (p *KeyProvider) handleSuccess(keyID uint, keyHashKey, activeKeysListKey string) error { keyDetails, err := p.store.HGetAll(keyHashKey) if err != nil { return fmt.Errorf("failed to get key details from store: %w", err) } failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64) isActive := keyDetails["status"] == models.KeyStatusActive if failureCount == 0 && isActive { return nil } return p.db.Transaction(func(tx *gorm.DB) error { var key models.APIKey if err := tx.Set("gorm:query_option", "FOR UPDATE").First(&key, keyID).Error; err != nil { return fmt.Errorf("failed to lock key %d for update: %w", keyID, err) } updates := map[string]any{"failure_count": 0} if !isActive { updates["status"] = models.KeyStatusActive } if err := tx.Model(&key).Updates(updates).Error; err != nil { return fmt.Errorf("failed to update key in DB: %w", err) } if err := p.store.HSet(keyHashKey, updates); err != nil { return fmt.Errorf("failed to update key details in store: %w", err) } if !isActive { logrus.WithField("keyID", keyID).Info("Key has recovered and is being restored to active pool.") if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil { return fmt.Errorf("failed to LRem key before LPush on recovery: %w", err) } if err := p.store.LPush(activeKeysListKey, keyID); err != nil { return fmt.Errorf("failed to LPush key back to active list: %w", err) } } return nil }) } func (p *KeyProvider) handleFailure(apiKey *models.APIKey, group *models.Group, keyHashKey, activeKeysListKey string) error { keyDetails, err := p.store.HGetAll(keyHashKey) if err != nil { return fmt.Errorf("failed to get key details from store: %w", err) } failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64) if keyDetails["status"] == models.KeyStatusInvalid { return nil } // 获取该分组的有效配置 blacklistThreshold := group.EffectiveConfig.BlacklistThreshold return p.db.Transaction(func(tx *gorm.DB) error { var key models.APIKey if err := tx.Set("gorm:query_option", "FOR UPDATE").First(&key, apiKey.ID).Error; err != nil { return fmt.Errorf("failed to lock key %d for update: %w", apiKey.ID, err) } newFailureCount := failureCount + 1 updates := map[string]any{"failure_count": newFailureCount} shouldBlacklist := blacklistThreshold > 0 && newFailureCount >= int64(blacklistThreshold) if shouldBlacklist { updates["status"] = models.KeyStatusInvalid } if err := tx.Model(&key).Updates(updates).Error; err != nil { return fmt.Errorf("failed to update key stats in DB: %w", err) } if _, err := p.store.HIncrBy(keyHashKey, "failure_count", 1); err != nil { return fmt.Errorf("failed to increment failure count in store: %w", err) } if shouldBlacklist { logrus.WithFields(logrus.Fields{"keyID": apiKey.ID, "threshold": blacklistThreshold}).Warn("Key has reached blacklist threshold, disabling.") if err := p.store.LRem(activeKeysListKey, 0, apiKey.ID); err != nil { return fmt.Errorf("failed to LRem key from active list: %w", err) } if err := p.store.HSet(keyHashKey, map[string]any{"status": models.KeyStatusInvalid}); err != nil { return fmt.Errorf("failed to update key status to invalid in store: %w", err) } } return nil }) } // LoadKeysFromDB 从数据库加载所有分组和密钥,并填充到 Store 中。 func (p *KeyProvider) LoadKeysFromDB() error { initFlagKey := "initialization:db_keys_loaded" exists, err := p.store.Exists(initFlagKey) if err != nil { return fmt.Errorf("failed to check initialization flag: %w", err) } if exists { logrus.Debug("Keys have already been loaded into the store. Skipping.") return nil } logrus.Debug("First time startup, loading keys from DB...") // 1. 分批从数据库加载并使用 Pipeline 写入 Redis allActiveKeyIDs := make(map[uint][]any) batchSize := 1000 var batchKeys []*models.APIKey err = p.db.Model(&models.APIKey{}).FindInBatches(&batchKeys, batchSize, func(tx *gorm.DB, batch int) error { logrus.Debugf("Processing batch %d with %d keys...", batch, len(batchKeys)) var pipeline store.Pipeliner if redisStore, ok := p.store.(store.RedisPipeliner); ok { pipeline = redisStore.Pipeline() } for _, key := range batchKeys { keyHashKey := fmt.Sprintf("key:%d", key.ID) keyDetails := p.apiKeyToMap(key) if pipeline != nil { pipeline.HSet(keyHashKey, keyDetails) } else { if err := p.store.HSet(keyHashKey, keyDetails); err != nil { logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to HSet key details") } } if key.Status == models.KeyStatusActive { allActiveKeyIDs[key.GroupID] = append(allActiveKeyIDs[key.GroupID], key.ID) } } if pipeline != nil { if err := pipeline.Exec(); err != nil { return fmt.Errorf("failed to execute pipeline for batch %d: %w", batch, err) } } return nil }).Error if err != nil { return fmt.Errorf("failed during batch processing of keys: %w", err) } // 2. 更新所有分组的 active_keys 列表 logrus.Info("Updating active key lists for all groups...") for groupID, activeIDs := range allActiveKeyIDs { if len(activeIDs) > 0 { activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID) p.store.Delete(activeKeysListKey) if err := p.store.LPush(activeKeysListKey, activeIDs...); err != nil { logrus.WithFields(logrus.Fields{"groupID": groupID, "error": err}).Error("Failed to LPush active keys for group") } } } if err := p.store.Set(initFlagKey, []byte("1"), 0); err != nil { logrus.WithField("flagKey", initFlagKey).Error("Failed to set initialization flag after loading keys") } return nil } // AddKeys 批量添加新的 Key 到池和数据库中。 func (p *KeyProvider) AddKeys(groupID uint, keys []models.APIKey) error { if len(keys) == 0 { return nil } err := p.db.Transaction(func(tx *gorm.DB) error { if err := tx.Create(&keys).Error; err != nil { return err } for _, key := range keys { if err := p.addKeyToStore(&key); err != nil { logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to add key to store after DB creation, rolling back transaction") return err } } return nil }) return err } // RemoveKeys 批量从池和数据库中移除 Key。 func (p *KeyProvider) RemoveKeys(groupID uint, keyValues []string) (int64, error) { if len(keyValues) == 0 { return 0, nil } var keysToDelete []models.APIKey var deletedCount int64 err := p.db.Transaction(func(tx *gorm.DB) error { if err := tx.Where("group_id = ? AND key_value IN ?", groupID, keyValues).Find(&keysToDelete).Error; err != nil { return err } if len(keysToDelete) == 0 { return nil } keyIDsToDelete := pluckIDs(keysToDelete) result := tx.Where("id IN ?", keyIDsToDelete).Delete(&models.APIKey{}) if result.Error != nil { return result.Error } deletedCount = result.RowsAffected for _, key := range keysToDelete { if err := p.removeKeyFromStore(key.ID, key.GroupID); err != nil { logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove key from store after DB deletion, rolling back transaction") return err } } return nil }) return deletedCount, err } // RestoreKeys 恢复组内所有无效的 Key。 func (p *KeyProvider) RestoreKeys(groupID uint) (int64, error) { var invalidKeys []models.APIKey var restoredCount int64 err := p.db.Transaction(func(tx *gorm.DB) error { if err := tx.Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Find(&invalidKeys).Error; err != nil { return err } if len(invalidKeys) == 0 { return nil } updates := map[string]any{ "status": models.KeyStatusActive, "failure_count": 0, } result := tx.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Updates(updates) if result.Error != nil { return result.Error } restoredCount = result.RowsAffected for _, key := range invalidKeys { key.Status = models.KeyStatusActive key.FailureCount = 0 if err := p.addKeyToStore(&key); err != nil { logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to restore key in store after DB update, rolling back transaction") return err } } return nil }) return restoredCount, err } // RestoreMultipleKeys 恢复指定的 Key。 func (p *KeyProvider) RestoreMultipleKeys(groupID uint, keyValues []string) (int64, error) { if len(keyValues) == 0 { return 0, nil } var keysToRestore []models.APIKey var restoredCount int64 err := p.db.Transaction(func(tx *gorm.DB) error { // 1. 查找要恢复的密钥 if err := tx.Where("group_id = ? AND key_value IN ? AND status = ?", groupID, keyValues, models.KeyStatusInvalid).Find(&keysToRestore).Error; err != nil { return err } if len(keysToRestore) == 0 { return nil } keyIDsToRestore := pluckIDs(keysToRestore) // 2. 更新数据库中的状态 updates := map[string]any{ "status": models.KeyStatusActive, "failure_count": 0, } result := tx.Model(&models.APIKey{}).Where("id IN ?", keyIDsToRestore).Updates(updates) if result.Error != nil { return result.Error } restoredCount = result.RowsAffected // 3. 将密钥添加回 Redis for _, key := range keysToRestore { key.Status = models.KeyStatusActive key.FailureCount = 0 if err := p.addKeyToStore(&key); err != nil { // 在事务中,单个失败会回滚整个事务,但这里的日志记录仍然有用 logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to restore key in store after DB update") return err // 返回错误以回滚事务 } } return nil }) return restoredCount, err } // RemoveInvalidKeys 移除组内所有无效的 Key。 func (p *KeyProvider) RemoveInvalidKeys(groupID uint) (int64, error) { var invalidKeys []models.APIKey var removedCount int64 err := p.db.Transaction(func(tx *gorm.DB) error { if err := tx.Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Find(&invalidKeys).Error; err != nil { return err } if len(invalidKeys) == 0 { return nil } result := tx.Where("id IN ?", pluckIDs(invalidKeys)).Delete(&models.APIKey{}) if result.Error != nil { return result.Error } removedCount = result.RowsAffected for _, key := range invalidKeys { if err := p.removeKeyFromStore(key.ID, key.GroupID); err != nil { logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove invalid key from store after DB deletion, rolling back transaction") return err } } return nil }) return removedCount, err } // RemoveKeysFromStore 直接从内存存储中移除指定的键,不涉及数据库操作 // 这个方法适用于数据库已经删除但需要清理内存存储的场景 func (p *KeyProvider) RemoveKeysFromStore(groupID uint, keyIDs []uint) error { if len(keyIDs) == 0 { return nil } activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID) // 第一步:直接删除整个 active_keys 列表 if err := p.store.Delete(activeKeysListKey); err != nil { logrus.WithFields(logrus.Fields{ "groupID": groupID, "error": err, }).Error("Failed to delete active keys list") return err } // 第二步:批量删除所有相关的key hash for _, keyID := range keyIDs { keyHashKey := fmt.Sprintf("key:%d", keyID) if err := p.store.Delete(keyHashKey); err != nil { logrus.WithFields(logrus.Fields{ "keyID": keyID, "error": err, }).Error("Failed to delete key hash") } } logrus.WithFields(logrus.Fields{ "groupID": groupID, "keyCount": len(keyIDs), }).Info("Successfully cleaned up group keys from store") return nil } // addKeyToStore is a helper to add a single key to the cache. func (p *KeyProvider) addKeyToStore(key *models.APIKey) error { // 1. Store key details in HASH keyHashKey := fmt.Sprintf("key:%d", key.ID) keyDetails := p.apiKeyToMap(key) if err := p.store.HSet(keyHashKey, keyDetails); err != nil { return fmt.Errorf("failed to HSet key details for key %d: %w", key.ID, err) } // 2. If active, add to the active LIST if key.Status == models.KeyStatusActive { activeKeysListKey := fmt.Sprintf("group:%d:active_keys", key.GroupID) if err := p.store.LRem(activeKeysListKey, 0, key.ID); err != nil { return fmt.Errorf("failed to LRem key %d before LPush for group %d: %w", key.ID, key.GroupID, err) } if err := p.store.LPush(activeKeysListKey, key.ID); err != nil { return fmt.Errorf("failed to LPush key %d to group %d: %w", key.ID, key.GroupID, err) } } return nil } // removeKeyFromStore is a helper to remove a single key from the cache. func (p *KeyProvider) removeKeyFromStore(keyID, groupID uint) error { activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID) if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil { logrus.WithFields(logrus.Fields{"keyID": keyID, "groupID": groupID, "error": err}).Error("Failed to LRem key from active list") } keyHashKey := fmt.Sprintf("key:%d", keyID) if err := p.store.Delete(keyHashKey); err != nil { return fmt.Errorf("failed to delete key HASH for key %d: %w", keyID, err) } return nil } // apiKeyToMap converts an APIKey model to a map for HSET. func (p *KeyProvider) apiKeyToMap(key *models.APIKey) map[string]any { return map[string]any{ "id": fmt.Sprint(key.ID), "key_string": key.KeyValue, "status": key.Status, "failure_count": key.FailureCount, "group_id": key.GroupID, "created_at": key.CreatedAt.Unix(), } } // pluckIDs extracts IDs from a slice of APIKey. func pluckIDs(keys []models.APIKey) []uint { ids := make([]uint, len(keys)) for i, key := range keys { ids[i] = key.ID } return ids }