522 lines
16 KiB
Go
522 lines
16 KiB
Go
package keypool
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"gpt-load/internal/config"
|
||
app_errors "gpt-load/internal/errors"
|
||
"gpt-load/internal/models"
|
||
"gpt-load/internal/store"
|
||
"strconv"
|
||
"time"
|
||
|
||
"github.com/sirupsen/logrus"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type KeyProvider struct {
|
||
db *gorm.DB
|
||
store store.Store
|
||
settingsManager *config.SystemSettingsManager
|
||
}
|
||
|
||
// NewProvider 创建一个新的 KeyProvider 实例。
|
||
func NewProvider(db *gorm.DB, store store.Store, settingsManager *config.SystemSettingsManager) *KeyProvider {
|
||
return &KeyProvider{
|
||
db: db,
|
||
store: store,
|
||
settingsManager: settingsManager,
|
||
}
|
||
}
|
||
|
||
// SelectKey 为指定的分组原子性地选择并轮换一个可用的 APIKey。
|
||
func (p *KeyProvider) SelectKey(groupID uint) (*models.APIKey, error) {
|
||
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
|
||
|
||
// 1. Atomically rotate the key ID from the list
|
||
keyIDStr, err := p.store.Rotate(activeKeysListKey)
|
||
if err != nil {
|
||
if errors.Is(err, store.ErrNotFound) {
|
||
return nil, app_errors.ErrNoActiveKeys
|
||
}
|
||
return nil, fmt.Errorf("failed to rotate key from store: %w", err)
|
||
}
|
||
|
||
keyID, err := strconv.ParseUint(keyIDStr, 10, 64)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to parse key ID '%s': %w", keyIDStr, err)
|
||
}
|
||
|
||
// 2. Get key details from HASH
|
||
keyHashKey := fmt.Sprintf("key:%d", keyID)
|
||
keyDetails, err := p.store.HGetAll(keyHashKey)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get key details for key ID %d: %w", keyID, err)
|
||
}
|
||
|
||
// 3. Manually unmarshal the map into an APIKey struct
|
||
failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64)
|
||
createdAt, _ := strconv.ParseInt(keyDetails["created_at"], 10, 64)
|
||
|
||
apiKey := &models.APIKey{
|
||
ID: uint(keyID),
|
||
KeyValue: keyDetails["key_string"],
|
||
Status: keyDetails["status"],
|
||
FailureCount: failureCount,
|
||
GroupID: groupID,
|
||
CreatedAt: time.Unix(createdAt, 0),
|
||
}
|
||
|
||
return apiKey, nil
|
||
}
|
||
|
||
// UpdateStatus 异步地提交一个 Key 状态更新任务。
|
||
func (p *KeyProvider) UpdateStatus(keyID uint, groupID uint, isSuccess bool) {
|
||
go func() {
|
||
keyHashKey := fmt.Sprintf("key:%d", keyID)
|
||
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
|
||
|
||
if isSuccess {
|
||
if err := p.handleSuccess(keyID, keyHashKey, activeKeysListKey); err != nil {
|
||
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to handle key success")
|
||
}
|
||
} else {
|
||
if err := p.handleFailure(keyID, keyHashKey, activeKeysListKey); err != nil {
|
||
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to handle key failure")
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
func (p *KeyProvider) handleSuccess(keyID uint, keyHashKey, activeKeysListKey string) error {
|
||
keyDetails, err := p.store.HGetAll(keyHashKey)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get key details from store: %w", err)
|
||
}
|
||
|
||
failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64)
|
||
isActive := keyDetails["status"] == models.KeyStatusActive
|
||
|
||
if failureCount == 0 && isActive {
|
||
return nil
|
||
}
|
||
|
||
return p.db.Transaction(func(tx *gorm.DB) error {
|
||
var key models.APIKey
|
||
if err := tx.Set("gorm:query_option", "FOR UPDATE").First(&key, keyID).Error; err != nil {
|
||
return fmt.Errorf("failed to lock key %d for update: %w", keyID, err)
|
||
}
|
||
|
||
updates := map[string]any{"failure_count": 0}
|
||
if !isActive {
|
||
updates["status"] = models.KeyStatusActive
|
||
}
|
||
|
||
if err := tx.Model(&key).Updates(updates).Error; err != nil {
|
||
return fmt.Errorf("failed to update key in DB: %w", err)
|
||
}
|
||
|
||
if err := p.store.HSet(keyHashKey, updates); err != nil {
|
||
return fmt.Errorf("failed to update key details in store: %w", err)
|
||
}
|
||
|
||
if !isActive {
|
||
logrus.WithField("keyID", keyID).Info("Key has recovered and is being restored to active pool.")
|
||
if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil {
|
||
return fmt.Errorf("failed to LRem key before LPush on recovery: %w", err)
|
||
}
|
||
if err := p.store.LPush(activeKeysListKey, keyID); err != nil {
|
||
return fmt.Errorf("failed to LPush key back to active list: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func (p *KeyProvider) handleFailure(keyID uint, keyHashKey, activeKeysListKey string) error {
|
||
keyDetails, err := p.store.HGetAll(keyHashKey)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get key details from store: %w", err)
|
||
}
|
||
|
||
failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64)
|
||
|
||
if keyDetails["status"] == models.KeyStatusInvalid {
|
||
return nil
|
||
}
|
||
|
||
settings := p.settingsManager.GetSettings()
|
||
blacklistThreshold := settings.BlacklistThreshold
|
||
|
||
return p.db.Transaction(func(tx *gorm.DB) error {
|
||
var key models.APIKey
|
||
if err := tx.Set("gorm:query_option", "FOR UPDATE").First(&key, keyID).Error; err != nil {
|
||
return fmt.Errorf("failed to lock key %d for update: %w", keyID, err)
|
||
}
|
||
|
||
newFailureCount := failureCount + 1
|
||
|
||
updates := map[string]any{"failure_count": newFailureCount}
|
||
shouldBlacklist := newFailureCount >= int64(blacklistThreshold)
|
||
if shouldBlacklist {
|
||
updates["status"] = models.KeyStatusInvalid
|
||
}
|
||
|
||
if err := tx.Model(&key).Updates(updates).Error; err != nil {
|
||
return fmt.Errorf("failed to update key stats in DB: %w", err)
|
||
}
|
||
|
||
if _, err := p.store.HIncrBy(keyHashKey, "failure_count", 1); err != nil {
|
||
return fmt.Errorf("failed to increment failure count in store: %w", err)
|
||
}
|
||
|
||
if shouldBlacklist {
|
||
logrus.WithFields(logrus.Fields{"keyID": keyID, "threshold": blacklistThreshold}).Warn("Key has reached blacklist threshold, disabling.")
|
||
if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil {
|
||
return fmt.Errorf("failed to LRem key from active list: %w", err)
|
||
}
|
||
if err := p.store.HSet(keyHashKey, map[string]any{"status": models.KeyStatusInvalid}); err != nil {
|
||
return fmt.Errorf("failed to update key status to invalid in store: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
// LoadKeysFromDB 从数据库加载所有分组和密钥,并填充到 Store 中。
|
||
func (p *KeyProvider) LoadKeysFromDB() error {
|
||
|
||
// 1. 分批从数据库加载并使用 Pipeline 写入 Redis
|
||
allActiveKeyIDs := make(map[uint][]any)
|
||
batchSize := 1000
|
||
var batchKeys []*models.APIKey
|
||
|
||
err := p.db.Model(&models.APIKey{}).FindInBatches(&batchKeys, batchSize, func(tx *gorm.DB, batch int) error {
|
||
logrus.Debugf("Processing batch %d with %d keys...", batch, len(batchKeys))
|
||
|
||
var pipeline store.Pipeliner
|
||
if redisStore, ok := p.store.(store.RedisPipeliner); ok {
|
||
pipeline = redisStore.Pipeline()
|
||
}
|
||
|
||
for _, key := range batchKeys {
|
||
keyHashKey := fmt.Sprintf("key:%d", key.ID)
|
||
keyDetails := p.apiKeyToMap(key)
|
||
|
||
if pipeline != nil {
|
||
pipeline.HSet(keyHashKey, keyDetails)
|
||
} else {
|
||
if err := p.store.HSet(keyHashKey, keyDetails); err != nil {
|
||
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to HSet key details")
|
||
}
|
||
}
|
||
|
||
if key.Status == models.KeyStatusActive {
|
||
allActiveKeyIDs[key.GroupID] = append(allActiveKeyIDs[key.GroupID], key.ID)
|
||
}
|
||
}
|
||
|
||
if pipeline != nil {
|
||
if err := pipeline.Exec(); err != nil {
|
||
return fmt.Errorf("failed to execute pipeline for batch %d: %w", batch, err)
|
||
}
|
||
}
|
||
return nil
|
||
}).Error
|
||
|
||
if err != nil {
|
||
return fmt.Errorf("failed during batch processing of keys: %w", err)
|
||
}
|
||
|
||
// 2. 更新所有分组的 active_keys 列表
|
||
logrus.Info("Updating active key lists for all groups...")
|
||
for groupID, activeIDs := range allActiveKeyIDs {
|
||
if len(activeIDs) > 0 {
|
||
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
|
||
p.store.Delete(activeKeysListKey)
|
||
if err := p.store.LPush(activeKeysListKey, activeIDs...); err != nil {
|
||
logrus.WithFields(logrus.Fields{"groupID": groupID, "error": err}).Error("Failed to LPush active keys for group")
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// AddKeys 批量添加新的 Key 到池和数据库中。
|
||
func (p *KeyProvider) AddKeys(groupID uint, keys []models.APIKey) error {
|
||
if len(keys) == 0 {
|
||
return nil
|
||
}
|
||
|
||
err := p.db.Transaction(func(tx *gorm.DB) error {
|
||
if err := tx.Create(&keys).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
for _, key := range keys {
|
||
if err := p.addKeyToStore(&key); err != nil {
|
||
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to add key to store after DB creation, rolling back transaction")
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
|
||
return err
|
||
}
|
||
|
||
// RemoveKeys 批量从池和数据库中移除 Key。
|
||
func (p *KeyProvider) RemoveKeys(groupID uint, keyValues []string) (int64, error) {
|
||
if len(keyValues) == 0 {
|
||
return 0, nil
|
||
}
|
||
|
||
var keysToDelete []models.APIKey
|
||
var deletedCount int64
|
||
|
||
err := p.db.Transaction(func(tx *gorm.DB) error {
|
||
if err := tx.Where("group_id = ? AND key_value IN ?", groupID, keyValues).Find(&keysToDelete).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
if len(keysToDelete) == 0 {
|
||
return nil
|
||
}
|
||
|
||
keyIDsToDelete := pluckIDs(keysToDelete)
|
||
|
||
result := tx.Where("id IN ?", keyIDsToDelete).Delete(&models.APIKey{})
|
||
if result.Error != nil {
|
||
return result.Error
|
||
}
|
||
deletedCount = result.RowsAffected
|
||
|
||
for _, key := range keysToDelete {
|
||
if err := p.removeKeyFromStore(key.ID, key.GroupID); err != nil {
|
||
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove key from store after DB deletion, rolling back transaction")
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
})
|
||
|
||
return deletedCount, err
|
||
}
|
||
|
||
// RestoreKeys 恢复组内所有无效的 Key。
|
||
func (p *KeyProvider) RestoreKeys(groupID uint) (int64, error) {
|
||
var invalidKeys []models.APIKey
|
||
var restoredCount int64
|
||
|
||
err := p.db.Transaction(func(tx *gorm.DB) error {
|
||
if err := tx.Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Find(&invalidKeys).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
if len(invalidKeys) == 0 {
|
||
return nil
|
||
}
|
||
|
||
updates := map[string]any{
|
||
"status": models.KeyStatusActive,
|
||
"failure_count": 0,
|
||
}
|
||
result := tx.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Updates(updates)
|
||
if result.Error != nil {
|
||
return result.Error
|
||
}
|
||
restoredCount = result.RowsAffected
|
||
|
||
for _, key := range invalidKeys {
|
||
key.Status = models.KeyStatusActive
|
||
key.FailureCount = 0
|
||
if err := p.addKeyToStore(&key); err != nil {
|
||
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to restore key in store after DB update, rolling back transaction")
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
|
||
return restoredCount, err
|
||
}
|
||
|
||
// RestoreMultipleKeys 恢复指定的 Key。
|
||
func (p *KeyProvider) RestoreMultipleKeys(groupID uint, keyValues []string) (int64, error) {
|
||
if len(keyValues) == 0 {
|
||
return 0, nil
|
||
}
|
||
|
||
var keysToRestore []models.APIKey
|
||
var restoredCount int64
|
||
|
||
err := p.db.Transaction(func(tx *gorm.DB) error {
|
||
// 1. 查找要恢复的密钥
|
||
if err := tx.Where("group_id = ? AND key_value IN ? AND status = ?", groupID, keyValues, models.KeyStatusInvalid).Find(&keysToRestore).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
if len(keysToRestore) == 0 {
|
||
return nil
|
||
}
|
||
|
||
keyIDsToRestore := pluckIDs(keysToRestore)
|
||
|
||
// 2. 更新数据库中的状态
|
||
updates := map[string]any{
|
||
"status": models.KeyStatusActive,
|
||
"failure_count": 0,
|
||
}
|
||
result := tx.Model(&models.APIKey{}).Where("id IN ?", keyIDsToRestore).Updates(updates)
|
||
if result.Error != nil {
|
||
return result.Error
|
||
}
|
||
restoredCount = result.RowsAffected
|
||
|
||
// 3. 将密钥添加回 Redis
|
||
for _, key := range keysToRestore {
|
||
key.Status = models.KeyStatusActive
|
||
key.FailureCount = 0
|
||
if err := p.addKeyToStore(&key); err != nil {
|
||
// 在事务中,单个失败会回滚整个事务,但这里的日志记录仍然有用
|
||
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to restore key in store after DB update")
|
||
return err // 返回错误以回滚事务
|
||
}
|
||
}
|
||
|
||
return nil
|
||
})
|
||
|
||
return restoredCount, err
|
||
}
|
||
|
||
// RemoveInvalidKeys 移除组内所有无效的 Key。
|
||
func (p *KeyProvider) RemoveInvalidKeys(groupID uint) (int64, error) {
|
||
var invalidKeys []models.APIKey
|
||
var removedCount int64
|
||
|
||
err := p.db.Transaction(func(tx *gorm.DB) error {
|
||
if err := tx.Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Find(&invalidKeys).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
if len(invalidKeys) == 0 {
|
||
return nil
|
||
}
|
||
|
||
result := tx.Where("id IN ?", pluckIDs(invalidKeys)).Delete(&models.APIKey{})
|
||
if result.Error != nil {
|
||
return result.Error
|
||
}
|
||
removedCount = result.RowsAffected
|
||
|
||
for _, key := range invalidKeys {
|
||
if err := p.removeKeyFromStore(key.ID, key.GroupID); err != nil {
|
||
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove invalid key from store after DB deletion, rolling back transaction")
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
|
||
return removedCount, err
|
||
}
|
||
|
||
// RemoveKeysFromStore 直接从内存存储中移除指定的键,不涉及数据库操作
|
||
// 这个方法适用于数据库已经删除但需要清理内存存储的场景
|
||
func (p *KeyProvider) RemoveKeysFromStore(groupID uint, keyIDs []uint) error {
|
||
if len(keyIDs) == 0 {
|
||
return nil
|
||
}
|
||
|
||
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
|
||
|
||
// 第一步:直接删除整个 active_keys 列表
|
||
if err := p.store.Delete(activeKeysListKey); err != nil {
|
||
logrus.WithFields(logrus.Fields{
|
||
"groupID": groupID,
|
||
"error": err,
|
||
}).Error("Failed to delete active keys list")
|
||
return err
|
||
}
|
||
|
||
// 第二步:批量删除所有相关的key hash
|
||
for _, keyID := range keyIDs {
|
||
keyHashKey := fmt.Sprintf("key:%d", keyID)
|
||
if err := p.store.Delete(keyHashKey); err != nil {
|
||
logrus.WithFields(logrus.Fields{
|
||
"keyID": keyID,
|
||
"error": err,
|
||
}).Error("Failed to delete key hash")
|
||
}
|
||
}
|
||
|
||
logrus.WithFields(logrus.Fields{
|
||
"groupID": groupID,
|
||
"keyCount": len(keyIDs),
|
||
}).Info("Successfully cleaned up group keys from store")
|
||
|
||
return nil
|
||
}
|
||
|
||
// addKeyToStore is a helper to add a single key to the cache.
|
||
func (p *KeyProvider) addKeyToStore(key *models.APIKey) error {
|
||
// 1. Store key details in HASH
|
||
keyHashKey := fmt.Sprintf("key:%d", key.ID)
|
||
keyDetails := p.apiKeyToMap(key)
|
||
if err := p.store.HSet(keyHashKey, keyDetails); err != nil {
|
||
return fmt.Errorf("failed to HSet key details for key %d: %w", key.ID, err)
|
||
}
|
||
|
||
// 2. If active, add to the active LIST
|
||
if key.Status == models.KeyStatusActive {
|
||
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", key.GroupID)
|
||
if err := p.store.LRem(activeKeysListKey, 0, key.ID); err != nil {
|
||
return fmt.Errorf("failed to LRem key %d before LPush for group %d: %w", key.ID, key.GroupID, err)
|
||
}
|
||
if err := p.store.LPush(activeKeysListKey, key.ID); err != nil {
|
||
return fmt.Errorf("failed to LPush key %d to group %d: %w", key.ID, key.GroupID, err)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// removeKeyFromStore is a helper to remove a single key from the cache.
|
||
func (p *KeyProvider) removeKeyFromStore(keyID, groupID uint) error {
|
||
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
|
||
if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil {
|
||
logrus.WithFields(logrus.Fields{"keyID": keyID, "groupID": groupID, "error": err}).Error("Failed to LRem key from active list")
|
||
}
|
||
|
||
keyHashKey := fmt.Sprintf("key:%d", keyID)
|
||
if err := p.store.Delete(keyHashKey); err != nil {
|
||
return fmt.Errorf("failed to delete key HASH for key %d: %w", keyID, err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// apiKeyToMap converts an APIKey model to a map for HSET.
|
||
func (p *KeyProvider) apiKeyToMap(key *models.APIKey) map[string]any {
|
||
return map[string]any{
|
||
"id": fmt.Sprint(key.ID),
|
||
"key_string": key.KeyValue,
|
||
"status": key.Status,
|
||
"failure_count": key.FailureCount,
|
||
"group_id": key.GroupID,
|
||
"created_at": key.CreatedAt.Unix(),
|
||
}
|
||
}
|
||
|
||
// pluckIDs extracts IDs from a slice of APIKey.
|
||
func pluckIDs(keys []models.APIKey) []uint {
|
||
ids := make([]uint, len(keys))
|
||
for i, key := range keys {
|
||
ids[i] = key.ID
|
||
}
|
||
return ids
|
||
}
|