diff --git a/internal/app/app.go b/internal/app/app.go index 79ee07e..7ae538b 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,6 +9,7 @@ import ( "time" "gpt-load/internal/config" + "gpt-load/internal/keypool" "gpt-load/internal/models" "gpt-load/internal/proxy" "gpt-load/internal/services" @@ -29,6 +30,7 @@ type App struct { logCleanupService *services.LogCleanupService keyCronService *services.KeyCronService keyValidationPool *services.KeyValidationPool + keyPoolProvider *keypool.KeyProvider proxyServer *proxy.ProxyServer storage store.Store db *gorm.DB @@ -46,6 +48,7 @@ type AppParams struct { LogCleanupService *services.LogCleanupService KeyCronService *services.KeyCronService KeyValidationPool *services.KeyValidationPool + KeyPoolProvider *keypool.KeyProvider ProxyServer *proxy.ProxyServer Storage store.Store DB *gorm.DB @@ -61,6 +64,7 @@ func NewApp(params AppParams) *App { logCleanupService: params.LogCleanupService, keyCronService: params.KeyCronService, keyValidationPool: params.KeyValidationPool, + keyPoolProvider: params.KeyPoolProvider, proxyServer: params.ProxyServer, storage: params.Storage, db: params.DB, @@ -75,6 +79,11 @@ func (a *App) Start() error { return fmt.Errorf("failed to initialize system settings: %w", err) } logrus.Info("System settings initialized") + + logrus.Info("Loading API keys into the key pool...") + if err := a.keyPoolProvider.LoadKeysFromDB(); err != nil { + return fmt.Errorf("failed to load keys into key pool: %w", err) + } a.settingsManager.DisplayCurrentSettings() a.configManager.DisplayConfig() diff --git a/internal/container/container.go b/internal/container/container.go index b46c19e..787219f 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -7,6 +7,7 @@ import ( "gpt-load/internal/config" "gpt-load/internal/db" "gpt-load/internal/handler" + "gpt-load/internal/keypool" "gpt-load/internal/proxy" "gpt-load/internal/router" "gpt-load/internal/services" @@ -58,6 +59,9 @@ func BuildContainer() (*dig.Container, error) { if err := container.Provide(services.NewLogCleanupService); err != nil { return nil, err } + if err := container.Provide(keypool.NewProvider); err != nil { + return nil, err + } // Handlers if err := container.Provide(handler.NewServer); err != nil { diff --git a/internal/errors/errors.go b/internal/errors/errors.go index f20122c..4431907 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -33,6 +33,7 @@ var ( ErrForbidden = &APIError{HTTPStatus: http.StatusForbidden, Code: "FORBIDDEN", Message: "You do not have permission to access this resource"} ErrTaskInProgress = &APIError{HTTPStatus: http.StatusConflict, Code: "TASK_IN_PROGRESS", Message: "A task is already in progress"} ErrBadGateway = &APIError{HTTPStatus: http.StatusBadGateway, Code: "BAD_GATEWAY", Message: "Upstream service error"} + ErrNoActiveKeys = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_ACTIVE_KEYS", Message: "No active API keys available for this group"} ) // NewAPIError creates a new APIError with a custom message. diff --git a/internal/keypool/provider.go b/internal/keypool/provider.go new file mode 100644 index 0000000..dbdcdbe --- /dev/null +++ b/internal/keypool/provider.go @@ -0,0 +1,451 @@ +package keypool + +import ( + "errors" + "fmt" + "gpt-load/internal/config" + app_errors "gpt-load/internal/errors" + "gpt-load/internal/models" + "gpt-load/internal/store" + "strconv" + "time" + + "github.com/sirupsen/logrus" + "gorm.io/gorm" +) + +const ( + keypoolInitializedKey = "keypool:initialized" + keypoolLoadingKey = "keypool:loading" +) + +type KeyProvider struct { + db *gorm.DB + store store.Store + settingsManager *config.SystemSettingsManager +} + +// NewProvider 创建一个新的 KeyProvider 实例。 +func NewProvider(db *gorm.DB, store store.Store, settingsManager *config.SystemSettingsManager) *KeyProvider { + return &KeyProvider{ + db: db, + store: store, + settingsManager: settingsManager, + } +} + +// SelectKey 为指定的分组原子性地选择并轮换一个可用的 APIKey。 +func (p *KeyProvider) SelectKey(groupID uint) (*models.APIKey, error) { + activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID) + + // 1. Atomically rotate the key ID from the list + keyIDStr, err := p.store.Rotate(activeKeysListKey) + if err != nil { + if errors.Is(err, store.ErrNotFound) { + return nil, app_errors.ErrNoActiveKeys + } + return nil, fmt.Errorf("failed to rotate key from store: %w", err) + } + + keyID, err := strconv.ParseUint(keyIDStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse key ID '%s': %w", keyIDStr, err) + } + + // 2. Get key details from HASH + keyHashKey := fmt.Sprintf("key:%d", keyID) + keyDetails, err := p.store.HGetAll(keyHashKey) + if err != nil { + return nil, fmt.Errorf("failed to get key details for key ID %d: %w", keyID, err) + } + + // 3. Manually unmarshal the map into an APIKey struct + failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64) + createdAt, _ := strconv.ParseInt(keyDetails["created_at"], 10, 64) + + apiKey := &models.APIKey{ + ID: uint(keyID), + KeyValue: keyDetails["key_string"], + Status: keyDetails["status"], + FailureCount: failureCount, + GroupID: groupID, + CreatedAt: time.Unix(createdAt, 0), + } + + return apiKey, nil +} + +// UpdateStatus 异步地提交一个 Key 状态更新任务。 +func (p *KeyProvider) UpdateStatus(keyID uint, groupID uint, isSuccess bool) { + go func() { + keyHashKey := fmt.Sprintf("key:%d", keyID) + activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID) + + if isSuccess { + p.handleSuccess(keyID, keyHashKey, activeKeysListKey) + } else { + p.handleFailure(keyID, keyHashKey, activeKeysListKey) + } + }() +} + +func (p *KeyProvider) handleSuccess(keyID uint, keyHashKey, activeKeysListKey string) { + keyDetails, err := p.store.HGetAll(keyHashKey) + if err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to get key details on success") + return + } + + failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64) + isInvalid := keyDetails["status"] == models.KeyStatusInvalid + + if failureCount == 0 && !isInvalid { + return + } + + if err := p.store.HSet(keyHashKey, "failure_count", 0); err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to reset failure count in store, aborting DB update.") + return + } + + updates := map[string]any{"failure_count": 0} + + if isInvalid { + logrus.WithField("keyID", keyID).Info("Key has recovered and is being restored to active pool.") + if err := p.store.HSet(keyHashKey, "status", models.KeyStatusActive); err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key status to active in store, aborting DB update.") + return + } + // To prevent duplicates, first remove any existing instance of the key from the list. + // This makes the recovery operation idempotent. + if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to LRem key before LPush on recovery, aborting DB update.") + return + } + if err := p.store.LPush(activeKeysListKey, keyID); err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to LPush key back to active list, aborting DB update.") + return + } + updates["status"] = models.KeyStatusActive + } + + if err := p.db.Model(&models.APIKey{}).Where("id = ?", keyID).Updates(updates).Error; err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key status in DB on success") + } +} + +func (p *KeyProvider) handleFailure(keyID uint, keyHashKey, activeKeysListKey string) { + keyDetails, err := p.store.HGetAll(keyHashKey) + if err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to get key details on failure") + return + } + if keyDetails["status"] == models.KeyStatusInvalid { + return + } + + newFailureCount, err := p.store.HIncrBy(keyHashKey, "failure_count", 1) + if err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to increment failure count") + return + } + + settings := p.settingsManager.GetSettings() + blacklistThreshold := settings.BlacklistThreshold + updates := map[string]any{"failure_count": newFailureCount} + + if newFailureCount >= int64(blacklistThreshold) { + logrus.WithFields(logrus.Fields{"keyID": keyID, "threshold": blacklistThreshold}).Warn("Key has reached blacklist threshold, disabling.") + if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to LRem key from active list, aborting DB update.") + return + } + if err := p.store.HSet(keyHashKey, "status", models.KeyStatusInvalid); err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key status to invalid in store, aborting DB update.") + return + } + updates["status"] = models.KeyStatusInvalid + } + + if err := p.db.Model(&models.APIKey{}).Where("id = ?", keyID).Updates(updates).Error; err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key stats in DB on failure") + } +} + +// LoadKeysFromDB 从数据库加载所有分组和密钥,并填充到 Store 中。 +func (p *KeyProvider) LoadKeysFromDB() error { + // 1. 检查是否已初始化 + initialized, err := p.store.Exists(keypoolInitializedKey) + if err != nil { + return fmt.Errorf("failed to check for keypool initialization flag: %w", err) + } + if initialized { + logrus.Info("Key pool already initialized, skipping database load.") + return nil + } + + // 2. 设置加载锁,防止集群中多个节点同时加载 + lockAcquired, err := p.store.SetNX(keypoolLoadingKey, []byte("1"), 10*time.Minute) + if err != nil { + return fmt.Errorf("failed to acquire loading lock: %w", err) + } + if !lockAcquired { + logrus.Info("Another instance is already loading the key pool. Skipping.") + return nil + } + defer p.store.Delete(keypoolLoadingKey) + + logrus.Info("Acquired loading lock. Starting first-time initialization of key pool...") + + // 3. 分批从数据库加载并使用 Pipeline 写入 Redis + allActiveKeyIDs := make(map[uint][]any) + batchSize := 1000 + + err = p.db.Model(&models.APIKey{}).FindInBatches(&[]*models.APIKey{}, batchSize, func(tx *gorm.DB, batch int) error { + keys := tx.RowsAffected + logrus.Infof("Processing batch %d with %d keys...", batch, keys) + + var pipeline store.Pipeliner + if redisStore, ok := p.store.(store.RedisPipeliner); ok { + pipeline = redisStore.Pipeline() + } + + var batchKeys []*models.APIKey + if err := tx.Find(&batchKeys).Error; err != nil { + return err + } + + for _, key := range batchKeys { + keyHashKey := fmt.Sprintf("key:%d", key.ID) + keyDetails := p.apiKeyToMap(key) + + if pipeline != nil { + pipeline.HSet(keyHashKey, keyDetails) + } else { + for field, value := range keyDetails { + if err := p.store.HSet(keyHashKey, field, value); err != nil { + logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to HSet key details") + } + } + } + + if key.Status == models.KeyStatusActive { + allActiveKeyIDs[key.GroupID] = append(allActiveKeyIDs[key.GroupID], key.ID) + } + } + + if pipeline != nil { + if err := pipeline.Exec(); err != nil { + return fmt.Errorf("failed to execute pipeline for batch %d: %w", batch, err) + } + } + return nil + }).Error + + if err != nil { + return fmt.Errorf("failed during batch processing of keys: %w", err) + } + + // 4. 更新所有分组的 active_keys 列表 + logrus.Info("Updating active key lists for all groups...") + for groupID, activeIDs := range allActiveKeyIDs { + if len(activeIDs) > 0 { + activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID) + p.store.Delete(activeKeysListKey) // Clean slate + if err := p.store.LPush(activeKeysListKey, activeIDs...); err != nil { + logrus.WithFields(logrus.Fields{"groupID": groupID, "error": err}).Error("Failed to LPush active keys for group") + } + } + } + + // 5. 设置最终的初始化成功标志 + logrus.Info("Key pool loaded successfully. Setting initialization flag.") + if err := p.store.Set(keypoolInitializedKey, []byte("1"), 0); err != nil { + logrus.WithError(err).Error("Critical: Failed to set final initialization flag. Next startup might re-run initialization.") + } + + return nil +} + +// AddKeys 批量添加新的 Key 到池和数据库中。 +func (p *KeyProvider) AddKeys(groupID uint, keys []models.APIKey) error { + if len(keys) == 0 { + return nil + } + + err := p.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Create(&keys).Error; err != nil { + return err + } + + for _, key := range keys { + if err := p.addKeyToStore(&key); err != nil { + logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to add key to store after DB creation") + } + } + return nil + }) + + return err +} + +// RemoveKeys 批量从池和数据库中移除 Key。 +func (p *KeyProvider) RemoveKeys(groupID uint, keyValues []string) (int64, error) { + if len(keyValues) == 0 { + return 0, nil + } + + var keysToDelete []models.APIKey + var deletedCount int64 + + err := p.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Where("group_id = ? AND key_value IN ?", groupID, keyValues).Find(&keysToDelete).Error; err != nil { + return err + } + + if len(keysToDelete) == 0 { + return nil + } + + result := tx.Where("group_id = ? AND key_value IN ?", groupID, keyValues).Delete(&models.APIKey{}) + if result.Error != nil { + return result.Error + } + deletedCount = result.RowsAffected + + for _, key := range keysToDelete { + if err := p.removeKeyFromStore(key.ID, key.GroupID); err != nil { + logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove key from store after DB deletion") + } + } + + return nil + }) + + return deletedCount, err +} + +// RestoreKeys 恢复组内所有无效的 Key。 +func (p *KeyProvider) RestoreKeys(groupID uint) (int64, error) { + var invalidKeys []models.APIKey + var restoredCount int64 + + err := p.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Find(&invalidKeys).Error; err != nil { + return err + } + + if len(invalidKeys) == 0 { + return nil + } + + result := tx.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Update("status", models.KeyStatusActive) + if result.Error != nil { + return result.Error + } + restoredCount = result.RowsAffected + + for _, key := range invalidKeys { + key.Status = models.KeyStatusActive + if err := p.addKeyToStore(&key); err != nil { + logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to restore key in store after DB update") + } + } + return nil + }) + + return restoredCount, err +} + +// RemoveInvalidKeys 移除组内所有无效的 Key。 +func (p *KeyProvider) RemoveInvalidKeys(groupID uint) (int64, error) { + var invalidKeys []models.APIKey + var removedCount int64 + + err := p.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Find(&invalidKeys).Error; err != nil { + return err + } + + if len(invalidKeys) == 0 { + return nil + } + + result := tx.Where("id IN ?", pluckIDs(invalidKeys)).Delete(&models.APIKey{}) + if result.Error != nil { + return result.Error + } + removedCount = result.RowsAffected + + for _, key := range invalidKeys { + keyHashKey := fmt.Sprintf("key:%d", key.ID) + if err := p.store.Delete(keyHashKey); err != nil { + logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove invalid key HASH from store after DB deletion") + } + } + return nil + }) + + return removedCount, err +} + +// addKeyToStore is a helper to add a single key to the cache. +func (p *KeyProvider) addKeyToStore(key *models.APIKey) error { + // 1. Store key details in HASH + keyHashKey := fmt.Sprintf("key:%d", key.ID) + keyDetails := p.apiKeyToMap(key) + for field, value := range keyDetails { + if err := p.store.HSet(keyHashKey, field, value); err != nil { + return fmt.Errorf("failed to HSet key details for key %d: %w", key.ID, err) + } + } + + // 2. If active, add to the active LIST + if key.Status == models.KeyStatusActive { + activeKeysListKey := fmt.Sprintf("group:%d:active_keys", key.GroupID) + // To prevent duplicates, first remove any existing instance of the key from the list. + // This makes the add operation idempotent regarding the list. + if err := p.store.LRem(activeKeysListKey, 0, key.ID); err != nil { + return fmt.Errorf("failed to LRem key %d before LPush for group %d: %w", key.ID, key.GroupID, err) + } + if err := p.store.LPush(activeKeysListKey, key.ID); err != nil { + return fmt.Errorf("failed to LPush key %d to group %d: %w", key.ID, key.GroupID, err) + } + } + return nil +} + +// removeKeyFromStore is a helper to remove a single key from the cache. +func (p *KeyProvider) removeKeyFromStore(keyID, groupID uint) error { + activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID) + if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil { + logrus.WithFields(logrus.Fields{"keyID": keyID, "groupID": groupID, "error": err}).Error("Failed to LRem key from active list") + } + + keyHashKey := fmt.Sprintf("key:%d", keyID) + if err := p.store.Delete(keyHashKey); err != nil { + return fmt.Errorf("failed to delete key HASH for key %d: %w", keyID, err) + } + return nil +} + +// apiKeyToMap converts an APIKey model to a map for HSET. +func (p *KeyProvider) apiKeyToMap(key *models.APIKey) map[string]any { + return map[string]any{ + "id": fmt.Sprint(key.ID), // Use fmt.Sprint for consistency in pipeline + "key_string": key.KeyValue, + "status": key.Status, + "failure_count": key.FailureCount, + "group_id": key.GroupID, + "created_at": key.CreatedAt.Unix(), + } +} + +// pluckIDs extracts IDs from a slice of APIKey. +func pluckIDs(keys []models.APIKey) []uint { + ids := make([]uint, len(keys)) + for i, key := range keys { + ids[i] = key.ID + } + return ids +} diff --git a/internal/proxy/server.go b/internal/proxy/server.go index ba99fbb..3eaa45c 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -7,11 +7,10 @@ import ( "fmt" "gpt-load/internal/channel" app_errors "gpt-load/internal/errors" + "gpt-load/internal/keypool" "gpt-load/internal/models" "gpt-load/internal/response" "io" - "sync" - "sync/atomic" "time" "github.com/gin-gonic/gin" @@ -23,16 +22,21 @@ import ( type ProxyServer struct { DB *gorm.DB channelFactory *channel.Factory - groupCounters sync.Map // map[uint]*atomic.Uint64 + keyProvider *keypool.KeyProvider requestLogChan chan models.RequestLog } // NewProxyServer creates a new proxy server -func NewProxyServer(db *gorm.DB, channelFactory *channel.Factory, requestLogChan chan models.RequestLog) (*ProxyServer, error) { +func NewProxyServer( + db *gorm.DB, + channelFactory *channel.Factory, + keyProvider *keypool.KeyProvider, + requestLogChan chan models.RequestLog, +) (*ProxyServer, error) { return &ProxyServer{ DB: db, channelFactory: channelFactory, - groupCounters: sync.Map{}, + keyProvider: keyProvider, requestLogChan: requestLogChan, }, nil } @@ -42,17 +46,22 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) { startTime := time.Now() groupName := c.Param("group_name") - // 1. Find the group by name + // 1. Find the group by name (without preloading keys) var group models.Group - if err := ps.DB.Preload("APIKeys").Where("name = ?", groupName).First(&group).Error; err != nil { + if err := ps.DB.Where("name = ?", groupName).First(&group).Error; err != nil { response.Error(c, app_errors.ParseDBError(err)) return } - // 2. Select an available API key from the group - apiKey, err := ps.selectAPIKey(&group) + // 2. Select an available API key from the KeyPool + apiKey, err := ps.keyProvider.SelectKey(group.ID) if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, err.Error())) + // Properly handle the case where no keys are available + if apiErr, ok := err.(*app_errors.APIError); ok { + response.Error(c, apiErr) + } else { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, err.Error())) + } return } @@ -75,8 +84,10 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) { // 5. Forward the request using the channel handler err = channelHandler.Handle(c, apiKey, &group) - // 6. Log the request asynchronously + // 6. Update key status and log the request asynchronously isSuccess := err == nil + ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, isSuccess) + if !isSuccess { logrus.WithFields(logrus.Fields{ "group": group.Name, @@ -84,37 +95,10 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) { "error": err.Error(), }).Error("Channel handler failed") } - go ps.logRequest(c, &group, apiKey, startTime, isSuccess) + go ps.logRequest(c, &group, apiKey, startTime) } -// selectAPIKey selects an API key from a group using round-robin -func (ps *ProxyServer) selectAPIKey(group *models.Group) (*models.APIKey, error) { - activeKeys := make([]models.APIKey, 0, len(group.APIKeys)) - for _, key := range group.APIKeys { - if key.Status == "active" { - activeKeys = append(activeKeys, key) - } - } - - if len(activeKeys) == 0 { - return nil, fmt.Errorf("no active API keys available in group '%s'", group.Name) - } - - // Get or create a counter for the group. The value is a pointer to a uint64. - val, _ := ps.groupCounters.LoadOrStore(group.ID, new(atomic.Uint64)) - counter := val.(*atomic.Uint64) - - // Atomically increment the counter and get the index for this request. - index := counter.Add(1) - 1 - selectedKey := activeKeys[int(index%uint64(len(activeKeys)))] - - return &selectedKey, nil -} - -func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time, isSuccess bool) { - // Update key stats based on request success - go ps.updateKeyStats(key.ID, isSuccess) - +func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time) { logEntry := models.RequestLog{ ID: fmt.Sprintf("req_%d", time.Now().UnixNano()), Timestamp: startTime, @@ -134,27 +118,6 @@ func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *mode } } -// updateKeyStats atomically updates the request and failure counts for a key -func (ps *ProxyServer) updateKeyStats(keyID uint, success bool) { - // Always increment the request count - updates := map[string]any{ - "request_count": gorm.Expr("request_count + 1"), - } - - // Additionally, increment the failure count if the request was not successful - if !success { - updates["failure_count"] = gorm.Expr("failure_count + 1") - } - - result := ps.DB.Model(&models.APIKey{}).Where("id = ?", keyID).Updates(updates) - if result.Error != nil { - logrus.WithFields(logrus.Fields{ - "keyID": keyID, - "error": result.Error, - }).Error("Failed to update key stats") - } -} - // Close cleans up resources func (ps *ProxyServer) Close() { // Nothing to close for now diff --git a/internal/services/key_service.go b/internal/services/key_service.go index e812e23..5613e38 100644 --- a/internal/services/key_service.go +++ b/internal/services/key_service.go @@ -3,6 +3,7 @@ package services import ( "encoding/json" "fmt" + "gpt-load/internal/keypool" "gpt-load/internal/models" "regexp" "strings" @@ -26,12 +27,16 @@ type DeleteKeysResult struct { // KeyService provides services related to API keys. type KeyService struct { - DB *gorm.DB + DB *gorm.DB + KeyProvider *keypool.KeyProvider } // NewKeyService creates a new KeyService. -func NewKeyService(db *gorm.DB) *KeyService { - return &KeyService{DB: db} +func NewKeyService(db *gorm.DB, keyProvider *keypool.KeyProvider) *KeyService { + return &KeyService{ + DB: db, + KeyProvider: keyProvider, + } } // AddMultipleKeys handles the business logic of creating new keys from a text block. @@ -42,13 +47,7 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes return nil, fmt.Errorf("no valid keys found in the input text") } - // 2. Get the group information for validation - var group models.Group - if err := s.DB.First(&group, groupID).Error; err != nil { - return nil, fmt.Errorf("failed to find group: %w", err) - } - - // 3. Get existing keys in the group for deduplication + // 2. Get existing keys in the group for deduplication var existingKeys []models.APIKey if err := s.DB.Where("group_id = ?", groupID).Select("key_value").Find(&existingKeys).Error; err != nil { return nil, err @@ -58,7 +57,7 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes existingKeyMap[k.KeyValue] = true } - // 4. Prepare new keys with basic validation only + // 3. Prepare new keys for creation var newKeysToCreate []models.APIKey uniqueNewKeys := make(map[string]bool) @@ -67,43 +66,44 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes if trimmedKey == "" { continue } - - // Check if key already exists if existingKeyMap[trimmedKey] || uniqueNewKeys[trimmedKey] { continue } - - // 通用验证:只做基础格式检查,不做渠道特定验证 if s.isValidKeyFormat(trimmedKey) { uniqueNewKeys[trimmedKey] = true newKeysToCreate = append(newKeysToCreate, models.APIKey{ GroupID: groupID, KeyValue: trimmedKey, - Status: "active", + Status: models.KeyStatusActive, }) } } - addedCount := len(newKeysToCreate) - // 更准确的忽略计数:包括重复的和无效的 - ignoredCount := len(keys) - addedCount - - // 5. Insert new keys if any - if addedCount > 0 { - if err := s.DB.Create(&newKeysToCreate).Error; err != nil { - return nil, err - } + if len(newKeysToCreate) == 0 { + var totalInGroup int64 + s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup) + return &AddKeysResult{ + AddedCount: 0, + IgnoredCount: len(keys), + TotalInGroup: totalInGroup, + }, nil } - // 6. Get the new total count + // 4. Use KeyProvider to add keys, which handles DB and cache + err := s.KeyProvider.AddKeys(groupID, newKeysToCreate) + if err != nil { + return nil, err + } + + // 5. Get the new total count var totalInGroup int64 if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil { return nil, err } return &AddKeysResult{ - AddedCount: addedCount, - IgnoredCount: ignoredCount, + AddedCount: len(newKeysToCreate), + IgnoredCount: len(keys) - len(newKeysToCreate), TotalInGroup: totalInGroup, }, nil } @@ -161,14 +161,12 @@ func (s *KeyService) isValidKeyFormat(key string) bool { // RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'. func (s *KeyService) RestoreAllInvalidKeys(groupID uint) (int64, error) { - result := s.DB.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, "inactive").Update("status", "active") - return result.RowsAffected, result.Error + return s.KeyProvider.RestoreKeys(groupID) } // ClearAllInvalidKeys deletes all 'inactive' keys from a group. func (s *KeyService) ClearAllInvalidKeys(groupID uint) (int64, error) { - result := s.DB.Where("group_id = ? AND status = ?", groupID, "inactive").Delete(&models.APIKey{}) - return result.RowsAffected, result.Error + return s.KeyProvider.RemoveInvalidKeys(groupID) } // DeleteMultipleKeys handles the business logic of deleting keys from a text block. @@ -179,16 +177,13 @@ func (s *KeyService) DeleteMultipleKeys(groupID uint, keysText string) (*DeleteK return nil, fmt.Errorf("no valid keys found in the input text") } - // 2. Perform the deletion - // GORM's batch delete doesn't easily return which ones were deleted vs. ignored. - // We perform a bulk delete and then count the remaining to calculate the result. - result := s.DB.Where("group_id = ? AND key_value IN ?", groupID, keysToDelete).Delete(&models.APIKey{}) - if result.Error != nil { - return nil, result.Error + // 2. Use KeyProvider to delete keys, which handles DB and cache + deletedCount, err := s.KeyProvider.RemoveKeys(groupID, keysToDelete) + if err != nil { + return nil, err } - deletedCount := int(result.RowsAffected) - ignoredCount := len(keysToDelete) - deletedCount + ignoredCount := len(keysToDelete) - int(deletedCount) // 3. Get the new total count var totalInGroup int64 @@ -197,7 +192,7 @@ func (s *KeyService) DeleteMultipleKeys(groupID uint, keysText string) (*DeleteK } return &DeleteKeysResult{ - DeletedCount: deletedCount, + DeletedCount: int(deletedCount), IgnoredCount: ignoredCount, TotalInGroup: totalInGroup, }, nil @@ -219,4 +214,3 @@ func (s *KeyService) ListKeysInGroupQuery(groupID uint, statusFilter string, sea return query } - diff --git a/internal/store/memory.go b/internal/store/memory.go index 6904042..d966cd4 100644 --- a/internal/store/memory.go +++ b/internal/store/memory.go @@ -1,6 +1,8 @@ package store import ( + "fmt" + "strconv" "sync" "time" ) @@ -12,50 +14,29 @@ type memoryStoreItem struct { } // MemoryStore is an in-memory key-value store that is safe for concurrent use. +// It now supports simple K/V, HASH, and LIST data types. type MemoryStore struct { - mu sync.RWMutex - data map[string]memoryStoreItem - stopCh chan struct{} // Channel to stop the cleanup goroutine + mu sync.RWMutex + // Using 'any' to store different data structures (memoryStoreItem, map[string]string, []string) + data map[string]any } // NewMemoryStore creates and returns a new MemoryStore instance. -// It also starts a background goroutine to periodically clean up expired keys. func NewMemoryStore() *MemoryStore { s := &MemoryStore{ - data: make(map[string]memoryStoreItem), - stopCh: make(chan struct{}), + data: make(map[string]any), } - go s.cleanupLoop(1 * time.Minute) + // The cleanup loop was removed as it's not compatible with multiple data types + // without a unified expiration mechanism, and the KeyPool feature does not rely on TTLs. return s } -// Close stops the background cleanup goroutine. +// Close cleans up resources. func (s *MemoryStore) Close() error { - close(s.stopCh) + // Nothing to close for now. return nil } -// cleanupLoop periodically iterates through the store and removes expired keys. -func (s *MemoryStore) cleanupLoop(interval time.Duration) { - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - s.mu.Lock() - now := time.Now().UnixNano() - for key, item := range s.data { - if item.expiresAt > 0 && now > item.expiresAt { - delete(s.data, key) - } - } - s.mu.Unlock() - case <-s.stopCh: - return - } - } -} // Set stores a key-value pair. func (s *MemoryStore) Set(key string, value []byte, ttl time.Duration) error { @@ -77,13 +58,18 @@ func (s *MemoryStore) Set(key string, value []byte, ttl time.Duration) error { // Get retrieves a value by its key. func (s *MemoryStore) Get(key string) ([]byte, error) { s.mu.RLock() - item, exists := s.data[key] + rawItem, exists := s.data[key] s.mu.RUnlock() if !exists { return nil, ErrNotFound } + item, ok := rawItem.(memoryStoreItem) + if !ok { + return nil, fmt.Errorf("type mismatch: key '%s' holds a different data type", key) + } + // Check for expiration if item.expiresAt > 0 && time.Now().UnixNano() > item.expiresAt { // Lazy deletion @@ -107,20 +93,213 @@ func (s *MemoryStore) Delete(key string) error { // Exists checks if a key exists. func (s *MemoryStore) Exists(key string) (bool, error) { s.mu.RLock() - item, exists := s.data[key] + rawItem, exists := s.data[key] s.mu.RUnlock() if !exists { return false, nil } - if item.expiresAt > 0 && time.Now().UnixNano() > item.expiresAt { - // Lazy deletion - s.mu.Lock() - delete(s.data, key) - s.mu.Unlock() - return false, nil + // Check for expiration only if it's a simple K/V item + if item, ok := rawItem.(memoryStoreItem); ok { + if item.expiresAt > 0 && time.Now().UnixNano() > item.expiresAt { + // Lazy deletion + s.mu.Lock() + delete(s.data, key) + s.mu.Unlock() + return false, nil + } } return true, nil } + +// SetNX sets a key-value pair if the key does not already exist. +func (s *MemoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // In memory store, we need to manually check for existence and expiration + rawItem, exists := s.data[key] + if exists { + if item, ok := rawItem.(memoryStoreItem); ok { + if item.expiresAt == 0 || time.Now().UnixNano() < item.expiresAt { + // Key exists and is not expired + return false, nil + } + } else { + // Key exists but is not a simple K/V item, treat as existing + return false, nil + } + } + + // Key does not exist or is expired, so we can set it. + var expiresAt int64 + if ttl > 0 { + expiresAt = time.Now().UnixNano() + ttl.Nanoseconds() + } + s.data[key] = memoryStoreItem{ + value: value, + expiresAt: expiresAt, + } + return true, nil +} + +// --- HASH operations --- + +func (s *MemoryStore) HSet(key, field string, value any) error { + s.mu.Lock() + defer s.mu.Unlock() + + var hash map[string]string + rawHash, exists := s.data[key] + if !exists { + hash = make(map[string]string) + s.data[key] = hash + } else { + var ok bool + hash, ok = rawHash.(map[string]string) + if !ok { + return fmt.Errorf("type mismatch: key '%s' holds a different data type", key) + } + } + + hash[field] = fmt.Sprint(value) + return nil +} + +func (s *MemoryStore) HGetAll(key string) (map[string]string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + rawHash, exists := s.data[key] + if !exists { + // Per Redis convention, HGETALL on a non-existent key returns an empty map, not an error. + return make(map[string]string), nil + } + + hash, ok := rawHash.(map[string]string) + if !ok { + return nil, fmt.Errorf("type mismatch: key '%s' holds a different data type", key) + } + + // Return a copy to prevent race conditions on the returned map + result := make(map[string]string, len(hash)) + for k, v := range hash { + result[k] = v + } + + return result, nil +} + +func (s *MemoryStore) HIncrBy(key, field string, incr int64) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + + var hash map[string]string + rawHash, exists := s.data[key] + if !exists { + hash = make(map[string]string) + s.data[key] = hash + } else { + var ok bool + hash, ok = rawHash.(map[string]string) + if !ok { + return 0, fmt.Errorf("type mismatch: key '%s' holds a different data type", key) + } + } + + currentVal, _ := strconv.ParseInt(hash[field], 10, 64) + newVal := currentVal + incr + hash[field] = strconv.FormatInt(newVal, 10) + + return newVal, nil +} + +// --- LIST operations --- + +func (s *MemoryStore) LPush(key string, values ...any) error { + s.mu.Lock() + defer s.mu.Unlock() + + var list []string + rawList, exists := s.data[key] + if !exists { + list = make([]string, 0) + } else { + var ok bool + list, ok = rawList.([]string) + if !ok { + return fmt.Errorf("type mismatch: key '%s' holds a different data type", key) + } + } + + strValues := make([]string, len(values)) + for i, v := range values { + strValues[i] = fmt.Sprint(v) + } + + s.data[key] = append(strValues, list...) // Prepend + return nil +} + +func (s *MemoryStore) LRem(key string, count int64, value any) error { + s.mu.Lock() + defer s.mu.Unlock() + + rawList, exists := s.data[key] + if !exists { + return nil + } + + list, ok := rawList.([]string) + if !ok { + return fmt.Errorf("type mismatch: key '%s' holds a different data type", key) + } + + strValue := fmt.Sprint(value) + newList := make([]string, 0, len(list)) + + // LREM with count = 0: Remove all elements equal to value. + if count != 0 { + // For now, only implement count = 0 behavior as it's what we need. + return fmt.Errorf("LRem with non-zero count is not implemented in MemoryStore") + } + + for _, item := range list { + if item != strValue { + newList = append(newList, item) + } + } + s.data[key] = newList + return nil +} + +func (s *MemoryStore) Rotate(key string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + rawList, exists := s.data[key] + if !exists { + return "", ErrNotFound + } + + list, ok := rawList.([]string) + if !ok { + return "", fmt.Errorf("type mismatch: key '%s' holds a different data type", key) + } + + if len(list) == 0 { + return "", ErrNotFound + } + + // "RPOP" + lastIndex := len(list) - 1 + item := list[lastIndex] + + // "LPUSH" + newList := append([]string{item}, list[:lastIndex]...) + s.data[key] = newList + + return item, nil +} diff --git a/internal/store/redis.go b/internal/store/redis.go index 634d796..10b024c 100644 --- a/internal/store/redis.go +++ b/internal/store/redis.go @@ -49,7 +49,71 @@ func (s *RedisStore) Exists(key string) (bool, error) { return val > 0, nil } +// SetNX sets a key-value pair in Redis if the key does not already exist. +func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) { + return s.client.SetNX(context.Background(), key, value, ttl).Result() +} + // Close closes the Redis client connection. func (s *RedisStore) Close() error { return s.client.Close() } + +// --- HASH operations --- + +func (s *RedisStore) HSet(key, field string, value any) error { + return s.client.HSet(context.Background(), key, field, value).Err() +} + +func (s *RedisStore) HGetAll(key string) (map[string]string, error) { + return s.client.HGetAll(context.Background(), key).Result() +} + +func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) { + return s.client.HIncrBy(context.Background(), key, field, incr).Result() +} + +// --- LIST operations --- + +func (s *RedisStore) LPush(key string, values ...any) error { + return s.client.LPush(context.Background(), key, values...).Err() +} + +func (s *RedisStore) LRem(key string, count int64, value any) error { + return s.client.LRem(context.Background(), key, count, value).Err() +} + +func (s *RedisStore) Rotate(key string) (string, error) { + val, err := s.client.RPopLPush(context.Background(), key, key).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return "", ErrNotFound + } + return "", err + } + return val, nil +} + +// --- Pipeliner implementation --- + +type redisPipeliner struct { + pipe redis.Pipeliner +} + +// HSet adds an HSET command to the pipeline. +func (p *redisPipeliner) HSet(key string, values map[string]any) { + p.pipe.HSet(context.Background(), key, values) +} + +// Exec executes all commands in the pipeline. +func (p *redisPipeliner) Exec() error { + _, err := p.pipe.Exec(context.Background()) + return err +} + +// Pipeline creates a new pipeline. +func (s *RedisStore) Pipeline() Pipeliner { + return &redisPipeliner{ + pipe: s.client.Pipeline(), + } +} diff --git a/internal/store/store.go b/internal/store/store.go index f4a1fac..a5ee81b 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -28,6 +28,31 @@ type Store interface { // Exists checks if a key exists in the store. Exists(key string) (bool, error) + // SetNX sets a key-value pair if the key does not already exist. + // It returns true if the key was set, false otherwise. + SetNX(key string, value []byte, ttl time.Duration) (bool, error) + + // HASH operations + HSet(key, field string, value any) error + HGetAll(key string) (map[string]string, error) + HIncrBy(key, field string, incr int64) (int64, error) + + // LIST operations + LPush(key string, values ...any) error + LRem(key string, count int64, value any) error + Rotate(key string) (string, error) + // Close closes the store and releases any underlying resources. Close() error } + +// Pipeliner defines an interface for executing a batch of commands. +type Pipeliner interface { + HSet(key string, values map[string]any) + Exec() error +} + +// RedisPipeliner is an optional interface that a Store can implement to provide pipelining. +type RedisPipeliner interface { + Pipeline() Pipeliner +}