feat: key provider
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gpt-load/internal/config"
|
"gpt-load/internal/config"
|
||||||
|
"gpt-load/internal/keypool"
|
||||||
"gpt-load/internal/models"
|
"gpt-load/internal/models"
|
||||||
"gpt-load/internal/proxy"
|
"gpt-load/internal/proxy"
|
||||||
"gpt-load/internal/services"
|
"gpt-load/internal/services"
|
||||||
@@ -29,6 +30,7 @@ type App struct {
|
|||||||
logCleanupService *services.LogCleanupService
|
logCleanupService *services.LogCleanupService
|
||||||
keyCronService *services.KeyCronService
|
keyCronService *services.KeyCronService
|
||||||
keyValidationPool *services.KeyValidationPool
|
keyValidationPool *services.KeyValidationPool
|
||||||
|
keyPoolProvider *keypool.KeyProvider
|
||||||
proxyServer *proxy.ProxyServer
|
proxyServer *proxy.ProxyServer
|
||||||
storage store.Store
|
storage store.Store
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
@@ -46,6 +48,7 @@ type AppParams struct {
|
|||||||
LogCleanupService *services.LogCleanupService
|
LogCleanupService *services.LogCleanupService
|
||||||
KeyCronService *services.KeyCronService
|
KeyCronService *services.KeyCronService
|
||||||
KeyValidationPool *services.KeyValidationPool
|
KeyValidationPool *services.KeyValidationPool
|
||||||
|
KeyPoolProvider *keypool.KeyProvider
|
||||||
ProxyServer *proxy.ProxyServer
|
ProxyServer *proxy.ProxyServer
|
||||||
Storage store.Store
|
Storage store.Store
|
||||||
DB *gorm.DB
|
DB *gorm.DB
|
||||||
@@ -61,6 +64,7 @@ func NewApp(params AppParams) *App {
|
|||||||
logCleanupService: params.LogCleanupService,
|
logCleanupService: params.LogCleanupService,
|
||||||
keyCronService: params.KeyCronService,
|
keyCronService: params.KeyCronService,
|
||||||
keyValidationPool: params.KeyValidationPool,
|
keyValidationPool: params.KeyValidationPool,
|
||||||
|
keyPoolProvider: params.KeyPoolProvider,
|
||||||
proxyServer: params.ProxyServer,
|
proxyServer: params.ProxyServer,
|
||||||
storage: params.Storage,
|
storage: params.Storage,
|
||||||
db: params.DB,
|
db: params.DB,
|
||||||
@@ -75,6 +79,11 @@ func (a *App) Start() error {
|
|||||||
return fmt.Errorf("failed to initialize system settings: %w", err)
|
return fmt.Errorf("failed to initialize system settings: %w", err)
|
||||||
}
|
}
|
||||||
logrus.Info("System settings initialized")
|
logrus.Info("System settings initialized")
|
||||||
|
|
||||||
|
logrus.Info("Loading API keys into the key pool...")
|
||||||
|
if err := a.keyPoolProvider.LoadKeysFromDB(); err != nil {
|
||||||
|
return fmt.Errorf("failed to load keys into key pool: %w", err)
|
||||||
|
}
|
||||||
a.settingsManager.DisplayCurrentSettings()
|
a.settingsManager.DisplayCurrentSettings()
|
||||||
a.configManager.DisplayConfig()
|
a.configManager.DisplayConfig()
|
||||||
|
|
||||||
|
@@ -7,6 +7,7 @@ import (
|
|||||||
"gpt-load/internal/config"
|
"gpt-load/internal/config"
|
||||||
"gpt-load/internal/db"
|
"gpt-load/internal/db"
|
||||||
"gpt-load/internal/handler"
|
"gpt-load/internal/handler"
|
||||||
|
"gpt-load/internal/keypool"
|
||||||
"gpt-load/internal/proxy"
|
"gpt-load/internal/proxy"
|
||||||
"gpt-load/internal/router"
|
"gpt-load/internal/router"
|
||||||
"gpt-load/internal/services"
|
"gpt-load/internal/services"
|
||||||
@@ -58,6 +59,9 @@ func BuildContainer() (*dig.Container, error) {
|
|||||||
if err := container.Provide(services.NewLogCleanupService); err != nil {
|
if err := container.Provide(services.NewLogCleanupService); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err := container.Provide(keypool.NewProvider); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Handlers
|
// Handlers
|
||||||
if err := container.Provide(handler.NewServer); err != nil {
|
if err := container.Provide(handler.NewServer); err != nil {
|
||||||
|
@@ -33,6 +33,7 @@ var (
|
|||||||
ErrForbidden = &APIError{HTTPStatus: http.StatusForbidden, Code: "FORBIDDEN", Message: "You do not have permission to access this resource"}
|
ErrForbidden = &APIError{HTTPStatus: http.StatusForbidden, Code: "FORBIDDEN", Message: "You do not have permission to access this resource"}
|
||||||
ErrTaskInProgress = &APIError{HTTPStatus: http.StatusConflict, Code: "TASK_IN_PROGRESS", Message: "A task is already in progress"}
|
ErrTaskInProgress = &APIError{HTTPStatus: http.StatusConflict, Code: "TASK_IN_PROGRESS", Message: "A task is already in progress"}
|
||||||
ErrBadGateway = &APIError{HTTPStatus: http.StatusBadGateway, Code: "BAD_GATEWAY", Message: "Upstream service error"}
|
ErrBadGateway = &APIError{HTTPStatus: http.StatusBadGateway, Code: "BAD_GATEWAY", Message: "Upstream service error"}
|
||||||
|
ErrNoActiveKeys = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_ACTIVE_KEYS", Message: "No active API keys available for this group"}
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewAPIError creates a new APIError with a custom message.
|
// NewAPIError creates a new APIError with a custom message.
|
||||||
|
451
internal/keypool/provider.go
Normal file
451
internal/keypool/provider.go
Normal file
@@ -0,0 +1,451 @@
|
|||||||
|
package keypool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"gpt-load/internal/config"
|
||||||
|
app_errors "gpt-load/internal/errors"
|
||||||
|
"gpt-load/internal/models"
|
||||||
|
"gpt-load/internal/store"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
keypoolInitializedKey = "keypool:initialized"
|
||||||
|
keypoolLoadingKey = "keypool:loading"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyProvider struct {
|
||||||
|
db *gorm.DB
|
||||||
|
store store.Store
|
||||||
|
settingsManager *config.SystemSettingsManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProvider 创建一个新的 KeyProvider 实例。
|
||||||
|
func NewProvider(db *gorm.DB, store store.Store, settingsManager *config.SystemSettingsManager) *KeyProvider {
|
||||||
|
return &KeyProvider{
|
||||||
|
db: db,
|
||||||
|
store: store,
|
||||||
|
settingsManager: settingsManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectKey 为指定的分组原子性地选择并轮换一个可用的 APIKey。
|
||||||
|
func (p *KeyProvider) SelectKey(groupID uint) (*models.APIKey, error) {
|
||||||
|
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
|
||||||
|
|
||||||
|
// 1. Atomically rotate the key ID from the list
|
||||||
|
keyIDStr, err := p.store.Rotate(activeKeysListKey)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, store.ErrNotFound) {
|
||||||
|
return nil, app_errors.ErrNoActiveKeys
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to rotate key from store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keyID, err := strconv.ParseUint(keyIDStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse key ID '%s': %w", keyIDStr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Get key details from HASH
|
||||||
|
keyHashKey := fmt.Sprintf("key:%d", keyID)
|
||||||
|
keyDetails, err := p.store.HGetAll(keyHashKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get key details for key ID %d: %w", keyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Manually unmarshal the map into an APIKey struct
|
||||||
|
failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64)
|
||||||
|
createdAt, _ := strconv.ParseInt(keyDetails["created_at"], 10, 64)
|
||||||
|
|
||||||
|
apiKey := &models.APIKey{
|
||||||
|
ID: uint(keyID),
|
||||||
|
KeyValue: keyDetails["key_string"],
|
||||||
|
Status: keyDetails["status"],
|
||||||
|
FailureCount: failureCount,
|
||||||
|
GroupID: groupID,
|
||||||
|
CreatedAt: time.Unix(createdAt, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStatus 异步地提交一个 Key 状态更新任务。
|
||||||
|
func (p *KeyProvider) UpdateStatus(keyID uint, groupID uint, isSuccess bool) {
|
||||||
|
go func() {
|
||||||
|
keyHashKey := fmt.Sprintf("key:%d", keyID)
|
||||||
|
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
|
||||||
|
|
||||||
|
if isSuccess {
|
||||||
|
p.handleSuccess(keyID, keyHashKey, activeKeysListKey)
|
||||||
|
} else {
|
||||||
|
p.handleFailure(keyID, keyHashKey, activeKeysListKey)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *KeyProvider) handleSuccess(keyID uint, keyHashKey, activeKeysListKey string) {
|
||||||
|
keyDetails, err := p.store.HGetAll(keyHashKey)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to get key details on success")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
failureCount, _ := strconv.ParseInt(keyDetails["failure_count"], 10, 64)
|
||||||
|
isInvalid := keyDetails["status"] == models.KeyStatusInvalid
|
||||||
|
|
||||||
|
if failureCount == 0 && !isInvalid {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.store.HSet(keyHashKey, "failure_count", 0); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to reset failure count in store, aborting DB update.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updates := map[string]any{"failure_count": 0}
|
||||||
|
|
||||||
|
if isInvalid {
|
||||||
|
logrus.WithField("keyID", keyID).Info("Key has recovered and is being restored to active pool.")
|
||||||
|
if err := p.store.HSet(keyHashKey, "status", models.KeyStatusActive); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key status to active in store, aborting DB update.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// To prevent duplicates, first remove any existing instance of the key from the list.
|
||||||
|
// This makes the recovery operation idempotent.
|
||||||
|
if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to LRem key before LPush on recovery, aborting DB update.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := p.store.LPush(activeKeysListKey, keyID); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to LPush key back to active list, aborting DB update.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updates["status"] = models.KeyStatusActive
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.db.Model(&models.APIKey{}).Where("id = ?", keyID).Updates(updates).Error; err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key status in DB on success")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *KeyProvider) handleFailure(keyID uint, keyHashKey, activeKeysListKey string) {
|
||||||
|
keyDetails, err := p.store.HGetAll(keyHashKey)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to get key details on failure")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if keyDetails["status"] == models.KeyStatusInvalid {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newFailureCount, err := p.store.HIncrBy(keyHashKey, "failure_count", 1)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to increment failure count")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := p.settingsManager.GetSettings()
|
||||||
|
blacklistThreshold := settings.BlacklistThreshold
|
||||||
|
updates := map[string]any{"failure_count": newFailureCount}
|
||||||
|
|
||||||
|
if newFailureCount >= int64(blacklistThreshold) {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": keyID, "threshold": blacklistThreshold}).Warn("Key has reached blacklist threshold, disabling.")
|
||||||
|
if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to LRem key from active list, aborting DB update.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := p.store.HSet(keyHashKey, "status", models.KeyStatusInvalid); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key status to invalid in store, aborting DB update.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updates["status"] = models.KeyStatusInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.db.Model(&models.APIKey{}).Where("id = ?", keyID).Updates(updates).Error; err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": keyID, "error": err}).Error("Failed to update key stats in DB on failure")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadKeysFromDB 从数据库加载所有分组和密钥,并填充到 Store 中。
|
||||||
|
func (p *KeyProvider) LoadKeysFromDB() error {
|
||||||
|
// 1. 检查是否已初始化
|
||||||
|
initialized, err := p.store.Exists(keypoolInitializedKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check for keypool initialization flag: %w", err)
|
||||||
|
}
|
||||||
|
if initialized {
|
||||||
|
logrus.Info("Key pool already initialized, skipping database load.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 设置加载锁,防止集群中多个节点同时加载
|
||||||
|
lockAcquired, err := p.store.SetNX(keypoolLoadingKey, []byte("1"), 10*time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to acquire loading lock: %w", err)
|
||||||
|
}
|
||||||
|
if !lockAcquired {
|
||||||
|
logrus.Info("Another instance is already loading the key pool. Skipping.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer p.store.Delete(keypoolLoadingKey)
|
||||||
|
|
||||||
|
logrus.Info("Acquired loading lock. Starting first-time initialization of key pool...")
|
||||||
|
|
||||||
|
// 3. 分批从数据库加载并使用 Pipeline 写入 Redis
|
||||||
|
allActiveKeyIDs := make(map[uint][]any)
|
||||||
|
batchSize := 1000
|
||||||
|
|
||||||
|
err = p.db.Model(&models.APIKey{}).FindInBatches(&[]*models.APIKey{}, batchSize, func(tx *gorm.DB, batch int) error {
|
||||||
|
keys := tx.RowsAffected
|
||||||
|
logrus.Infof("Processing batch %d with %d keys...", batch, keys)
|
||||||
|
|
||||||
|
var pipeline store.Pipeliner
|
||||||
|
if redisStore, ok := p.store.(store.RedisPipeliner); ok {
|
||||||
|
pipeline = redisStore.Pipeline()
|
||||||
|
}
|
||||||
|
|
||||||
|
var batchKeys []*models.APIKey
|
||||||
|
if err := tx.Find(&batchKeys).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range batchKeys {
|
||||||
|
keyHashKey := fmt.Sprintf("key:%d", key.ID)
|
||||||
|
keyDetails := p.apiKeyToMap(key)
|
||||||
|
|
||||||
|
if pipeline != nil {
|
||||||
|
pipeline.HSet(keyHashKey, keyDetails)
|
||||||
|
} else {
|
||||||
|
for field, value := range keyDetails {
|
||||||
|
if err := p.store.HSet(keyHashKey, field, value); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to HSet key details")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.Status == models.KeyStatusActive {
|
||||||
|
allActiveKeyIDs[key.GroupID] = append(allActiveKeyIDs[key.GroupID], key.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pipeline != nil {
|
||||||
|
if err := pipeline.Exec(); err != nil {
|
||||||
|
return fmt.Errorf("failed to execute pipeline for batch %d: %w", batch, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed during batch processing of keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 更新所有分组的 active_keys 列表
|
||||||
|
logrus.Info("Updating active key lists for all groups...")
|
||||||
|
for groupID, activeIDs := range allActiveKeyIDs {
|
||||||
|
if len(activeIDs) > 0 {
|
||||||
|
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
|
||||||
|
p.store.Delete(activeKeysListKey) // Clean slate
|
||||||
|
if err := p.store.LPush(activeKeysListKey, activeIDs...); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"groupID": groupID, "error": err}).Error("Failed to LPush active keys for group")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 设置最终的初始化成功标志
|
||||||
|
logrus.Info("Key pool loaded successfully. Setting initialization flag.")
|
||||||
|
if err := p.store.Set(keypoolInitializedKey, []byte("1"), 0); err != nil {
|
||||||
|
logrus.WithError(err).Error("Critical: Failed to set final initialization flag. Next startup might re-run initialization.")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddKeys 批量添加新的 Key 到池和数据库中。
|
||||||
|
func (p *KeyProvider) AddKeys(groupID uint, keys []models.APIKey) error {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := p.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
if err := tx.Create(&keys).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range keys {
|
||||||
|
if err := p.addKeyToStore(&key); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to add key to store after DB creation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveKeys 批量从池和数据库中移除 Key。
|
||||||
|
func (p *KeyProvider) RemoveKeys(groupID uint, keyValues []string) (int64, error) {
|
||||||
|
if len(keyValues) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var keysToDelete []models.APIKey
|
||||||
|
var deletedCount int64
|
||||||
|
|
||||||
|
err := p.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
if err := tx.Where("group_id = ? AND key_value IN ?", groupID, keyValues).Find(&keysToDelete).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keysToDelete) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := tx.Where("group_id = ? AND key_value IN ?", groupID, keyValues).Delete(&models.APIKey{})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
deletedCount = result.RowsAffected
|
||||||
|
|
||||||
|
for _, key := range keysToDelete {
|
||||||
|
if err := p.removeKeyFromStore(key.ID, key.GroupID); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove key from store after DB deletion")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return deletedCount, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreKeys 恢复组内所有无效的 Key。
|
||||||
|
func (p *KeyProvider) RestoreKeys(groupID uint) (int64, error) {
|
||||||
|
var invalidKeys []models.APIKey
|
||||||
|
var restoredCount int64
|
||||||
|
|
||||||
|
err := p.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
if err := tx.Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Find(&invalidKeys).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(invalidKeys) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := tx.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Update("status", models.KeyStatusActive)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
restoredCount = result.RowsAffected
|
||||||
|
|
||||||
|
for _, key := range invalidKeys {
|
||||||
|
key.Status = models.KeyStatusActive
|
||||||
|
if err := p.addKeyToStore(&key); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to restore key in store after DB update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return restoredCount, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveInvalidKeys 移除组内所有无效的 Key。
|
||||||
|
func (p *KeyProvider) RemoveInvalidKeys(groupID uint) (int64, error) {
|
||||||
|
var invalidKeys []models.APIKey
|
||||||
|
var removedCount int64
|
||||||
|
|
||||||
|
err := p.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
if err := tx.Where("group_id = ? AND status = ?", groupID, models.KeyStatusInvalid).Find(&invalidKeys).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(invalidKeys) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := tx.Where("id IN ?", pluckIDs(invalidKeys)).Delete(&models.APIKey{})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
removedCount = result.RowsAffected
|
||||||
|
|
||||||
|
for _, key := range invalidKeys {
|
||||||
|
keyHashKey := fmt.Sprintf("key:%d", key.ID)
|
||||||
|
if err := p.store.Delete(keyHashKey); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": key.ID, "error": err}).Error("Failed to remove invalid key HASH from store after DB deletion")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return removedCount, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// addKeyToStore is a helper to add a single key to the cache.
|
||||||
|
func (p *KeyProvider) addKeyToStore(key *models.APIKey) error {
|
||||||
|
// 1. Store key details in HASH
|
||||||
|
keyHashKey := fmt.Sprintf("key:%d", key.ID)
|
||||||
|
keyDetails := p.apiKeyToMap(key)
|
||||||
|
for field, value := range keyDetails {
|
||||||
|
if err := p.store.HSet(keyHashKey, field, value); err != nil {
|
||||||
|
return fmt.Errorf("failed to HSet key details for key %d: %w", key.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. If active, add to the active LIST
|
||||||
|
if key.Status == models.KeyStatusActive {
|
||||||
|
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", key.GroupID)
|
||||||
|
// To prevent duplicates, first remove any existing instance of the key from the list.
|
||||||
|
// This makes the add operation idempotent regarding the list.
|
||||||
|
if err := p.store.LRem(activeKeysListKey, 0, key.ID); err != nil {
|
||||||
|
return fmt.Errorf("failed to LRem key %d before LPush for group %d: %w", key.ID, key.GroupID, err)
|
||||||
|
}
|
||||||
|
if err := p.store.LPush(activeKeysListKey, key.ID); err != nil {
|
||||||
|
return fmt.Errorf("failed to LPush key %d to group %d: %w", key.ID, key.GroupID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeKeyFromStore is a helper to remove a single key from the cache.
|
||||||
|
func (p *KeyProvider) removeKeyFromStore(keyID, groupID uint) error {
|
||||||
|
activeKeysListKey := fmt.Sprintf("group:%d:active_keys", groupID)
|
||||||
|
if err := p.store.LRem(activeKeysListKey, 0, keyID); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{"keyID": keyID, "groupID": groupID, "error": err}).Error("Failed to LRem key from active list")
|
||||||
|
}
|
||||||
|
|
||||||
|
keyHashKey := fmt.Sprintf("key:%d", keyID)
|
||||||
|
if err := p.store.Delete(keyHashKey); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete key HASH for key %d: %w", keyID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// apiKeyToMap converts an APIKey model to a map for HSET.
|
||||||
|
func (p *KeyProvider) apiKeyToMap(key *models.APIKey) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"id": fmt.Sprint(key.ID), // Use fmt.Sprint for consistency in pipeline
|
||||||
|
"key_string": key.KeyValue,
|
||||||
|
"status": key.Status,
|
||||||
|
"failure_count": key.FailureCount,
|
||||||
|
"group_id": key.GroupID,
|
||||||
|
"created_at": key.CreatedAt.Unix(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pluckIDs extracts IDs from a slice of APIKey.
|
||||||
|
func pluckIDs(keys []models.APIKey) []uint {
|
||||||
|
ids := make([]uint, len(keys))
|
||||||
|
for i, key := range keys {
|
||||||
|
ids[i] = key.ID
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
@@ -7,11 +7,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"gpt-load/internal/channel"
|
"gpt-load/internal/channel"
|
||||||
app_errors "gpt-load/internal/errors"
|
app_errors "gpt-load/internal/errors"
|
||||||
|
"gpt-load/internal/keypool"
|
||||||
"gpt-load/internal/models"
|
"gpt-load/internal/models"
|
||||||
"gpt-load/internal/response"
|
"gpt-load/internal/response"
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -23,16 +22,21 @@ import (
|
|||||||
type ProxyServer struct {
|
type ProxyServer struct {
|
||||||
DB *gorm.DB
|
DB *gorm.DB
|
||||||
channelFactory *channel.Factory
|
channelFactory *channel.Factory
|
||||||
groupCounters sync.Map // map[uint]*atomic.Uint64
|
keyProvider *keypool.KeyProvider
|
||||||
requestLogChan chan models.RequestLog
|
requestLogChan chan models.RequestLog
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewProxyServer creates a new proxy server
|
// NewProxyServer creates a new proxy server
|
||||||
func NewProxyServer(db *gorm.DB, channelFactory *channel.Factory, requestLogChan chan models.RequestLog) (*ProxyServer, error) {
|
func NewProxyServer(
|
||||||
|
db *gorm.DB,
|
||||||
|
channelFactory *channel.Factory,
|
||||||
|
keyProvider *keypool.KeyProvider,
|
||||||
|
requestLogChan chan models.RequestLog,
|
||||||
|
) (*ProxyServer, error) {
|
||||||
return &ProxyServer{
|
return &ProxyServer{
|
||||||
DB: db,
|
DB: db,
|
||||||
channelFactory: channelFactory,
|
channelFactory: channelFactory,
|
||||||
groupCounters: sync.Map{},
|
keyProvider: keyProvider,
|
||||||
requestLogChan: requestLogChan,
|
requestLogChan: requestLogChan,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -42,17 +46,22 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) {
|
|||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
groupName := c.Param("group_name")
|
groupName := c.Param("group_name")
|
||||||
|
|
||||||
// 1. Find the group by name
|
// 1. Find the group by name (without preloading keys)
|
||||||
var group models.Group
|
var group models.Group
|
||||||
if err := ps.DB.Preload("APIKeys").Where("name = ?", groupName).First(&group).Error; err != nil {
|
if err := ps.DB.Where("name = ?", groupName).First(&group).Error; err != nil {
|
||||||
response.Error(c, app_errors.ParseDBError(err))
|
response.Error(c, app_errors.ParseDBError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Select an available API key from the group
|
// 2. Select an available API key from the KeyPool
|
||||||
apiKey, err := ps.selectAPIKey(&group)
|
apiKey, err := ps.keyProvider.SelectKey(group.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, err.Error()))
|
// Properly handle the case where no keys are available
|
||||||
|
if apiErr, ok := err.(*app_errors.APIError); ok {
|
||||||
|
response.Error(c, apiErr)
|
||||||
|
} else {
|
||||||
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, err.Error()))
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,8 +84,10 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) {
|
|||||||
// 5. Forward the request using the channel handler
|
// 5. Forward the request using the channel handler
|
||||||
err = channelHandler.Handle(c, apiKey, &group)
|
err = channelHandler.Handle(c, apiKey, &group)
|
||||||
|
|
||||||
// 6. Log the request asynchronously
|
// 6. Update key status and log the request asynchronously
|
||||||
isSuccess := err == nil
|
isSuccess := err == nil
|
||||||
|
ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, isSuccess)
|
||||||
|
|
||||||
if !isSuccess {
|
if !isSuccess {
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"group": group.Name,
|
"group": group.Name,
|
||||||
@@ -84,37 +95,10 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) {
|
|||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
}).Error("Channel handler failed")
|
}).Error("Channel handler failed")
|
||||||
}
|
}
|
||||||
go ps.logRequest(c, &group, apiKey, startTime, isSuccess)
|
go ps.logRequest(c, &group, apiKey, startTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectAPIKey selects an API key from a group using round-robin
|
func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time) {
|
||||||
func (ps *ProxyServer) selectAPIKey(group *models.Group) (*models.APIKey, error) {
|
|
||||||
activeKeys := make([]models.APIKey, 0, len(group.APIKeys))
|
|
||||||
for _, key := range group.APIKeys {
|
|
||||||
if key.Status == "active" {
|
|
||||||
activeKeys = append(activeKeys, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(activeKeys) == 0 {
|
|
||||||
return nil, fmt.Errorf("no active API keys available in group '%s'", group.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get or create a counter for the group. The value is a pointer to a uint64.
|
|
||||||
val, _ := ps.groupCounters.LoadOrStore(group.ID, new(atomic.Uint64))
|
|
||||||
counter := val.(*atomic.Uint64)
|
|
||||||
|
|
||||||
// Atomically increment the counter and get the index for this request.
|
|
||||||
index := counter.Add(1) - 1
|
|
||||||
selectedKey := activeKeys[int(index%uint64(len(activeKeys)))]
|
|
||||||
|
|
||||||
return &selectedKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time, isSuccess bool) {
|
|
||||||
// Update key stats based on request success
|
|
||||||
go ps.updateKeyStats(key.ID, isSuccess)
|
|
||||||
|
|
||||||
logEntry := models.RequestLog{
|
logEntry := models.RequestLog{
|
||||||
ID: fmt.Sprintf("req_%d", time.Now().UnixNano()),
|
ID: fmt.Sprintf("req_%d", time.Now().UnixNano()),
|
||||||
Timestamp: startTime,
|
Timestamp: startTime,
|
||||||
@@ -134,27 +118,6 @@ func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *mode
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateKeyStats atomically updates the request and failure counts for a key
|
|
||||||
func (ps *ProxyServer) updateKeyStats(keyID uint, success bool) {
|
|
||||||
// Always increment the request count
|
|
||||||
updates := map[string]any{
|
|
||||||
"request_count": gorm.Expr("request_count + 1"),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Additionally, increment the failure count if the request was not successful
|
|
||||||
if !success {
|
|
||||||
updates["failure_count"] = gorm.Expr("failure_count + 1")
|
|
||||||
}
|
|
||||||
|
|
||||||
result := ps.DB.Model(&models.APIKey{}).Where("id = ?", keyID).Updates(updates)
|
|
||||||
if result.Error != nil {
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"keyID": keyID,
|
|
||||||
"error": result.Error,
|
|
||||||
}).Error("Failed to update key stats")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close cleans up resources
|
// Close cleans up resources
|
||||||
func (ps *ProxyServer) Close() {
|
func (ps *ProxyServer) Close() {
|
||||||
// Nothing to close for now
|
// Nothing to close for now
|
||||||
|
@@ -3,6 +3,7 @@ package services
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"gpt-load/internal/keypool"
|
||||||
"gpt-load/internal/models"
|
"gpt-load/internal/models"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -26,12 +27,16 @@ type DeleteKeysResult struct {
|
|||||||
|
|
||||||
// KeyService provides services related to API keys.
|
// KeyService provides services related to API keys.
|
||||||
type KeyService struct {
|
type KeyService struct {
|
||||||
DB *gorm.DB
|
DB *gorm.DB
|
||||||
|
KeyProvider *keypool.KeyProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewKeyService creates a new KeyService.
|
// NewKeyService creates a new KeyService.
|
||||||
func NewKeyService(db *gorm.DB) *KeyService {
|
func NewKeyService(db *gorm.DB, keyProvider *keypool.KeyProvider) *KeyService {
|
||||||
return &KeyService{DB: db}
|
return &KeyService{
|
||||||
|
DB: db,
|
||||||
|
KeyProvider: keyProvider,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddMultipleKeys handles the business logic of creating new keys from a text block.
|
// AddMultipleKeys handles the business logic of creating new keys from a text block.
|
||||||
@@ -42,13 +47,7 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes
|
|||||||
return nil, fmt.Errorf("no valid keys found in the input text")
|
return nil, fmt.Errorf("no valid keys found in the input text")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Get the group information for validation
|
// 2. Get existing keys in the group for deduplication
|
||||||
var group models.Group
|
|
||||||
if err := s.DB.First(&group, groupID).Error; err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to find group: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Get existing keys in the group for deduplication
|
|
||||||
var existingKeys []models.APIKey
|
var existingKeys []models.APIKey
|
||||||
if err := s.DB.Where("group_id = ?", groupID).Select("key_value").Find(&existingKeys).Error; err != nil {
|
if err := s.DB.Where("group_id = ?", groupID).Select("key_value").Find(&existingKeys).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -58,7 +57,7 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes
|
|||||||
existingKeyMap[k.KeyValue] = true
|
existingKeyMap[k.KeyValue] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Prepare new keys with basic validation only
|
// 3. Prepare new keys for creation
|
||||||
var newKeysToCreate []models.APIKey
|
var newKeysToCreate []models.APIKey
|
||||||
uniqueNewKeys := make(map[string]bool)
|
uniqueNewKeys := make(map[string]bool)
|
||||||
|
|
||||||
@@ -67,43 +66,44 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes
|
|||||||
if trimmedKey == "" {
|
if trimmedKey == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if key already exists
|
|
||||||
if existingKeyMap[trimmedKey] || uniqueNewKeys[trimmedKey] {
|
if existingKeyMap[trimmedKey] || uniqueNewKeys[trimmedKey] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 通用验证:只做基础格式检查,不做渠道特定验证
|
|
||||||
if s.isValidKeyFormat(trimmedKey) {
|
if s.isValidKeyFormat(trimmedKey) {
|
||||||
uniqueNewKeys[trimmedKey] = true
|
uniqueNewKeys[trimmedKey] = true
|
||||||
newKeysToCreate = append(newKeysToCreate, models.APIKey{
|
newKeysToCreate = append(newKeysToCreate, models.APIKey{
|
||||||
GroupID: groupID,
|
GroupID: groupID,
|
||||||
KeyValue: trimmedKey,
|
KeyValue: trimmedKey,
|
||||||
Status: "active",
|
Status: models.KeyStatusActive,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
addedCount := len(newKeysToCreate)
|
if len(newKeysToCreate) == 0 {
|
||||||
// 更准确的忽略计数:包括重复的和无效的
|
var totalInGroup int64
|
||||||
ignoredCount := len(keys) - addedCount
|
s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup)
|
||||||
|
return &AddKeysResult{
|
||||||
// 5. Insert new keys if any
|
AddedCount: 0,
|
||||||
if addedCount > 0 {
|
IgnoredCount: len(keys),
|
||||||
if err := s.DB.Create(&newKeysToCreate).Error; err != nil {
|
TotalInGroup: totalInGroup,
|
||||||
return nil, err
|
}, nil
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. Get the new total count
|
// 4. Use KeyProvider to add keys, which handles DB and cache
|
||||||
|
err := s.KeyProvider.AddKeys(groupID, newKeysToCreate)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Get the new total count
|
||||||
var totalInGroup int64
|
var totalInGroup int64
|
||||||
if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil {
|
if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &AddKeysResult{
|
return &AddKeysResult{
|
||||||
AddedCount: addedCount,
|
AddedCount: len(newKeysToCreate),
|
||||||
IgnoredCount: ignoredCount,
|
IgnoredCount: len(keys) - len(newKeysToCreate),
|
||||||
TotalInGroup: totalInGroup,
|
TotalInGroup: totalInGroup,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -161,14 +161,12 @@ func (s *KeyService) isValidKeyFormat(key string) bool {
|
|||||||
|
|
||||||
// RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'.
|
// RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'.
|
||||||
func (s *KeyService) RestoreAllInvalidKeys(groupID uint) (int64, error) {
|
func (s *KeyService) RestoreAllInvalidKeys(groupID uint) (int64, error) {
|
||||||
result := s.DB.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, "inactive").Update("status", "active")
|
return s.KeyProvider.RestoreKeys(groupID)
|
||||||
return result.RowsAffected, result.Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearAllInvalidKeys deletes all 'inactive' keys from a group.
|
// ClearAllInvalidKeys deletes all 'inactive' keys from a group.
|
||||||
func (s *KeyService) ClearAllInvalidKeys(groupID uint) (int64, error) {
|
func (s *KeyService) ClearAllInvalidKeys(groupID uint) (int64, error) {
|
||||||
result := s.DB.Where("group_id = ? AND status = ?", groupID, "inactive").Delete(&models.APIKey{})
|
return s.KeyProvider.RemoveInvalidKeys(groupID)
|
||||||
return result.RowsAffected, result.Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteMultipleKeys handles the business logic of deleting keys from a text block.
|
// DeleteMultipleKeys handles the business logic of deleting keys from a text block.
|
||||||
@@ -179,16 +177,13 @@ func (s *KeyService) DeleteMultipleKeys(groupID uint, keysText string) (*DeleteK
|
|||||||
return nil, fmt.Errorf("no valid keys found in the input text")
|
return nil, fmt.Errorf("no valid keys found in the input text")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Perform the deletion
|
// 2. Use KeyProvider to delete keys, which handles DB and cache
|
||||||
// GORM's batch delete doesn't easily return which ones were deleted vs. ignored.
|
deletedCount, err := s.KeyProvider.RemoveKeys(groupID, keysToDelete)
|
||||||
// We perform a bulk delete and then count the remaining to calculate the result.
|
if err != nil {
|
||||||
result := s.DB.Where("group_id = ? AND key_value IN ?", groupID, keysToDelete).Delete(&models.APIKey{})
|
return nil, err
|
||||||
if result.Error != nil {
|
|
||||||
return nil, result.Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
deletedCount := int(result.RowsAffected)
|
ignoredCount := len(keysToDelete) - int(deletedCount)
|
||||||
ignoredCount := len(keysToDelete) - deletedCount
|
|
||||||
|
|
||||||
// 3. Get the new total count
|
// 3. Get the new total count
|
||||||
var totalInGroup int64
|
var totalInGroup int64
|
||||||
@@ -197,7 +192,7 @@ func (s *KeyService) DeleteMultipleKeys(groupID uint, keysText string) (*DeleteK
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &DeleteKeysResult{
|
return &DeleteKeysResult{
|
||||||
DeletedCount: deletedCount,
|
DeletedCount: int(deletedCount),
|
||||||
IgnoredCount: ignoredCount,
|
IgnoredCount: ignoredCount,
|
||||||
TotalInGroup: totalInGroup,
|
TotalInGroup: totalInGroup,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -219,4 +214,3 @@ func (s *KeyService) ListKeysInGroupQuery(groupID uint, statusFilter string, sea
|
|||||||
|
|
||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -12,50 +14,29 @@ type memoryStoreItem struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MemoryStore is an in-memory key-value store that is safe for concurrent use.
|
// MemoryStore is an in-memory key-value store that is safe for concurrent use.
|
||||||
|
// It now supports simple K/V, HASH, and LIST data types.
|
||||||
type MemoryStore struct {
|
type MemoryStore struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
data map[string]memoryStoreItem
|
// Using 'any' to store different data structures (memoryStoreItem, map[string]string, []string)
|
||||||
stopCh chan struct{} // Channel to stop the cleanup goroutine
|
data map[string]any
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMemoryStore creates and returns a new MemoryStore instance.
|
// NewMemoryStore creates and returns a new MemoryStore instance.
|
||||||
// It also starts a background goroutine to periodically clean up expired keys.
|
|
||||||
func NewMemoryStore() *MemoryStore {
|
func NewMemoryStore() *MemoryStore {
|
||||||
s := &MemoryStore{
|
s := &MemoryStore{
|
||||||
data: make(map[string]memoryStoreItem),
|
data: make(map[string]any),
|
||||||
stopCh: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
go s.cleanupLoop(1 * time.Minute)
|
// The cleanup loop was removed as it's not compatible with multiple data types
|
||||||
|
// without a unified expiration mechanism, and the KeyPool feature does not rely on TTLs.
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close stops the background cleanup goroutine.
|
// Close cleans up resources.
|
||||||
func (s *MemoryStore) Close() error {
|
func (s *MemoryStore) Close() error {
|
||||||
close(s.stopCh)
|
// Nothing to close for now.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupLoop periodically iterates through the store and removes expired keys.
|
|
||||||
func (s *MemoryStore) cleanupLoop(interval time.Duration) {
|
|
||||||
ticker := time.NewTicker(interval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
s.mu.Lock()
|
|
||||||
now := time.Now().UnixNano()
|
|
||||||
for key, item := range s.data {
|
|
||||||
if item.expiresAt > 0 && now > item.expiresAt {
|
|
||||||
delete(s.data, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.mu.Unlock()
|
|
||||||
case <-s.stopCh:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set stores a key-value pair.
|
// Set stores a key-value pair.
|
||||||
func (s *MemoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
func (s *MemoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||||
@@ -77,13 +58,18 @@ func (s *MemoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
|||||||
// Get retrieves a value by its key.
|
// Get retrieves a value by its key.
|
||||||
func (s *MemoryStore) Get(key string) ([]byte, error) {
|
func (s *MemoryStore) Get(key string) ([]byte, error) {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
item, exists := s.data[key]
|
rawItem, exists := s.data[key]
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
item, ok := rawItem.(memoryStoreItem)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
||||||
|
}
|
||||||
|
|
||||||
// Check for expiration
|
// Check for expiration
|
||||||
if item.expiresAt > 0 && time.Now().UnixNano() > item.expiresAt {
|
if item.expiresAt > 0 && time.Now().UnixNano() > item.expiresAt {
|
||||||
// Lazy deletion
|
// Lazy deletion
|
||||||
@@ -107,20 +93,213 @@ func (s *MemoryStore) Delete(key string) error {
|
|||||||
// Exists checks if a key exists.
|
// Exists checks if a key exists.
|
||||||
func (s *MemoryStore) Exists(key string) (bool, error) {
|
func (s *MemoryStore) Exists(key string) (bool, error) {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
item, exists := s.data[key]
|
rawItem, exists := s.data[key]
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if item.expiresAt > 0 && time.Now().UnixNano() > item.expiresAt {
|
// Check for expiration only if it's a simple K/V item
|
||||||
// Lazy deletion
|
if item, ok := rawItem.(memoryStoreItem); ok {
|
||||||
s.mu.Lock()
|
if item.expiresAt > 0 && time.Now().UnixNano() > item.expiresAt {
|
||||||
delete(s.data, key)
|
// Lazy deletion
|
||||||
s.mu.Unlock()
|
s.mu.Lock()
|
||||||
return false, nil
|
delete(s.data, key)
|
||||||
|
s.mu.Unlock()
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNX sets a key-value pair if the key does not already exist.
|
||||||
|
func (s *MemoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// In memory store, we need to manually check for existence and expiration
|
||||||
|
rawItem, exists := s.data[key]
|
||||||
|
if exists {
|
||||||
|
if item, ok := rawItem.(memoryStoreItem); ok {
|
||||||
|
if item.expiresAt == 0 || time.Now().UnixNano() < item.expiresAt {
|
||||||
|
// Key exists and is not expired
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Key exists but is not a simple K/V item, treat as existing
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Key does not exist or is expired, so we can set it.
|
||||||
|
var expiresAt int64
|
||||||
|
if ttl > 0 {
|
||||||
|
expiresAt = time.Now().UnixNano() + ttl.Nanoseconds()
|
||||||
|
}
|
||||||
|
s.data[key] = memoryStoreItem{
|
||||||
|
value: value,
|
||||||
|
expiresAt: expiresAt,
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- HASH operations ---
|
||||||
|
|
||||||
|
func (s *MemoryStore) HSet(key, field string, value any) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
var hash map[string]string
|
||||||
|
rawHash, exists := s.data[key]
|
||||||
|
if !exists {
|
||||||
|
hash = make(map[string]string)
|
||||||
|
s.data[key] = hash
|
||||||
|
} else {
|
||||||
|
var ok bool
|
||||||
|
hash, ok = rawHash.(map[string]string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hash[field] = fmt.Sprint(value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MemoryStore) HGetAll(key string) (map[string]string, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
rawHash, exists := s.data[key]
|
||||||
|
if !exists {
|
||||||
|
// Per Redis convention, HGETALL on a non-existent key returns an empty map, not an error.
|
||||||
|
return make(map[string]string), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hash, ok := rawHash.(map[string]string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a copy to prevent race conditions on the returned map
|
||||||
|
result := make(map[string]string, len(hash))
|
||||||
|
for k, v := range hash {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MemoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
var hash map[string]string
|
||||||
|
rawHash, exists := s.data[key]
|
||||||
|
if !exists {
|
||||||
|
hash = make(map[string]string)
|
||||||
|
s.data[key] = hash
|
||||||
|
} else {
|
||||||
|
var ok bool
|
||||||
|
hash, ok = rawHash.(map[string]string)
|
||||||
|
if !ok {
|
||||||
|
return 0, fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentVal, _ := strconv.ParseInt(hash[field], 10, 64)
|
||||||
|
newVal := currentVal + incr
|
||||||
|
hash[field] = strconv.FormatInt(newVal, 10)
|
||||||
|
|
||||||
|
return newVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- LIST operations ---
|
||||||
|
|
||||||
|
func (s *MemoryStore) LPush(key string, values ...any) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
var list []string
|
||||||
|
rawList, exists := s.data[key]
|
||||||
|
if !exists {
|
||||||
|
list = make([]string, 0)
|
||||||
|
} else {
|
||||||
|
var ok bool
|
||||||
|
list, ok = rawList.([]string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
strValues := make([]string, len(values))
|
||||||
|
for i, v := range values {
|
||||||
|
strValues[i] = fmt.Sprint(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.data[key] = append(strValues, list...) // Prepend
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MemoryStore) LRem(key string, count int64, value any) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
rawList, exists := s.data[key]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
list, ok := rawList.([]string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
strValue := fmt.Sprint(value)
|
||||||
|
newList := make([]string, 0, len(list))
|
||||||
|
|
||||||
|
// LREM with count = 0: Remove all elements equal to value.
|
||||||
|
if count != 0 {
|
||||||
|
// For now, only implement count = 0 behavior as it's what we need.
|
||||||
|
return fmt.Errorf("LRem with non-zero count is not implemented in MemoryStore")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range list {
|
||||||
|
if item != strValue {
|
||||||
|
newList = append(newList, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.data[key] = newList
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MemoryStore) Rotate(key string) (string, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
rawList, exists := s.data[key]
|
||||||
|
if !exists {
|
||||||
|
return "", ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
list, ok := rawList.([]string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(list) == 0 {
|
||||||
|
return "", ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// "RPOP"
|
||||||
|
lastIndex := len(list) - 1
|
||||||
|
item := list[lastIndex]
|
||||||
|
|
||||||
|
// "LPUSH"
|
||||||
|
newList := append([]string{item}, list[:lastIndex]...)
|
||||||
|
s.data[key] = newList
|
||||||
|
|
||||||
|
return item, nil
|
||||||
|
}
|
||||||
|
@@ -49,7 +49,71 @@ func (s *RedisStore) Exists(key string) (bool, error) {
|
|||||||
return val > 0, nil
|
return val > 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNX sets a key-value pair in Redis if the key does not already exist.
|
||||||
|
func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
||||||
|
return s.client.SetNX(context.Background(), key, value, ttl).Result()
|
||||||
|
}
|
||||||
|
|
||||||
// Close closes the Redis client connection.
|
// Close closes the Redis client connection.
|
||||||
func (s *RedisStore) Close() error {
|
func (s *RedisStore) Close() error {
|
||||||
return s.client.Close()
|
return s.client.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- HASH operations ---
|
||||||
|
|
||||||
|
func (s *RedisStore) HSet(key, field string, value any) error {
|
||||||
|
return s.client.HSet(context.Background(), key, field, value).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedisStore) HGetAll(key string) (map[string]string, error) {
|
||||||
|
return s.client.HGetAll(context.Background(), key).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||||
|
return s.client.HIncrBy(context.Background(), key, field, incr).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- LIST operations ---
|
||||||
|
|
||||||
|
func (s *RedisStore) LPush(key string, values ...any) error {
|
||||||
|
return s.client.LPush(context.Background(), key, values...).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedisStore) LRem(key string, count int64, value any) error {
|
||||||
|
return s.client.LRem(context.Background(), key, count, value).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedisStore) Rotate(key string) (string, error) {
|
||||||
|
val, err := s.client.RPopLPush(context.Background(), key, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, redis.Nil) {
|
||||||
|
return "", ErrNotFound
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Pipeliner implementation ---
|
||||||
|
|
||||||
|
type redisPipeliner struct {
|
||||||
|
pipe redis.Pipeliner
|
||||||
|
}
|
||||||
|
|
||||||
|
// HSet adds an HSET command to the pipeline.
|
||||||
|
func (p *redisPipeliner) HSet(key string, values map[string]any) {
|
||||||
|
p.pipe.HSet(context.Background(), key, values)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes all commands in the pipeline.
|
||||||
|
func (p *redisPipeliner) Exec() error {
|
||||||
|
_, err := p.pipe.Exec(context.Background())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pipeline creates a new pipeline.
|
||||||
|
func (s *RedisStore) Pipeline() Pipeliner {
|
||||||
|
return &redisPipeliner{
|
||||||
|
pipe: s.client.Pipeline(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -28,6 +28,31 @@ type Store interface {
|
|||||||
// Exists checks if a key exists in the store.
|
// Exists checks if a key exists in the store.
|
||||||
Exists(key string) (bool, error)
|
Exists(key string) (bool, error)
|
||||||
|
|
||||||
|
// SetNX sets a key-value pair if the key does not already exist.
|
||||||
|
// It returns true if the key was set, false otherwise.
|
||||||
|
SetNX(key string, value []byte, ttl time.Duration) (bool, error)
|
||||||
|
|
||||||
|
// HASH operations
|
||||||
|
HSet(key, field string, value any) error
|
||||||
|
HGetAll(key string) (map[string]string, error)
|
||||||
|
HIncrBy(key, field string, incr int64) (int64, error)
|
||||||
|
|
||||||
|
// LIST operations
|
||||||
|
LPush(key string, values ...any) error
|
||||||
|
LRem(key string, count int64, value any) error
|
||||||
|
Rotate(key string) (string, error)
|
||||||
|
|
||||||
// Close closes the store and releases any underlying resources.
|
// Close closes the store and releases any underlying resources.
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pipeliner defines an interface for executing a batch of commands.
|
||||||
|
type Pipeliner interface {
|
||||||
|
HSet(key string, values map[string]any)
|
||||||
|
Exec() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedisPipeliner is an optional interface that a Store can implement to provide pipelining.
|
||||||
|
type RedisPipeliner interface {
|
||||||
|
Pipeline() Pipeliner
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user