From f03c338a5751eeae15b6f75650e96b602b5848ca Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 7 Jul 2025 23:48:15 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E4=BA=8B=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/keypool/provider.go | 149 ++++++++++++++++++++--------------- 1 file changed, 85 insertions(+), 64 deletions(-) diff --git a/internal/keypool/provider.go b/internal/keypool/provider.go index 96a768c..acc5f70 100644 --- a/internal/keypool/provider.go +++ b/internal/keypool/provider.go @@ -82,94 +82,112 @@ func (p *KeyProvider) UpdateStatus(keyID uint, groupID uint, isSuccess bool) { activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID) if isSuccess { - p.handleSuccess(keyID, keyHashKey, activeKeysListKey) + if err := p.handleSuccess(keyID, keyHashKey, activeKeysListKey); err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to handle key success") + } } else { - p.handleFailure(keyID, keyHashKey, activeKeysListKey) + if err := p.handleFailure(keyID, keyHashKey, activeKeysListKey); err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to handle key failure") + } } }() } -func (p *KeyProvider) handleSuccess(keyID uint, keyHashKey, activeKeysListKey string) { +func (p *KeyProvider) handleSuccess(keyID uint, keyHashKey, activeKeysListKey string) error { keyDetails, err := p.store.HGetAll(keyHashKey) if err != nil { - logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to get key details on success") - return + return fmt.Errorf("failed to get key details from store: %w", err) } failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64) isActive := keyDetails["status"] == models.KeyStatusActive if failureCount == 0 && isActive { - return + return nil } - if err := p.store.HSet(keyHashKey, map[string]any{"failure_count": 0}); err != nil { - logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to reset failure count in store, aborting DB update.") - return - } - - updates := map[string]any{"failure_count": 0} - - if !isActive { - logrus.WithField("keyID", keyID).Info("Key has recovered and is being restored to active pool.") - if err := p.store.HSet(keyHashKey, map[string]any{"status": models.KeyStatusActive}); err != nil { - logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key status to active in store, aborting DB update.") - return + return p.db.Transaction(func(tx *gorm.DB) error { + var key models.APIKey + if err := tx.Set("gorm:query_option", "FOR UPDATE").First(&key, keyID).Error; err != nil { + return fmt.Errorf("failed to lock key %d for update: %w", keyID, err) } - // To prevent duplicates, first remove any existing instance of the key from the list. - // This makes the recovery operation idempotent. - if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil { - logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to LRem key before LPush on recovery, aborting DB update.") - return - } - if err := p.store.LPush(activeKeysListKey, keyID); err != nil { - logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to LPush key back to active list, aborting DB update.") - return - } - updates["status"] = models.KeyStatusActive - } - if err := p.db.Model(&models.APIKey{}).Where("id = ?", keyID).Updates(updates).Error; err != nil { - logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key status in DB on success") - } + updates := map[string]any{"failure_count": 0} + if !isActive { + updates["status"] = models.KeyStatusActive + } + + if err := tx.Model(&key).Updates(updates).Error; err != nil { + return fmt.Errorf("failed to update key in DB: %w", err) + } + + if err := p.store.HSet(keyHashKey, updates); err != nil { + return fmt.Errorf("failed to update key details in store: %w", err) + } + + if !isActive { + logrus.WithField("keyID", keyID).Info("Key has recovered and is being restored to active pool.") + if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil { + return fmt.Errorf("failed to LRem key before LPush on recovery: %w", err) + } + if err := p.store.LPush(activeKeysListKey, keyID); err != nil { + return fmt.Errorf("failed to LPush key back to active list: %w", err) + } + } + + return nil + }) } -func (p *KeyProvider) handleFailure(keyID uint, keyHashKey, activeKeysListKey string) { +func (p *KeyProvider) handleFailure(keyID uint, keyHashKey, activeKeysListKey string) error { keyDetails, err := p.store.HGetAll(keyHashKey) if err != nil { - logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to get key details on failure") - return - } - if keyDetails["status"] == models.KeyStatusInvalid { - return + return fmt.Errorf("failed to get key details from store: %w", err) } - newFailureCount, err := p.store.HIncrBy(keyHashKey, "failure_count", 1) - if err != nil { - logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to increment failure count") - return + failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64) + + if keyDetails["status"] == models.KeyStatusInvalid { + return nil } settings := p.settingsManager.GetSettings() blacklistThreshold := settings.BlacklistThreshold - updates := map[string]any{"failure_count": newFailureCount} - if newFailureCount >= int64(blacklistThreshold) { - logrus.WithFields(logrus.Fields{"keyID": keyID, "threshold": blacklistThreshold}).Warn("Key has reached blacklist threshold, disabling.") - if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil { - logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to LRem key from active list, aborting DB update.") - return + return p.db.Transaction(func(tx *gorm.DB) error { + var key models.APIKey + if err := tx.Set("gorm:query_option", "FOR UPDATE").First(&key, keyID).Error; err != nil { + return fmt.Errorf("failed to lock key %d for update: %w", keyID, err) } - if err := p.store.HSet(keyHashKey, map[string]any{"status": models.KeyStatusInvalid}); err != nil { - logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key status to invalid in store, aborting DB update.") - return - } - updates["status"] = models.KeyStatusInvalid - } - if err := p.db.Model(&models.APIKey{}).Where("id = ?", keyID).Updates(updates).Error; err != nil { - logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key stats in DB on failure") - } + newFailureCount := failureCount + 1 + + updates := map[string]any{"failure_count": newFailureCount} + shouldBlacklist := newFailureCount >= int64(blacklistThreshold) + if shouldBlacklist { + updates["status"] = models.KeyStatusInvalid + } + + if err := tx.Model(&key).Updates(updates).Error; err != nil { + return fmt.Errorf("failed to update key stats in DB: %w", err) + } + + if _, err := p.store.HIncrBy(keyHashKey, "failure_count", 1); err != nil { + return fmt.Errorf("failed to increment failure count in store: %w", err) + } + + if shouldBlacklist { + logrus.WithFields(logrus.Fields{"keyID": keyID, "threshold": blacklistThreshold}).Warn("Key has reached blacklist threshold, disabling.") + if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil { + return fmt.Errorf("failed to LRem key from active list: %w", err) + } + if err := p.store.HSet(keyHashKey, map[string]any{"status": models.KeyStatusInvalid}); err != nil { + return fmt.Errorf("failed to update key status to invalid in store: %w", err) + } + } + + return nil + }) } // LoadKeysFromDB 从数据库加载所有分组和密钥,并填充到 Store 中。 @@ -278,7 +296,8 @@ func (p *KeyProvider) AddKeys(groupID uint, keys []models.APIKey) error { for _, key := range keys { if err := p.addKeyToStore(&key); err != nil { - logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to add key to store after DB creation") + logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to add key to store after DB creation, rolling back transaction") + return err } } return nil @@ -315,7 +334,8 @@ func (p *KeyProvider) RemoveKeys(groupID uint, keyValues []string) (int64, error for _, key := range keysToDelete { if err := p.removeKeyFromStore(key.ID, key.GroupID); err != nil { - logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove key from store after DB deletion") + logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove key from store after DB deletion, rolling back transaction") + return err } } @@ -353,7 +373,8 @@ func (p *KeyProvider) RestoreKeys(groupID uint) (int64, error) { key.Status = models.KeyStatusActive key.FailureCount = 0 if err := p.addKeyToStore(&key); err != nil { - logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to restore key in store after DB update") + logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to restore key in store after DB update, rolling back transaction") + return err } } return nil @@ -432,9 +453,9 @@ func (p *KeyProvider) RemoveInvalidKeys(groupID uint) (int64, error) { removedCount = result.RowsAffected for _, key := range invalidKeys { - keyHashKey := fmt.Sprintf("key:%d", key.ID) - if err := p.store.Delete(keyHashKey); err != nil { - logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove invalid key HASH from store after DB deletion") + if err := p.removeKeyFromStore(key.ID, key.GroupID); err != nil { + logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove invalid key from store after DB deletion, rolling back transaction") + return err } } return nil