Files
gpt-load/internal/keypool/provider.go
2025-07-07 18:55:06 +08:00

452 lines
15 KiB
Go

package keypool
import (
"errors"
"fmt"
"gpt-load/internal/config"
app_errors "gpt-load/internal/errors"
"gpt-load/internal/models"
"gpt-load/internal/store"
"strconv"
"time"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
const (
keypoolInitializedKey = "keypool:initialized"
keypoolLoadingKey = "keypool:loading"
)
type KeyProvider struct {
db *gorm.DB
store store.Store
settingsManager *config.SystemSettingsManager
}
// NewProvider 创建一个新的 KeyProvider 实例。
func NewProvider(db *gorm.DB, store store.Store, settingsManager *config.SystemSettingsManager) *KeyProvider {
return &KeyProvider{
db: db,
store: store,
settingsManager: settingsManager,
}
}
// SelectKey 为指定的分组原子性地选择并轮换一个可用的 APIKey。
func (p *KeyProvider) SelectKey(groupID uint) (*models.APIKey, error) {
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
// 1. Atomically rotate the key ID from the list
keyIDStr, err := p.store.Rotate(activeKeysListKey)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return nil, app_errors.ErrNoActiveKeys
}
return nil, fmt.Errorf("failed to rotate key from store: %w", err)
}
keyID, err := strconv.ParseUint(keyIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse key ID '%s': %w", keyIDStr, err)
}
// 2. Get key details from HASH
keyHashKey := fmt.Sprintf("key:%d", keyID)
keyDetails, err := p.store.HGetAll(keyHashKey)
if err != nil {
return nil, fmt.Errorf("failed to get key details for key ID %d: %w", keyID, err)
}
// 3. Manually unmarshal the map into an APIKey struct
failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64)
createdAt, _ := strconv.ParseInt(keyDetails["created_at"], 10, 64)
apiKey := &models.APIKey{
ID: uint(keyID),
KeyValue: keyDetails["key_string"],
Status: keyDetails["status"],
FailureCount: failureCount,
GroupID: groupID,
CreatedAt: time.Unix(createdAt, 0),
}
return apiKey, nil
}
// UpdateStatus 异步地提交一个 Key 状态更新任务。
func (p *KeyProvider) UpdateStatus(keyID uint, groupID uint, isSuccess bool) {
go func() {
keyHashKey := fmt.Sprintf("key:%d", keyID)
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
if isSuccess {
p.handleSuccess(keyID, keyHashKey, activeKeysListKey)
} else {
p.handleFailure(keyID, keyHashKey, activeKeysListKey)
}
}()
}
func (p *KeyProvider) handleSuccess(keyID uint, keyHashKey, activeKeysListKey string) {
keyDetails, err := p.store.HGetAll(keyHashKey)
if err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to get key details on success")
return
}
failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64)
isInvalid := keyDetails["status"] == models.KeyStatusInvalid
if failureCount == 0 && !isInvalid {
return
}
if err := p.store.HSet(keyHashKey, "failure_count", 0); err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to reset failure count in store, aborting DB update.")
return
}
updates := map[string]any{"failure_count": 0}
if isInvalid {
logrus.WithField("keyID", keyID).Info("Key has recovered and is being restored to active pool.")
if err := p.store.HSet(keyHashKey, "status", models.KeyStatusActive); err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key status to active in store, aborting DB update.")
return
}
// To prevent duplicates, first remove any existing instance of the key from the list.
// This makes the recovery operation idempotent.
if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to LRem key before LPush on recovery, aborting DB update.")
return
}
if err := p.store.LPush(activeKeysListKey, keyID); err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to LPush key back to active list, aborting DB update.")
return
}
updates["status"] = models.KeyStatusActive
}
if err := p.db.Model(&models.APIKey{}).Where("id = ?", keyID).Updates(updates).Error; err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key status in DB on success")
}
}
func (p *KeyProvider) handleFailure(keyID uint, keyHashKey, activeKeysListKey string) {
keyDetails, err := p.store.HGetAll(keyHashKey)
if err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to get key details on failure")
return
}
if keyDetails["status"] == models.KeyStatusInvalid {
return
}
newFailureCount, err := p.store.HIncrBy(keyHashKey, "failure_count", 1)
if err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to increment failure count")
return
}
settings := p.settingsManager.GetSettings()
blacklistThreshold := settings.BlacklistThreshold
updates := map[string]any{"failure_count": newFailureCount}
if newFailureCount >= int64(blacklistThreshold) {
logrus.WithFields(logrus.Fields{"keyID": keyID, "threshold": blacklistThreshold}).Warn("Key has reached blacklist threshold, disabling.")
if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to LRem key from active list, aborting DB update.")
return
}
if err := p.store.HSet(keyHashKey, "status", models.KeyStatusInvalid); err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key status to invalid in store, aborting DB update.")
return
}
updates["status"] = models.KeyStatusInvalid
}
if err := p.db.Model(&models.APIKey{}).Where("id = ?", keyID).Updates(updates).Error; err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key stats in DB on failure")
}
}
// LoadKeysFromDB 从数据库加载所有分组和密钥,并填充到 Store 中。
func (p *KeyProvider) LoadKeysFromDB() error {
// 1. 检查是否已初始化
initialized, err := p.store.Exists(keypoolInitializedKey)
if err != nil {
return fmt.Errorf("failed to check for keypool initialization flag: %w", err)
}
if initialized {
logrus.Info("Key pool already initialized, skipping database load.")
return nil
}
// 2. 设置加载锁,防止集群中多个节点同时加载
lockAcquired, err := p.store.SetNX(keypoolLoadingKey, []byte("1"), 10*time.Minute)
if err != nil {
return fmt.Errorf("failed to acquire loading lock: %w", err)
}
if !lockAcquired {
logrus.Info("Another instance is already loading the key pool. Skipping.")
return nil
}
defer p.store.Delete(keypoolLoadingKey)
logrus.Info("Acquired loading lock. Starting first-time initialization of key pool...")
// 3. 分批从数据库加载并使用 Pipeline 写入 Redis
allActiveKeyIDs := make(map[uint][]any)
batchSize := 1000
err = p.db.Model(&models.APIKey{}).FindInBatches(&[]*models.APIKey{}, batchSize, func(tx *gorm.DB, batch int) error {
keys := tx.RowsAffected
logrus.Infof("Processing batch %d with %d keys...", batch, keys)
var pipeline store.Pipeliner
if redisStore, ok := p.store.(store.RedisPipeliner); ok {
pipeline = redisStore.Pipeline()
}
var batchKeys []*models.APIKey
if err := tx.Find(&batchKeys).Error; err != nil {
return err
}
for _, key := range batchKeys {
keyHashKey := fmt.Sprintf("key:%d", key.ID)
keyDetails := p.apiKeyToMap(key)
if pipeline != nil {
pipeline.HSet(keyHashKey, keyDetails)
} else {
for field, value := range keyDetails {
if err := p.store.HSet(keyHashKey, field, value); err != nil {
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to HSet key details")
}
}
}
if key.Status == models.KeyStatusActive {
allActiveKeyIDs[key.GroupID] = append(allActiveKeyIDs[key.GroupID], key.ID)
}
}
if pipeline != nil {
if err := pipeline.Exec(); err != nil {
return fmt.Errorf("failed to execute pipeline for batch %d: %w", batch, err)
}
}
return nil
}).Error
if err != nil {
return fmt.Errorf("failed during batch processing of keys: %w", err)
}
// 4. 更新所有分组的 active_keys 列表
logrus.Info("Updating active key lists for all groups...")
for groupID, activeIDs := range allActiveKeyIDs {
if len(activeIDs) > 0 {
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
p.store.Delete(activeKeysListKey) // Clean slate
if err := p.store.LPush(activeKeysListKey, activeIDs...); err != nil {
logrus.WithFields(logrus.Fields{"groupID": groupID, "error": err}).Error("Failed to LPush active keys for group")
}
}
}
// 5. 设置最终的初始化成功标志
logrus.Info("Key pool loaded successfully. Setting initialization flag.")
if err := p.store.Set(keypoolInitializedKey, []byte("1"), 0); err != nil {
logrus.WithError(err).Error("Critical: Failed to set final initialization flag. Next startup might re-run initialization.")
}
return nil
}
// AddKeys 批量添加新的 Key 到池和数据库中。
func (p *KeyProvider) AddKeys(groupID uint, keys []models.APIKey) error {
if len(keys) == 0 {
return nil
}
err := p.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(&keys).Error; err != nil {
return err
}
for _, key := range keys {
if err := p.addKeyToStore(&key); err != nil {
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to add key to store after DB creation")
}
}
return nil
})
return err
}
// RemoveKeys 批量从池和数据库中移除 Key。
func (p *KeyProvider) RemoveKeys(groupID uint, keyValues []string) (int64, error) {
if len(keyValues) == 0 {
return 0, nil
}
var keysToDelete []models.APIKey
var deletedCount int64
err := p.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("group_id = ? AND key_value IN ?", groupID, keyValues).Find(&keysToDelete).Error; err != nil {
return err
}
if len(keysToDelete) == 0 {
return nil
}
result := tx.Where("group_id = ? AND key_value IN ?", groupID, keyValues).Delete(&models.APIKey{})
if result.Error != nil {
return result.Error
}
deletedCount = result.RowsAffected
for _, key := range keysToDelete {
if err := p.removeKeyFromStore(key.ID, key.GroupID); err != nil {
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove key from store after DB deletion")
}
}
return nil
})
return deletedCount, err
}
// RestoreKeys 恢复组内所有无效的 Key。
func (p *KeyProvider) RestoreKeys(groupID uint) (int64, error) {
var invalidKeys []models.APIKey
var restoredCount int64
err := p.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Find(&invalidKeys).Error; err != nil {
return err
}
if len(invalidKeys) == 0 {
return nil
}
result := tx.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Update("status", models.KeyStatusActive)
if result.Error != nil {
return result.Error
}
restoredCount = result.RowsAffected
for _, key := range invalidKeys {
key.Status = models.KeyStatusActive
if err := p.addKeyToStore(&key); err != nil {
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to restore key in store after DB update")
}
}
return nil
})
return restoredCount, err
}
// RemoveInvalidKeys 移除组内所有无效的 Key。
func (p *KeyProvider) RemoveInvalidKeys(groupID uint) (int64, error) {
var invalidKeys []models.APIKey
var removedCount int64
err := p.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Find(&invalidKeys).Error; err != nil {
return err
}
if len(invalidKeys) == 0 {
return nil
}
result := tx.Where("id IN ?", pluckIDs(invalidKeys)).Delete(&models.APIKey{})
if result.Error != nil {
return result.Error
}
removedCount = result.RowsAffected
for _, key := range invalidKeys {
keyHashKey := fmt.Sprintf("key:%d", key.ID)
if err := p.store.Delete(keyHashKey); err != nil {
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove invalid key HASH from store after DB deletion")
}
}
return nil
})
return removedCount, err
}
// addKeyToStore is a helper to add a single key to the cache.
func (p *KeyProvider) addKeyToStore(key *models.APIKey) error {
// 1. Store key details in HASH
keyHashKey := fmt.Sprintf("key:%d", key.ID)
keyDetails := p.apiKeyToMap(key)
for field, value := range keyDetails {
if err := p.store.HSet(keyHashKey, field, value); err != nil {
return fmt.Errorf("failed to HSet key details for key %d: %w", key.ID, err)
}
}
// 2. If active, add to the active LIST
if key.Status == models.KeyStatusActive {
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", key.GroupID)
// To prevent duplicates, first remove any existing instance of the key from the list.
// This makes the add operation idempotent regarding the list.
if err := p.store.LRem(activeKeysListKey, 0, key.ID); err != nil {
return fmt.Errorf("failed to LRem key %d before LPush for group %d: %w", key.ID, key.GroupID, err)
}
if err := p.store.LPush(activeKeysListKey, key.ID); err != nil {
return fmt.Errorf("failed to LPush key %d to group %d: %w", key.ID, key.GroupID, err)
}
}
return nil
}
// removeKeyFromStore is a helper to remove a single key from the cache.
func (p *KeyProvider) removeKeyFromStore(keyID, groupID uint) error {
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil {
logrus.WithFields(logrus.Fields{"keyID": keyID, "groupID": groupID, "error": err}).Error("Failed to LRem key from active list")
}
keyHashKey := fmt.Sprintf("key:%d", keyID)
if err := p.store.Delete(keyHashKey); err != nil {
return fmt.Errorf("failed to delete key HASH for key %d: %w", keyID, err)
}
return nil
}
// apiKeyToMap converts an APIKey model to a map for HSET.
func (p *KeyProvider) apiKeyToMap(key *models.APIKey) map[string]any {
return map[string]any{
"id": fmt.Sprint(key.ID), // Use fmt.Sprint for consistency in pipeline
"key_string": key.KeyValue,
"status": key.Status,
"failure_count": key.FailureCount,
"group_id": key.GroupID,
"created_at": key.CreatedAt.Unix(),
}
}
// pluckIDs extracts IDs from a slice of APIKey.
func pluckIDs(keys []models.APIKey) []uint {
ids := make([]uint, len(keys))
for i, key := range keys {
ids[i] = key.ID
}
return ids
}