Files
gpt-load/internal/keypool/provider.go
2025-07-09 09:46:38 +08:00

522 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package keypool
import (
"errors"
"fmt"
"gpt-load/internal/config"
app_errors "gpt-load/internal/errors"
"gpt-load/internal/models"
"gpt-load/internal/store"
"strconv"
"time"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
type KeyProvider struct {
db *gorm.DB
store store.Store
settingsManager *config.SystemSettingsManager
}
// NewProvider 创建一个新的 KeyProvider 实例。
func NewProvider(db *gorm.DB, store store.Store, settingsManager *config.SystemSettingsManager) *KeyProvider {
return &KeyProvider{
db: db,
store: store,
settingsManager: settingsManager,
}
}
// SelectKey 为指定的分组原子性地选择并轮换一个可用的 APIKey。
func (p *KeyProvider) SelectKey(groupID uint) (*models.APIKey, error) {
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
// 1. Atomically rotate the key ID from the list
keyIDStr, err := p.store.Rotate(activeKeysListKey)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return nil, app_errors.ErrNoActiveKeys
}
return nil, fmt.Errorf("failed to rotate key from store: %w", err)
}
keyID, err := strconv.ParseUint(keyIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse key ID '%s': %w", keyIDStr, err)
}
// 2. Get key details from HASH
keyHashKey := fmt.Sprintf("key:%d", keyID)
keyDetails, err := p.store.HGetAll(keyHashKey)
if err != nil {
return nil, fmt.Errorf("failed to get key details for key ID %d: %w", keyID, err)
}
// 3. Manually unmarshal the map into an APIKey struct
failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64)
createdAt, _ := strconv.ParseInt(keyDetails["created_at"], 10, 64)
apiKey := &models.APIKey{
ID: uint(keyID),
KeyValue: keyDetails["key_string"],
Status: keyDetails["status"],
FailureCount: failureCount,
GroupID: groupID,
CreatedAt: time.Unix(createdAt, 0),
}
return apiKey, nil
}
// UpdateStatus 异步地提交一个 Key 状态更新任务。
func (p *KeyProvider) UpdateStatus(keyID uint, groupID uint, isSuccess bool) {
go func() {
keyHashKey := fmt.Sprintf("key:%d", keyID)
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
if isSuccess {
if err := p.handleSuccess(keyID, keyHashKey, activeKeysListKey); err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to handle key success")
}
} else {
if err := p.handleFailure(keyID, keyHashKey, activeKeysListKey); err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to handle key failure")
}
}
}()
}
func (p *KeyProvider) handleSuccess(keyID uint, keyHashKey, activeKeysListKey string) error {
keyDetails, err := p.store.HGetAll(keyHashKey)
if err != nil {
return fmt.Errorf("failed to get key details from store: %w", err)
}
failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64)
isActive := keyDetails["status"] == models.KeyStatusActive
if failureCount == 0 && isActive {
return nil
}
return p.db.Transaction(func(tx *gorm.DB) error {
var key models.APIKey
if err := tx.Set("gorm:query_option", "FOR UPDATE").First(&key, keyID).Error; err != nil {
return fmt.Errorf("failed to lock key %d for update: %w", keyID, err)
}
updates := map[string]any{"failure_count": 0}
if !isActive {
updates["status"] = models.KeyStatusActive
}
if err := tx.Model(&key).Updates(updates).Error; err != nil {
return fmt.Errorf("failed to update key in DB: %w", err)
}
if err := p.store.HSet(keyHashKey, updates); err != nil {
return fmt.Errorf("failed to update key details in store: %w", err)
}
if !isActive {
logrus.WithField("keyID", keyID).Info("Key has recovered and is being restored to active pool.")
if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil {
return fmt.Errorf("failed to LRem key before LPush on recovery: %w", err)
}
if err := p.store.LPush(activeKeysListKey, keyID); err != nil {
return fmt.Errorf("failed to LPush key back to active list: %w", err)
}
}
return nil
})
}
func (p *KeyProvider) handleFailure(keyID uint, keyHashKey, activeKeysListKey string) error {
keyDetails, err := p.store.HGetAll(keyHashKey)
if err != nil {
return fmt.Errorf("failed to get key details from store: %w", err)
}
failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64)
if keyDetails["status"] == models.KeyStatusInvalid {
return nil
}
settings := p.settingsManager.GetSettings()
blacklistThreshold := settings.BlacklistThreshold
return p.db.Transaction(func(tx *gorm.DB) error {
var key models.APIKey
if err := tx.Set("gorm:query_option", "FOR UPDATE").First(&key, keyID).Error; err != nil {
return fmt.Errorf("failed to lock key %d for update: %w", keyID, err)
}
newFailureCount := failureCount + 1
updates := map[string]any{"failure_count": newFailureCount}
shouldBlacklist := newFailureCount >= int64(blacklistThreshold)
if shouldBlacklist {
updates["status"] = models.KeyStatusInvalid
}
if err := tx.Model(&key).Updates(updates).Error; err != nil {
return fmt.Errorf("failed to update key stats in DB: %w", err)
}
if _, err := p.store.HIncrBy(keyHashKey, "failure_count", 1); err != nil {
return fmt.Errorf("failed to increment failure count in store: %w", err)
}
if shouldBlacklist {
logrus.WithFields(logrus.Fields{"keyID": keyID, "threshold": blacklistThreshold}).Warn("Key has reached blacklist threshold, disabling.")
if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil {
return fmt.Errorf("failed to LRem key from active list: %w", err)
}
if err := p.store.HSet(keyHashKey, map[string]any{"status": models.KeyStatusInvalid}); err != nil {
return fmt.Errorf("failed to update key status to invalid in store: %w", err)
}
}
return nil
})
}
// LoadKeysFromDB 从数据库加载所有分组和密钥,并填充到 Store 中。
func (p *KeyProvider) LoadKeysFromDB() error {
// 1. 分批从数据库加载并使用 Pipeline 写入 Redis
allActiveKeyIDs := make(map[uint][]any)
batchSize := 1000
var batchKeys []*models.APIKey
err := p.db.Model(&models.APIKey{}).FindInBatches(&batchKeys, batchSize, func(tx *gorm.DB, batch int) error {
logrus.Debugf("Processing batch %d with %d keys...", batch, len(batchKeys))
var pipeline store.Pipeliner
if redisStore, ok := p.store.(store.RedisPipeliner); ok {
pipeline = redisStore.Pipeline()
}
for _, key := range batchKeys {
keyHashKey := fmt.Sprintf("key:%d", key.ID)
keyDetails := p.apiKeyToMap(key)
if pipeline != nil {
pipeline.HSet(keyHashKey, keyDetails)
} else {
if err := p.store.HSet(keyHashKey, keyDetails); err != nil {
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to HSet key details")
}
}
if key.Status == models.KeyStatusActive {
allActiveKeyIDs[key.GroupID] = append(allActiveKeyIDs[key.GroupID], key.ID)
}
}
if pipeline != nil {
if err := pipeline.Exec(); err != nil {
return fmt.Errorf("failed to execute pipeline for batch %d: %w", batch, err)
}
}
return nil
}).Error
if err != nil {
return fmt.Errorf("failed during batch processing of keys: %w", err)
}
// 2. 更新所有分组的 active_keys 列表
logrus.Info("Updating active key lists for all groups...")
for groupID, activeIDs := range allActiveKeyIDs {
if len(activeIDs) > 0 {
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
p.store.Delete(activeKeysListKey)
if err := p.store.LPush(activeKeysListKey, activeIDs...); err != nil {
logrus.WithFields(logrus.Fields{"groupID": groupID, "error": err}).Error("Failed to LPush active keys for group")
}
}
}
return nil
}
// AddKeys 批量添加新的 Key 到池和数据库中。
func (p *KeyProvider) AddKeys(groupID uint, keys []models.APIKey) error {
if len(keys) == 0 {
return nil
}
err := p.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(&keys).Error; err != nil {
return err
}
for _, key := range keys {
if err := p.addKeyToStore(&key); err != nil {
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to add key to store after DB creation, rolling back transaction")
return err
}
}
return nil
})
return err
}
// RemoveKeys 批量从池和数据库中移除 Key。
func (p *KeyProvider) RemoveKeys(groupID uint, keyValues []string) (int64, error) {
if len(keyValues) == 0 {
return 0, nil
}
var keysToDelete []models.APIKey
var deletedCount int64
err := p.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("group_id = ? AND key_value IN ?", groupID, keyValues).Find(&keysToDelete).Error; err != nil {
return err
}
if len(keysToDelete) == 0 {
return nil
}
keyIDsToDelete := pluckIDs(keysToDelete)
result := tx.Where("id IN ?", keyIDsToDelete).Delete(&models.APIKey{})
if result.Error != nil {
return result.Error
}
deletedCount = result.RowsAffected
for _, key := range keysToDelete {
if err := p.removeKeyFromStore(key.ID, key.GroupID); err != nil {
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove key from store after DB deletion, rolling back transaction")
return err
}
}
return nil
})
return deletedCount, err
}
// RestoreKeys 恢复组内所有无效的 Key。
func (p *KeyProvider) RestoreKeys(groupID uint) (int64, error) {
var invalidKeys []models.APIKey
var restoredCount int64
err := p.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Find(&invalidKeys).Error; err != nil {
return err
}
if len(invalidKeys) == 0 {
return nil
}
updates := map[string]any{
"status": models.KeyStatusActive,
"failure_count": 0,
}
result := tx.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Updates(updates)
if result.Error != nil {
return result.Error
}
restoredCount = result.RowsAffected
for _, key := range invalidKeys {
key.Status = models.KeyStatusActive
key.FailureCount = 0
if err := p.addKeyToStore(&key); err != nil {
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to restore key in store after DB update, rolling back transaction")
return err
}
}
return nil
})
return restoredCount, err
}
// RestoreMultipleKeys 恢复指定的 Key。
func (p *KeyProvider) RestoreMultipleKeys(groupID uint, keyValues []string) (int64, error) {
if len(keyValues) == 0 {
return 0, nil
}
var keysToRestore []models.APIKey
var restoredCount int64
err := p.db.Transaction(func(tx *gorm.DB) error {
// 1. 查找要恢复的密钥
if err := tx.Where("group_id = ? AND key_value IN ? AND status = ?", groupID, keyValues, models.KeyStatusInvalid).Find(&keysToRestore).Error; err != nil {
return err
}
if len(keysToRestore) == 0 {
return nil
}
keyIDsToRestore := pluckIDs(keysToRestore)
// 2. 更新数据库中的状态
updates := map[string]any{
"status": models.KeyStatusActive,
"failure_count": 0,
}
result := tx.Model(&models.APIKey{}).Where("id IN ?", keyIDsToRestore).Updates(updates)
if result.Error != nil {
return result.Error
}
restoredCount = result.RowsAffected
// 3. 将密钥添加回 Redis
for _, key := range keysToRestore {
key.Status = models.KeyStatusActive
key.FailureCount = 0
if err := p.addKeyToStore(&key); err != nil {
// 在事务中,单个失败会回滚整个事务,但这里的日志记录仍然有用
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to restore key in store after DB update")
return err // 返回错误以回滚事务
}
}
return nil
})
return restoredCount, err
}
// RemoveInvalidKeys 移除组内所有无效的 Key。
func (p *KeyProvider) RemoveInvalidKeys(groupID uint) (int64, error) {
var invalidKeys []models.APIKey
var removedCount int64
err := p.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Find(&invalidKeys).Error; err != nil {
return err
}
if len(invalidKeys) == 0 {
return nil
}
result := tx.Where("id IN ?", pluckIDs(invalidKeys)).Delete(&models.APIKey{})
if result.Error != nil {
return result.Error
}
removedCount = result.RowsAffected
for _, key := range invalidKeys {
if err := p.removeKeyFromStore(key.ID, key.GroupID); err != nil {
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove invalid key from store after DB deletion, rolling back transaction")
return err
}
}
return nil
})
return removedCount, err
}
// RemoveKeysFromStore 直接从内存存储中移除指定的键,不涉及数据库操作
// 这个方法适用于数据库已经删除但需要清理内存存储的场景
func (p *KeyProvider) RemoveKeysFromStore(groupID uint, keyIDs []uint) error {
if len(keyIDs) == 0 {
return nil
}
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
// 第一步:直接删除整个 active_keys 列表
if err := p.store.Delete(activeKeysListKey); err != nil {
logrus.WithFields(logrus.Fields{
"groupID": groupID,
"error": err,
}).Error("Failed to delete active keys list")
return err
}
// 第二步批量删除所有相关的key hash
for _, keyID := range keyIDs {
keyHashKey := fmt.Sprintf("key:%d", keyID)
if err := p.store.Delete(keyHashKey); err != nil {
logrus.WithFields(logrus.Fields{
"keyID": keyID,
"error": err,
}).Error("Failed to delete key hash")
}
}
logrus.WithFields(logrus.Fields{
"groupID": groupID,
"keyCount": len(keyIDs),
}).Info("Successfully cleaned up group keys from store")
return nil
}
// addKeyToStore is a helper to add a single key to the cache.
func (p *KeyProvider) addKeyToStore(key *models.APIKey) error {
// 1. Store key details in HASH
keyHashKey := fmt.Sprintf("key:%d", key.ID)
keyDetails := p.apiKeyToMap(key)
if err := p.store.HSet(keyHashKey, keyDetails); err != nil {
return fmt.Errorf("failed to HSet key details for key %d: %w", key.ID, err)
}
// 2. If active, add to the active LIST
if key.Status == models.KeyStatusActive {
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", key.GroupID)
if err := p.store.LRem(activeKeysListKey, 0, key.ID); err != nil {
return fmt.Errorf("failed to LRem key %d before LPush for group %d: %w", key.ID, key.GroupID, err)
}
if err := p.store.LPush(activeKeysListKey, key.ID); err != nil {
return fmt.Errorf("failed to LPush key %d to group %d: %w", key.ID, key.GroupID, err)
}
}
return nil
}
// removeKeyFromStore is a helper to remove a single key from the cache.
func (p *KeyProvider) removeKeyFromStore(keyID, groupID uint) error {
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "groupID": groupID, "error": err}).Error("Failed to LRem key from active list")
}
keyHashKey := fmt.Sprintf("key:%d", keyID)
if err := p.store.Delete(keyHashKey); err != nil {
return fmt.Errorf("failed to delete key HASH for key %d: %w", keyID, err)
}
return nil
}
// apiKeyToMap converts an APIKey model to a map for HSET.
func (p *KeyProvider) apiKeyToMap(key *models.APIKey) map[string]any {
return map[string]any{
"id": fmt.Sprint(key.ID),
"key_string": key.KeyValue,
"status": key.Status,
"failure_count": key.FailureCount,
"group_id": key.GroupID,
"created_at": key.CreatedAt.Unix(),
}
}
// pluckIDs extracts IDs from a slice of APIKey.
func pluckIDs(keys []models.APIKey) []uint {
ids := make([]uint, len(keys))
for i, key := range keys {
ids[i] = key.ID
}
return ids
}