diff --git a/go.mod b/go.mod index ed1e946..dd5dc84 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/gin-contrib/static v1.1.5 github.com/gin-gonic/gin v1.10.1 github.com/go-sql-driver/mysql v1.8.1 + github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/redis/go-redis/v9 v9.5.3 github.com/sirupsen/logrus v1.9.3 diff --git a/go.sum b/go.sum index 0de0a56..679cbd6 100644 --- a/go.sum +++ b/go.sum @@ -49,6 +49,8 @@ github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EO github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= diff --git a/internal/app/app.go b/internal/app/app.go index 8556756..b7dab1f 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -29,6 +29,7 @@ type App struct { settingsManager *config.SystemSettingsManager groupManager *services.GroupManager logCleanupService *services.LogCleanupService + requestLogService *services.RequestLogService keyCronService *services.KeyCronService keyValidationPool *services.KeyValidationPool keyPoolProvider *keypool.KeyProvider @@ -37,7 +38,6 @@ type App struct { storage store.Store db *gorm.DB httpServer *http.Server - requestLogChan chan models.RequestLog wg sync.WaitGroup } @@ -49,6 +49,7 @@ type AppParams struct { SettingsManager *config.SystemSettingsManager GroupManager *services.GroupManager LogCleanupService *services.LogCleanupService + RequestLogService *services.RequestLogService KeyCronService *services.KeyCronService KeyValidationPool *services.KeyValidationPool KeyPoolProvider *keypool.KeyProvider @@ -56,7 +57,6 @@ type AppParams struct { ProxyServer *proxy.ProxyServer Storage store.Store DB *gorm.DB - RequestLogChan chan models.RequestLog } // NewApp is the constructor for App, with dependencies injected by dig. @@ -67,6 +67,7 @@ func NewApp(params AppParams) *App { settingsManager: params.SettingsManager, groupManager: params.GroupManager, logCleanupService: params.LogCleanupService, + requestLogService: params.RequestLogService, keyCronService: params.KeyCronService, keyValidationPool: params.KeyValidationPool, keyPoolProvider: params.KeyPoolProvider, @@ -74,7 +75,6 @@ func NewApp(params AppParams) *App { proxyServer: params.ProxyServer, storage: params.Storage, db: params.DB, - requestLogChan: params.RequestLogChan, } } @@ -139,7 +139,7 @@ func (a *App) Start() error { a.groupManager.Initialize() - a.startRequestLogger() + a.requestLogService.Start() a.logCleanupService.Start() a.keyValidationPool.Start() a.keyCronService.Start() @@ -182,58 +182,10 @@ func (a *App) Stop(ctx context.Context) { a.keyValidationPool.Stop() a.leaderService.Stop() a.logCleanupService.Stop() + a.requestLogService.Stop() a.groupManager.Stop() a.settingsManager.Stop() a.storage.Close() - // Wait for the logger to finish writing all logs - logrus.Info("Closing request log channel...") - close(a.requestLogChan) - a.wg.Wait() - logrus.Info("All logs have been written.") - logrus.Info("Server exited gracefully") } - -// startRequestLogger runs a background goroutine to batch-insert request logs. -func (a *App) startRequestLogger() { - a.wg.Add(1) - go func() { - defer a.wg.Done() - ticker := time.NewTicker(5 * time.Second) - defer ticker.Stop() - - logBuffer := make([]models.RequestLog, 0, 100) - - for { - select { - case logEntry, ok := <-a.requestLogChan: - if !ok { - // Channel closed, flush remaining logs and exit - if len(logBuffer) > 0 { - if err := a.db.Create(&logBuffer).Error; err != nil { - logrus.Errorf("Failed to write remaining request logs: %v", err) - } - } - logrus.Info("Request logger stopped.") - return - } - logBuffer = append(logBuffer, logEntry) - if len(logBuffer) >= 100 { - if err := a.db.Create(&logBuffer).Error; err != nil { - logrus.Errorf("Failed to write request logs: %v", err) - } - logBuffer = make([]models.RequestLog, 0, 100) // Reset buffer - } - case <-ticker.C: - // Flush logs periodically - if len(logBuffer) > 0 { - if err := a.db.Create(&logBuffer).Error; err != nil { - logrus.Errorf("Failed to write request logs on tick: %v", err) - } - logBuffer = make([]models.RequestLog, 0, 100) // Reset buffer - } - } - } - }() -} diff --git a/internal/config/system_settings.go b/internal/config/system_settings.go index d3b67f9..79450cf 100644 --- a/internal/config/system_settings.go +++ b/internal/config/system_settings.go @@ -351,6 +351,7 @@ func (sm *SystemSettingsManager) DisplaySystemConfig(settings types.SystemSettin logrus.Info("--- System Settings ---") logrus.Infof(" App URL: %s", settings.AppUrl) logrus.Infof(" Request Log Retention: %d days", settings.RequestLogRetentionDays) + logrus.Infof(" Request Log Write Interval: %d minutes", settings.RequestLogWriteIntervalMinutes) logrus.Info("--- Request Behavior ---") logrus.Infof(" Request Timeout: %d seconds", settings.RequestTimeout) diff --git a/internal/container/container.go b/internal/container/container.go index 76eb52f..855a756 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -63,6 +63,9 @@ func BuildContainer() (*dig.Container, error) { if err := container.Provide(services.NewLogCleanupService); err != nil { return nil, err } + if err := container.Provide(services.NewRequestLogService); err != nil { + return nil, err + } if err := container.Provide(services.NewGroupManager); err != nil { return nil, err } diff --git a/internal/models/types.go b/internal/models/types.go index ec33a6d..1571ceb 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -72,14 +72,18 @@ type APIKey struct { // RequestLog 对应 request_logs 表 type RequestLog struct { - ID string `gorm:"type:varchar(36);primaryKey" json:"id"` - Timestamp time.Time `gorm:"type:datetime(3);not null" json:"timestamp"` - GroupID uint `gorm:"not null" json:"group_id"` - KeyID uint `gorm:"not null" json:"key_id"` - SourceIP string `gorm:"type:varchar(45)" json:"source_ip"` - StatusCode int `gorm:"not null" json:"status_code"` - RequestPath string `gorm:"type:varchar(1024)" json:"request_path"` - RequestBodySnippet string `gorm:"type:text" json:"request_body_snippet"` + ID string `gorm:"type:varchar(36);primaryKey" json:"id"` + Timestamp time.Time `gorm:"type:datetime(3);not null;index" json:"timestamp"` + GroupID uint `gorm:"not null;index" json:"group_id"` + KeyID uint `gorm:"not null;index" json:"key_id"` + IsSuccess bool `gorm:"not null" json:"is_success"` + SourceIP string `gorm:"type:varchar(45)" json:"source_ip"` + StatusCode int `gorm:"not null" json:"status_code"` + RequestPath string `gorm:"type:varchar(1024)" json:"request_path"` + Duration int64 `gorm:"not null" json:"duration_ms"` + ErrorMessage string `gorm:"type:text" json:"error_message"` + UserAgent string `gorm:"type:varchar(512)" json:"user_agent"` + Retries int `gorm:"not null" json:"retries"` } // GroupRequestStat 用于表示每个分组的请求统计 diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 38e8e14..c0c3f51 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -5,9 +5,11 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" + "strconv" "time" "gpt-load/internal/channel" @@ -26,10 +28,11 @@ import ( // ProxyServer represents the proxy server type ProxyServer struct { - keyProvider *keypool.KeyProvider - groupManager *services.GroupManager - settingsManager *config.SystemSettingsManager - channelFactory *channel.Factory + keyProvider *keypool.KeyProvider + groupManager *services.GroupManager + settingsManager *config.SystemSettingsManager + channelFactory *channel.Factory + requestLogService *services.RequestLogService } // NewProxyServer creates a new proxy server @@ -38,12 +41,14 @@ func NewProxyServer( groupManager *services.GroupManager, settingsManager *config.SystemSettingsManager, channelFactory *channel.Factory, + requestLogService *services.RequestLogService, ) (*ProxyServer, error) { return &ProxyServer{ - keyProvider: keyProvider, - groupManager: groupManager, - settingsManager: settingsManager, - channelFactory: channelFactory, + keyProvider: keyProvider, + groupManager: groupManager, + settingsManager: settingsManager, + channelFactory: channelFactory, + requestLogService: requestLogService, }, nil } @@ -105,9 +110,13 @@ func (ps *ProxyServer) executeRequestWithRetry( response.Error(c, app_errors.NewAPIErrorWithUpstream(lastError.StatusCode, "UPSTREAM_ERROR", lastError.ErrorMessage)) } logrus.Debugf("Max retries exceeded for group %s after %d attempts. Parsed Error: %s", group.Name, retryCount, lastError.ErrorMessage) + + keyID, _ := strconv.ParseUint(lastError.KeyID, 10, 64) + ps.logRequest(c, group, uint(keyID), startTime, lastError.StatusCode, retryCount, errors.New(lastError.ErrorMessage)) } else { response.Error(c, app_errors.ErrMaxRetriesExceeded) logrus.Debugf("Max retries exceeded for group %s after %d attempts.", group.Name, retryCount) + ps.logRequest(c, group, 0, startTime, http.StatusServiceUnavailable, retryCount, app_errors.ErrMaxRetriesExceeded) } return } @@ -116,6 +125,7 @@ func (ps *ProxyServer) executeRequestWithRetry( if err != nil { logrus.Errorf("Failed to select a key for group %s on attempt %d: %v", group.Name, retryCount+1, err) response.Error(c, app_errors.NewAPIError(app_errors.ErrNoKeysAvailable, err.Error())) + ps.logRequest(c, group, 0, startTime, http.StatusServiceUnavailable, retryCount, err) return } @@ -163,6 +173,7 @@ func (ps *ProxyServer) executeRequestWithRetry( if err != nil || (resp != nil && resp.StatusCode >= 400) { if err != nil && app_errors.IsIgnorableError(err) { logrus.Debugf("Client-side ignorable error for key %s, aborting retries: %v", utils.MaskAPIKey(apiKey.KeyValue), err) + ps.logRequest(c, group, apiKey.ID, startTime, 499, retryCount+1, err) return } @@ -203,6 +214,7 @@ func (ps *ProxyServer) executeRequestWithRetry( ps.keyProvider.UpdateStatus(apiKey, group, true) logrus.Debugf("Request for group %s succeeded on attempt %d with key %s", group.Name, retryCount+1, utils.MaskAPIKey(apiKey.KeyValue)) + ps.logRequest(c, group, apiKey.ID, startTime, resp.StatusCode, retryCount+1, nil) for key, values := range resp.Header { for _, value := range values { @@ -217,3 +229,40 @@ func (ps *ProxyServer) executeRequestWithRetry( ps.handleNormalResponse(c, resp) } } + +// logRequest is a helper function to create and record a request log. +func (ps *ProxyServer) logRequest( + c *gin.Context, + group *models.Group, + keyID uint, + startTime time.Time, + statusCode int, + retries int, + finalError error, +) { + if ps.requestLogService == nil { + return + } + + duration := time.Since(startTime).Milliseconds() + + logEntry := &models.RequestLog{ + GroupID: group.ID, + KeyID: keyID, + IsSuccess: finalError == nil && statusCode < 400, + SourceIP: c.ClientIP(), + StatusCode: statusCode, + RequestPath: c.Request.URL.String(), + Duration: duration, + UserAgent: c.Request.UserAgent(), + Retries: retries, + } + + if finalError != nil { + logEntry.ErrorMessage = finalError.Error() + } + + if err := ps.requestLogService.Record(logEntry); err != nil { + logrus.Errorf("Failed to record request log: %v", err) + } +} diff --git a/internal/services/request_log_service.go b/internal/services/request_log_service.go new file mode 100644 index 0000000..b66f3fc --- /dev/null +++ b/internal/services/request_log_service.go @@ -0,0 +1,227 @@ +package services + +import ( + "context" + "encoding/json" + "fmt" + "gpt-load/internal/config" + "gpt-load/internal/models" + "gpt-load/internal/store" + "strings" + "time" + + "github.com/google/uuid" + "github.com/sirupsen/logrus" + "gorm.io/gorm" +) + +const ( + RequestLogCachePrefix = "request_log:" + PendingLogKeysSet = "pending_log_keys" + DefaultLogFlushBatchSize = 500 +) + +// RequestLogService is responsible for managing request logs. +type RequestLogService struct { + db *gorm.DB + store store.Store + settingsManager *config.SystemSettingsManager + leaderService *LeaderService + ctx context.Context + cancel context.CancelFunc + ticker *time.Ticker +} + +// NewRequestLogService creates a new RequestLogService instance +func NewRequestLogService(db *gorm.DB, store store.Store, sm *config.SystemSettingsManager, ls *LeaderService) *RequestLogService { + ctx, cancel := context.WithCancel(context.Background()) + return &RequestLogService{ + db: db, + store: store, + settingsManager: sm, + leaderService: ls, + ctx: ctx, + cancel: cancel, + } +} + +// Start initializes the service and starts the periodic flush routine +func (s *RequestLogService) Start() { + go s.flush() + + interval := time.Duration(s.settingsManager.GetSettings().RequestLogWriteIntervalMinutes) * time.Minute + if interval <= 0 { + interval = time.Minute + } + s.ticker = time.NewTicker(interval) + + go func() { + for { + select { + case <-s.ticker.C: + newInterval := time.Duration(s.settingsManager.GetSettings().RequestLogWriteIntervalMinutes) * time.Minute + if newInterval <= 0 { + newInterval = time.Minute + } + if newInterval != interval { + s.ticker.Reset(newInterval) + interval = newInterval + logrus.Debugf("Request log write interval updated to: %v", interval) + } + s.flush() + case <-s.ctx.Done(): + s.ticker.Stop() + logrus.Info("RequestLogService stopped.") + return + } + } + }() +} + +// Stop gracefully stops the RequestLogService +func (s *RequestLogService) Stop() { + s.flush() + s.cancel() +} + +// Record logs a request to the database and cache +func (s *RequestLogService) Record(log *models.RequestLog) error { + log.ID = uuid.NewString() + log.Timestamp = time.Now() + + if s.settingsManager.GetSettings().RequestLogWriteIntervalMinutes == 0 { + return s.writeLogsToDB([]*models.RequestLog{log}) + } + + cacheKey := RequestLogCachePrefix + log.ID + + logBytes, err := json.Marshal(log) + if err != nil { + return fmt.Errorf("failed to marshal request log: %w", err) + } + + ttl := time.Duration(s.settingsManager.GetSettings().RequestLogWriteIntervalMinutes*5) * time.Minute + if err := s.store.Set(cacheKey, logBytes, ttl); err != nil { + return err + } + + return s.store.SAdd(PendingLogKeysSet, cacheKey) +} + +// flush data from cache to database +func (s *RequestLogService) flush() { + if s.settingsManager.GetSettings().RequestLogWriteIntervalMinutes == 0 { + logrus.Debug("Sync mode enabled, skipping scheduled log flush.") + return + } + + if !s.leaderService.IsLeader() { + logrus.Debug("Not a leader, skipping log flush.") + return + } + + logrus.Debug("Leader starting to flush request logs...") + + for { + keys, err := s.store.SPopN(PendingLogKeysSet, DefaultLogFlushBatchSize) + if err != nil { + logrus.Errorf("Failed to pop pending log keys from store: %v", err) + return + } + + if len(keys) == 0 { + logrus.Debug("No more request logs to flush in this cycle.") + return + } + + logrus.Debugf("Popped %d request logs to flush.", len(keys)) + + var logs []*models.RequestLog + var processedKeys []string + for _, key := range keys { + logBytes, err := s.store.Get(key) + if err != nil { + if err == store.ErrNotFound { + logrus.Warnf("Log key %s found in set but not in store, skipping.", key) + } else { + logrus.Warnf("Failed to get log for key %s: %v", key, err) + } + continue + } + var log models.RequestLog + if err := json.Unmarshal(logBytes, &log); err != nil { + logrus.Warnf("Failed to unmarshal log for key %s: %v", key, err) + continue + } + logs = append(logs, &log) + processedKeys = append(processedKeys, key) + } + + if len(logs) == 0 { + continue + } + + err = s.writeLogsToDB(logs) + + if err != nil { + logrus.Errorf("Failed to flush request logs batch, will retry next time. Error: %v", err) + if len(keys) > 0 { + keysToRetry := make([]any, len(keys)) + for i, k := range keys { + keysToRetry[i] = k + } + if saddErr := s.store.SAdd(PendingLogKeysSet, keysToRetry...); saddErr != nil { + logrus.Errorf("CRITICAL: Failed to re-add failed log keys to set: %v", saddErr) + } + } + return + } + + if len(processedKeys) > 0 { + if err := s.store.Del(processedKeys...); err != nil { + logrus.Errorf("Failed to delete flushed log bodies from store: %v", err) + } + } + logrus.Infof("Successfully flushed %d request logs.", len(logs)) + } +} + +// writeLogsToDB writes a batch of request logs to the database +func (s *RequestLogService) writeLogsToDB(logs []*models.RequestLog) error { + if len(logs) == 0 { + return nil + } + + return s.db.Transaction(func(tx *gorm.DB) error { + if err := tx.CreateInBatches(logs, len(logs)).Error; err != nil { + return fmt.Errorf("failed to batch insert request logs: %w", err) + } + + keyStats := make(map[uint]int64) + for _, log := range logs { + if log.IsSuccess { + keyStats[log.KeyID]++ + } + } + + if len(keyStats) > 0 { + var caseStmt strings.Builder + var keyIDs []uint + caseStmt.WriteString("CASE id ") + for keyID, count := range keyStats { + caseStmt.WriteString(fmt.Sprintf("WHEN %d THEN request_count + %d ", keyID, count)) + keyIDs = append(keyIDs, keyID) + } + caseStmt.WriteString("END") + + if err := tx.Model(&models.APIKey{}).Where("id IN ?", keyIDs). + Updates(map[string]any{ + "request_count": gorm.Expr(caseStmt.String()), + "last_used_at": time.Now(), + }).Error; err != nil { + return fmt.Errorf("failed to batch update api_key stats: %w", err) + } + } + return nil + }) +} diff --git a/internal/store/memory.go b/internal/store/memory.go index 574ca06..76b54a3 100644 --- a/internal/store/memory.go +++ b/internal/store/memory.go @@ -14,12 +14,9 @@ type memoryStoreItem struct { } // MemoryStore is an in-memory key-value store that is safe for concurrent use. -// It now supports simple K/V, HASH, and LIST data types. type MemoryStore struct { - mu sync.RWMutex - data map[string]any - - // For Pub/Sub + mu sync.RWMutex + data map[string]any muSubscribers sync.RWMutex subscribers map[string]map[chan *Message]struct{} } @@ -30,14 +27,11 @@ func NewMemoryStore() *MemoryStore { data: make(map[string]any), subscribers: make(map[string]map[chan *Message]struct{}), } - // The cleanup loop was removed as it's not compatible with multiple data types - // without a unified expiration mechanism, and the KeyPool feature does not rely on TTLs. return s } // Close cleans up resources. func (s *MemoryStore) Close() error { - // Nothing to close for now. return nil } @@ -73,9 +67,7 @@ func (s *MemoryStore) Get(key string) ([]byte, error) { return nil, fmt.Errorf("type mismatch: key '%s' holds a different data type", key) } - // Check for expiration if item.expiresAt > 0 && time.Now().UnixNano() > item.expiresAt { - // Lazy deletion s.mu.Lock() delete(s.data, key) s.mu.Unlock() @@ -93,6 +85,16 @@ func (s *MemoryStore) Delete(key string) error { return nil } +// Del removes multiple values by their keys. +func (s *MemoryStore) Del(keys ...string) error { + s.mu.Lock() + defer s.mu.Unlock() + for _, key := range keys { + delete(s.data, key) + } + return nil +} + // Exists checks if a key exists. func (s *MemoryStore) Exists(key string) (bool, error) { s.mu.RLock() @@ -103,10 +105,8 @@ func (s *MemoryStore) Exists(key string) (bool, error) { return false, nil } - // Check for expiration only if it's a simple K/V item if item, ok := rawItem.(memoryStoreItem); ok { if item.expiresAt > 0 && time.Now().UnixNano() > item.expiresAt { - // Lazy deletion s.mu.Lock() delete(s.data, key) s.mu.Unlock() @@ -122,12 +122,10 @@ func (s *MemoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, s.mu.Lock() defer s.mu.Unlock() - // In memory store, we need to manually check for existence and expiration rawItem, exists := s.data[key] if exists { if item, ok := rawItem.(memoryStoreItem); ok { if item.expiresAt == 0 || time.Now().UnixNano() < item.expiresAt { - // Key exists and is not expired return false, nil } } else { @@ -179,7 +177,6 @@ func (s *MemoryStore) HGetAll(key string) (map[string]string, error) { rawHash, exists := s.data[key] if !exists { - // Per Redis convention, HGETALL on a non-existent key returns an empty map, not an error. return make(map[string]string), nil } @@ -188,7 +185,6 @@ func (s *MemoryStore) HGetAll(key string) (map[string]string, error) { return nil, fmt.Errorf("type mismatch: key '%s' holds a different data type", key) } - // Return a copy to prevent race conditions on the returned map result := make(map[string]string, len(hash)) for k, v := range hash { result[k] = v @@ -265,9 +261,7 @@ func (s *MemoryStore) LRem(key string, count int64, value any) error { strValue := fmt.Sprint(value) newList := make([]string, 0, len(list)) - // LREM with count = 0: Remove all elements equal to value. if count != 0 { - // For now, only implement count = 0 behavior as it's what we need. return fmt.Errorf("LRem with non-zero count is not implemented in MemoryStore") } @@ -298,7 +292,6 @@ func (s *MemoryStore) Rotate(key string) (string, error) { return "", ErrNotFound } - // "RPOP" lastIndex := len(list) - 1 item := list[lastIndex] @@ -309,6 +302,63 @@ func (s *MemoryStore) Rotate(key string) (string, error) { return item, nil } +// --- SET operations --- + +// SAdd adds members to a set. +func (s *MemoryStore) SAdd(key string, members ...any) error { + s.mu.Lock() + defer s.mu.Unlock() + + var set map[string]struct{} + rawSet, exists := s.data[key] + if !exists { + set = make(map[string]struct{}) + s.data[key] = set + } else { + var ok bool + set, ok = rawSet.(map[string]struct{}) + if !ok { + return fmt.Errorf("type mismatch: key '%s' holds a different data type", key) + } + } + + for _, member := range members { + set[fmt.Sprint(member)] = struct{}{} + } + return nil +} + +// SPopN randomly removes and returns the given number of members from a set. +func (s *MemoryStore) SPopN(key string, count int64) ([]string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + rawSet, exists := s.data[key] + if !exists { + return []string{}, nil + } + + set, ok := rawSet.(map[string]struct{}) + if !ok { + return nil, fmt.Errorf("type mismatch: key '%s' holds a different data type", key) + } + + if count > int64(len(set)) { + count = int64(len(set)) + } + + popped := make([]string, 0, count) + for member := range set { + if int64(len(popped)) >= count { + break + } + popped = append(popped, member) + delete(set, member) + } + + return popped, nil +} + // --- Pub/Sub operations --- // memorySubscription implements the Subscription interface for the in-memory store. @@ -350,11 +400,10 @@ func (s *MemoryStore) Publish(channel string, message []byte) error { if subs, ok := s.subscribers[channel]; ok { for subCh := range subs { - // Non-blocking send go func(c chan *Message) { select { case c <- msg: - case <-time.After(1 * time.Second): // Prevent goroutine leak if receiver is stuck + case <-time.After(1 * time.Second): } }(subCh) } diff --git a/internal/store/redis.go b/internal/store/redis.go index 1903256..36813c2 100644 --- a/internal/store/redis.go +++ b/internal/store/redis.go @@ -42,6 +42,14 @@ func (s *RedisStore) Delete(key string) error { return s.client.Del(context.Background(), key).Err() } +// Del removes multiple values from Redis. +func (s *RedisStore) Del(keys ...string) error { + if len(keys) == 0 { + return nil + } + return s.client.Del(context.Background(), keys...).Err() +} + // Exists checks if a key exists in Redis. func (s *RedisStore) Exists(key string) (bool, error) { val, err := s.client.Exists(context.Background(), key).Result() @@ -96,6 +104,16 @@ func (s *RedisStore) Rotate(key string) (string, error) { return val, nil } +// --- SET operations --- + +func (s *RedisStore) SAdd(key string, members ...any) error { + return s.client.SAdd(context.Background(), key, members...).Err() +} + +func (s *RedisStore) SPopN(key string, count int64) ([]string, error) { + return s.client.SPopN(context.Background(), key, count).Result() +} + // --- Pipeliner implementation --- type redisPipeliner struct { @@ -121,7 +139,7 @@ func (s *RedisStore) Pipeline() Pipeliner { } // Eval executes a Lua script on Redis. -func (s *RedisStore) Eval(script string, keys []string, args ...interface{}) (interface{}, error) { +func (s *RedisStore) Eval(script string, keys []string, args ...any) (any, error) { return s.client.Eval(context.Background(), script, keys, args...).Result() } @@ -165,7 +183,6 @@ func (s *RedisStore) Publish(channel string, message []byte) error { func (s *RedisStore) Subscribe(channel string) (Subscription, error) { pubsub := s.client.Subscribe(context.Background(), channel) - // Wait for confirmation that subscription is created. _, err := pubsub.Receive(context.Background()) if err != nil { return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err) diff --git a/internal/store/store.go b/internal/store/store.go index 0fba62a..9a045e6 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -16,34 +16,28 @@ type Message struct { // Subscription represents an active subscription to a pub/sub channel. type Subscription interface { - // Channel returns the channel for receiving messages. Channel() <-chan *Message - // Close unsubscribes and releases any resources associated with the subscription. Close() error } // Store is a generic key-value store interface. -// Implementations of this interface must be safe for concurrent use. type Store interface { // Set stores a key-value pair with an optional TTL. - // - key: The key (string). - // - value: The value ([]byte). - // - ttl: The expiration time. If ttl is 0, the key never expires. Set(key string, value []byte, ttl time.Duration) error // Get retrieves a value by its key. - // It must return store.ErrNotFound if the key does not exist. Get(key string) ([]byte, error) // Delete removes a value by its key. - // If the key does not exist, this operation should be considered successful (idempotent) and not return an error. Delete(key string) error + // Del deletes multiple keys. + Del(keys ...string) error + // Exists checks if a key exists in the store. Exists(key string) (bool, error) // SetNX sets a key-value pair if the key does not already exist. - // It returns true if the key was set, false otherwise. SetNX(key string, value []byte, ttl time.Duration) (bool, error) // HASH operations @@ -56,6 +50,10 @@ type Store interface { LRem(key string, count int64, value any) error Rotate(key string) (string, error) + // SET operations + SAdd(key string, members ...any) error + SPopN(key string, count int64) ([]string, error) + // Close closes the store and releases any underlying resources. Close() error @@ -63,7 +61,6 @@ type Store interface { Publish(channel string, message []byte) error // Subscribe listens for messages on a given channel. - // It returns a Subscription object that can be used to receive messages and to close the subscription. Subscribe(channel string) (Subscription, error) } @@ -80,5 +77,5 @@ type RedisPipeliner interface { // LuaScripter is an optional interface that a Store can implement to provide Lua script execution. type LuaScripter interface { - Eval(script string, keys []string, args ...interface{}) (interface{}, error) + Eval(script string, keys []string, args ...any) (any, error) } diff --git a/internal/types/types.go b/internal/types/types.go index 39f7d4c..fe6e5a7 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -17,8 +17,9 @@ type ConfigManager interface { // SystemSettings 定义所有系统配置项 type SystemSettings struct { // 基础参数 - AppUrl string `json:"app_url" default:"http://localhost:3000" name:"项目地址" category:"基础参数" desc:"项目的基础 URL,用于拼接分组终端节点地址。系统配置优先于环境变量 APP_URL。"` - RequestLogRetentionDays int `json:"request_log_retention_days" default:"7" name:"日志保留时长(天)" category:"基础参数" desc:"请求日志在数据库中的保留天数" validate:"min=0"` + AppUrl string `json:"app_url" default:"http://localhost:3000" name:"项目地址" category:"基础参数" desc:"项目的基础 URL,用于拼接分组终端节点地址。系统配置优先于环境变量 APP_URL。"` + RequestLogRetentionDays int `json:"request_log_retention_days" default:"7" name:"日志保留时长(天)" category:"基础参数" desc:"请求日志在数据库中的保留天数,0为不清理日志。" validate:"min=0"` + RequestLogWriteIntervalMinutes int `json:"request_log_write_interval_minutes" default:"5" name:"日志延迟写入周期(分钟)" category:"基础参数" desc:"请求日志从缓存写入数据库的周期(分钟),0为实时写入数据。" validate:"min=0"` // 请求设置 RequestTimeout int `json:"request_timeout" default:"600" name:"请求超时(秒)" category:"请求设置" desc:"转发请求的完整生命周期超时(秒),包括连接、重试等。" validate:"min=1"` @@ -29,9 +30,9 @@ type SystemSettings struct { MaxIdleConnsPerHost int `json:"max_idle_conns_per_host" default:"50" name:"每主机最大空闲连接数" category:"请求设置" desc:"HTTP 客户端连接池对每个上游主机允许的最大空闲连接数。" validate:"min=1"` // 密钥配置 - MaxRetries int `json:"max_retries" default:"3" name:"最大重试次数" category:"密钥配置" desc:"单个请求使用不同 Key 的最大重试次数" validate:"min=0"` - BlacklistThreshold int `json:"blacklist_threshold" default:"3" name:"黑名单阈值" category:"密钥配置" desc:"一个 Key 连续失败多少次后进入黑名单" validate:"min=0"` - KeyValidationIntervalMinutes int `json:"key_validation_interval_minutes" default:"60" name:"定时验证周期(分钟)" category:"密钥配置" desc:"后台定时验证密钥的默认周期(分钟)" validate:"min=30"` + MaxRetries int `json:"max_retries" default:"3" name:"最大重试次数" category:"密钥配置" desc:"单个请求使用不同 Key 的最大重试次数,0为不重试。" validate:"min=0"` + BlacklistThreshold int `json:"blacklist_threshold" default:"3" name:"黑名单阈值" category:"密钥配置" desc:"一个 Key 连续失败多少次后进入黑名单,0为不拉黑。" validate:"min=0"` + KeyValidationIntervalMinutes int `json:"key_validation_interval_minutes" default:"60" name:"定时验证周期(分钟)" category:"密钥配置" desc:"后台定时验证密钥的默认周期(分钟)。" validate:"min=30"` } // ServerConfig represents server configuration diff --git a/main.go b/main.go index dfb910b..cbb6941 100644 --- a/main.go +++ b/main.go @@ -11,7 +11,6 @@ import ( "gpt-load/internal/app" "gpt-load/internal/container" - "gpt-load/internal/models" "gpt-load/internal/types" "gpt-load/internal/utils" @@ -39,12 +38,6 @@ func main() { logrus.Fatalf("Failed to provide indexPage: %v", err) } - // Provide the request log channel as a value - requestLogChan := make(chan models.RequestLog, 1000) - if err := container.Provide(func() chan models.RequestLog { return requestLogChan }); err != nil { - logrus.Fatalf("Failed to provide request log channel: %v", err) - } - // Initialzie global logger if err := container.Invoke(func(configManager types.ConfigManager) { utils.SetupLogger(configManager)