Refactor configuration and key management
- Removed key management configuration from .env.example and related code. - Updated Makefile to load environment variables for HOST and PORT. - Modified main.go to handle request logging with a wait group for graceful shutdown. - Simplified dashboard statistics handler to focus on key counts and request metrics. - Removed key manager implementation and related interfaces. - Updated proxy server to use atomic counters for round-robin key selection. - Cleaned up unused types and configurations in types.go. - Added package-lock.json for frontend dependencies.
This commit is contained in:
@@ -45,7 +45,6 @@ type Manager struct {
|
||||
// Config represents the application configuration
|
||||
type Config struct {
|
||||
Server types.ServerConfig `json:"server"`
|
||||
Keys types.KeysConfig `json:"keys"`
|
||||
OpenAI types.OpenAIConfig `json:"openai"`
|
||||
Auth types.AuthConfig `json:"auth"`
|
||||
CORS types.CORSConfig `json:"cors"`
|
||||
@@ -78,12 +77,6 @@ func (m *Manager) ReloadConfig() error {
|
||||
IdleTimeout: parseInteger(os.Getenv("SERVER_IDLE_TIMEOUT"), 120),
|
||||
GracefulShutdownTimeout: parseInteger(os.Getenv("SERVER_GRACEFUL_SHUTDOWN_TIMEOUT"), 60),
|
||||
},
|
||||
Keys: types.KeysConfig{
|
||||
FilePath: getEnvOrDefault("KEYS_FILE", "keys.txt"),
|
||||
StartIndex: parseInteger(os.Getenv("START_INDEX"), 0),
|
||||
BlacklistThreshold: parseInteger(os.Getenv("BLACKLIST_THRESHOLD"), 1),
|
||||
MaxRetries: parseInteger(os.Getenv("MAX_RETRIES"), 3),
|
||||
},
|
||||
OpenAI: types.OpenAIConfig{
|
||||
BaseURLs: parseArray(os.Getenv("OPENAI_BASE_URL"), []string{"https://api.openai.com"}),
|
||||
RequestTimeout: parseInteger(os.Getenv("REQUEST_TIMEOUT"), DefaultConstants.DefaultTimeout),
|
||||
@@ -131,11 +124,6 @@ func (m *Manager) GetServerConfig() types.ServerConfig {
|
||||
return m.config.Server
|
||||
}
|
||||
|
||||
// GetKeysConfig returns keys configuration
|
||||
func (m *Manager) GetKeysConfig() types.KeysConfig {
|
||||
return m.config.Keys
|
||||
}
|
||||
|
||||
// GetOpenAIConfig returns OpenAI configuration
|
||||
func (m *Manager) GetOpenAIConfig() types.OpenAIConfig {
|
||||
config := m.config.OpenAI
|
||||
@@ -178,16 +166,6 @@ func (m *Manager) Validate() error {
|
||||
validationErrors = append(validationErrors, fmt.Sprintf("port must be between %d-%d", DefaultConstants.MinPort, DefaultConstants.MaxPort))
|
||||
}
|
||||
|
||||
// Validate start index
|
||||
if m.config.Keys.StartIndex < 0 {
|
||||
validationErrors = append(validationErrors, "start index cannot be less than 0")
|
||||
}
|
||||
|
||||
// Validate blacklist threshold
|
||||
if m.config.Keys.BlacklistThreshold < 1 {
|
||||
validationErrors = append(validationErrors, "blacklist threshold cannot be less than 1")
|
||||
}
|
||||
|
||||
// Validate timeout
|
||||
if m.config.OpenAI.RequestTimeout < DefaultConstants.MinTimeout {
|
||||
validationErrors = append(validationErrors, fmt.Sprintf("request timeout cannot be less than %ds", DefaultConstants.MinTimeout))
|
||||
@@ -223,10 +201,6 @@ func (m *Manager) Validate() error {
|
||||
func (m *Manager) DisplayConfig() {
|
||||
logrus.Info("Current Configuration:")
|
||||
logrus.Infof(" Server: %s:%d", m.config.Server.Host, m.config.Server.Port)
|
||||
logrus.Infof(" Keys file: %s", m.config.Keys.FilePath)
|
||||
logrus.Infof(" Start index: %d", m.config.Keys.StartIndex)
|
||||
logrus.Infof(" Blacklist threshold: %d errors", m.config.Keys.BlacklistThreshold)
|
||||
logrus.Infof(" Max retries: %d", m.config.Keys.MaxRetries)
|
||||
logrus.Infof(" Upstream URLs: %s", strings.Join(m.config.OpenAI.BaseURLs, ", "))
|
||||
logrus.Infof(" Request timeout: %ds", m.config.OpenAI.RequestTimeout)
|
||||
logrus.Infof(" Response timeout: %ds", m.config.OpenAI.ResponseTimeout)
|
||||
|
@@ -1,54 +1,39 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gpt-load/internal/db"
|
||||
"gpt-load/internal/models"
|
||||
"gpt-load/internal/response"
|
||||
)
|
||||
|
||||
// GetDashboardStats godoc
|
||||
// @Summary Get dashboard statistics
|
||||
// @Description Get statistics for the dashboard, including total requests, success rate, and group distribution.
|
||||
// @Description Get statistics for the dashboard, including key counts and request metrics.
|
||||
// @Tags Dashboard
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} models.DashboardStats
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/dashboard/stats [get]
|
||||
func GetDashboardStats(c *gin.Context) {
|
||||
func (s *Server) Stats(c *gin.Context) {
|
||||
var totalRequests, successRequests int64
|
||||
var groupStats []models.GroupRequestStat
|
||||
|
||||
// Get total requests
|
||||
if err := db.DB.Model(&models.RequestLog{}).Count(&totalRequests).Error; err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get total requests")
|
||||
return
|
||||
}
|
||||
// 1. Get total and successful requests from the api_keys table
|
||||
s.DB.Model(&models.APIKey{}).Select("SUM(request_count)").Row().Scan(&totalRequests)
|
||||
s.DB.Model(&models.APIKey{}).Select("SUM(request_count) - SUM(failure_count)").Row().Scan(&successRequests)
|
||||
|
||||
// Get success requests (status code 2xx)
|
||||
if err := db.DB.Model(&models.RequestLog{}).Where("status_code >= ? AND status_code < ?", 200, 300).Count(&successRequests).Error; err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get success requests")
|
||||
return
|
||||
}
|
||||
// 2. Get request counts per group
|
||||
s.DB.Table("api_keys").
|
||||
Select("groups.name as group_name, SUM(api_keys.request_count) as request_count").
|
||||
Joins("join groups on groups.id = api_keys.group_id").
|
||||
Group("groups.name").
|
||||
Order("request_count DESC").
|
||||
Scan(&groupStats)
|
||||
|
||||
// Calculate success rate
|
||||
// 3. Calculate success rate
|
||||
var successRate float64
|
||||
if totalRequests > 0 {
|
||||
successRate = float64(successRequests) / float64(totalRequests)
|
||||
}
|
||||
|
||||
// Get group stats
|
||||
err := db.DB.Table("request_logs").
|
||||
Select("groups.name as group_name, count(request_logs.id) as request_count").
|
||||
Joins("join groups on groups.id = request_logs.group_id").
|
||||
Group("groups.name").
|
||||
Order("request_count desc").
|
||||
Scan(&groupStats).Error
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get group stats")
|
||||
return
|
||||
successRate = float64(successRequests) / float64(totalRequests) * 100
|
||||
}
|
||||
|
||||
stats := models.DashboardStats{
|
||||
|
@@ -3,7 +3,6 @@ package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"gpt-load/internal/models"
|
||||
@@ -53,7 +52,7 @@ func (s *Server) RegisterAPIRoutes(api *gin.RouterGroup) {
|
||||
// Dashboard and logs routes
|
||||
dashboard := api.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/stats", GetDashboardStats)
|
||||
dashboard.GET("/stats", s.Stats)
|
||||
}
|
||||
|
||||
api.GET("/logs", GetLogs)
|
||||
@@ -101,53 +100,6 @@ func (s *Server) Health(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// Stats handles statistics requests
|
||||
func (s *Server) Stats(c *gin.Context) {
|
||||
var totalKeys, healthyKeys, disabledKeys int64
|
||||
s.DB.Model(&models.APIKey{}).Count(&totalKeys)
|
||||
s.DB.Model(&models.APIKey{}).Where("status = ?", "active").Count(&healthyKeys)
|
||||
s.DB.Model(&models.APIKey{}).Where("status != ?", "active").Count(&disabledKeys)
|
||||
|
||||
// TODO: Get request counts from the database
|
||||
var successCount, failureCount int64
|
||||
s.DB.Model(&models.RequestLog{}).Where("status_code = ?", http.StatusOK).Count(&successCount)
|
||||
s.DB.Model(&models.RequestLog{}).Where("status_code != ?", http.StatusOK).Count(&failureCount)
|
||||
|
||||
// Add additional system information
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
response := gin.H{
|
||||
"keys": gin.H{
|
||||
"total": totalKeys,
|
||||
"healthy": healthyKeys,
|
||||
"disabled": disabledKeys,
|
||||
},
|
||||
"requests": gin.H{
|
||||
"success_count": successCount,
|
||||
"failure_count": failureCount,
|
||||
"total_count": successCount + failureCount,
|
||||
},
|
||||
"memory": gin.H{
|
||||
"alloc_mb": bToMb(m.Alloc),
|
||||
"total_alloc_mb": bToMb(m.TotalAlloc),
|
||||
"sys_mb": bToMb(m.Sys),
|
||||
"num_gc": m.NumGC,
|
||||
"last_gc": time.Unix(0, int64(m.LastGC)).Format("2006-01-02 15:04:05"),
|
||||
"next_gc_mb": bToMb(m.NextGC),
|
||||
},
|
||||
"system": gin.H{
|
||||
"goroutines": runtime.NumGoroutine(),
|
||||
"cpu_count": runtime.NumCPU(),
|
||||
"go_version": runtime.Version(),
|
||||
},
|
||||
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
|
||||
// MethodNotAllowed handles 405 requests
|
||||
func (s *Server) MethodNotAllowed(c *gin.Context) {
|
||||
c.JSON(http.StatusMethodNotAllowed, gin.H{
|
||||
@@ -169,7 +121,6 @@ func (s *Server) GetConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
serverConfig := s.config.GetServerConfig()
|
||||
keysConfig := s.config.GetKeysConfig()
|
||||
openaiConfig := s.config.GetOpenAIConfig()
|
||||
authConfig := s.config.GetAuthConfig()
|
||||
corsConfig := s.config.GetCORSConfig()
|
||||
@@ -182,12 +133,6 @@ func (s *Server) GetConfig(c *gin.Context) {
|
||||
"host": serverConfig.Host,
|
||||
"port": serverConfig.Port,
|
||||
},
|
||||
"keys": gin.H{
|
||||
"file_path": keysConfig.FilePath,
|
||||
"start_index": keysConfig.StartIndex,
|
||||
"blacklist_threshold": keysConfig.BlacklistThreshold,
|
||||
"max_retries": keysConfig.MaxRetries,
|
||||
},
|
||||
"openai": gin.H{
|
||||
"base_url": openaiConfig.BaseURL,
|
||||
"request_timeout": openaiConfig.RequestTimeout,
|
||||
@@ -230,8 +175,3 @@ func (s *Server) GetConfig(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, sanitizedConfig)
|
||||
}
|
||||
|
||||
// Helper function to convert bytes to megabytes
|
||||
func bToMb(b uint64) uint64 {
|
||||
return b / 1024 / 1024
|
||||
}
|
||||
|
@@ -1,336 +0,0 @@
|
||||
// Package keymanager provides high-performance API key management
|
||||
package keymanager
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gpt-load/internal/errors"
|
||||
"gpt-load/internal/types"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Manager implements the KeyManager interface
|
||||
type Manager struct {
|
||||
keysFilePath string
|
||||
keys []string
|
||||
keyPreviews []string
|
||||
currentIndex int64
|
||||
blacklistedKeys sync.Map
|
||||
successCount int64
|
||||
failureCount int64
|
||||
keyFailureCounts sync.Map
|
||||
config types.KeysConfig
|
||||
|
||||
// Performance optimization: pre-compiled regex patterns
|
||||
permanentErrorPatterns []*regexp.Regexp
|
||||
|
||||
// Memory management
|
||||
cleanupTicker *time.Ticker
|
||||
stopCleanup chan bool
|
||||
|
||||
// Read-write lock to protect key list
|
||||
keysMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewManager creates a new key manager
|
||||
func NewManager(config types.KeysConfig) (types.KeyManager, error) {
|
||||
if config.FilePath == "" {
|
||||
return nil, errors.NewAppError(errors.ErrKeyFileNotFound, "Keys file path is required")
|
||||
}
|
||||
|
||||
km := &Manager{
|
||||
keysFilePath: config.FilePath,
|
||||
currentIndex: int64(config.StartIndex),
|
||||
stopCleanup: make(chan bool),
|
||||
config: config,
|
||||
|
||||
// Pre-compile regex patterns
|
||||
permanentErrorPatterns: []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)invalid api key`),
|
||||
regexp.MustCompile(`(?i)incorrect api key`),
|
||||
regexp.MustCompile(`(?i)api key not found`),
|
||||
regexp.MustCompile(`(?i)unauthorized`),
|
||||
regexp.MustCompile(`(?i)account deactivated`),
|
||||
regexp.MustCompile(`(?i)billing`),
|
||||
},
|
||||
}
|
||||
|
||||
// Start memory cleanup
|
||||
km.setupMemoryCleanup()
|
||||
|
||||
// Load keys
|
||||
if err := km.LoadKeys(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return km, nil
|
||||
}
|
||||
|
||||
// LoadKeys loads API keys from file
|
||||
func (km *Manager) LoadKeys() error {
|
||||
file, err := os.Open(km.keysFilePath)
|
||||
if err != nil {
|
||||
return errors.NewAppErrorWithCause(errors.ErrKeyFileNotFound, "Failed to open keys file", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var keys []string
|
||||
var keyPreviews []string
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line != "" && !strings.HasPrefix(line, "#") {
|
||||
keys = append(keys, line)
|
||||
// Create preview (first 8 chars + "..." + last 4 chars)
|
||||
if len(line) > 12 {
|
||||
preview := line[:8] + "..." + line[len(line)-4:]
|
||||
keyPreviews = append(keyPreviews, preview)
|
||||
} else {
|
||||
keyPreviews = append(keyPreviews, line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return errors.NewAppErrorWithCause(errors.ErrKeyFileInvalid, "Failed to read keys file", err)
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return errors.NewAppError(errors.ErrNoKeysAvailable, "No valid API keys found in file")
|
||||
}
|
||||
|
||||
km.keysMutex.Lock()
|
||||
km.keys = keys
|
||||
km.keyPreviews = keyPreviews
|
||||
km.keysMutex.Unlock()
|
||||
|
||||
logrus.Infof("Successfully loaded %d API keys", len(keys))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNextKey gets the next available key (high-performance version)
|
||||
func (km *Manager) GetNextKey() (*types.KeyInfo, error) {
|
||||
km.keysMutex.RLock()
|
||||
keysLen := len(km.keys)
|
||||
if keysLen == 0 {
|
||||
km.keysMutex.RUnlock()
|
||||
return nil, errors.ErrNoAPIKeysAvailable
|
||||
}
|
||||
|
||||
// Fast path: directly get next key, avoid blacklist check overhead
|
||||
currentIdx := atomic.AddInt64(&km.currentIndex, 1) - 1
|
||||
keyIndex := int(currentIdx) % keysLen
|
||||
selectedKey := km.keys[keyIndex]
|
||||
keyPreview := km.keyPreviews[keyIndex]
|
||||
km.keysMutex.RUnlock()
|
||||
|
||||
// Check if blacklisted
|
||||
if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted {
|
||||
return &types.KeyInfo{
|
||||
Key: selectedKey,
|
||||
Index: keyIndex,
|
||||
Preview: keyPreview,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Slow path: find next available key
|
||||
return km.findNextAvailableKey(keyIndex, keysLen)
|
||||
}
|
||||
|
||||
// findNextAvailableKey finds the next available non-blacklisted key
|
||||
func (km *Manager) findNextAvailableKey(startIndex, keysLen int) (*types.KeyInfo, error) {
|
||||
km.keysMutex.RLock()
|
||||
defer km.keysMutex.RUnlock()
|
||||
|
||||
blacklistedCount := 0
|
||||
for i := 0; i < keysLen; i++ {
|
||||
keyIndex := (startIndex + i) % keysLen
|
||||
selectedKey := km.keys[keyIndex]
|
||||
|
||||
if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted {
|
||||
return &types.KeyInfo{
|
||||
Key: selectedKey,
|
||||
Index: keyIndex,
|
||||
Preview: km.keyPreviews[keyIndex],
|
||||
}, nil
|
||||
}
|
||||
blacklistedCount++
|
||||
}
|
||||
|
||||
if blacklistedCount >= keysLen {
|
||||
logrus.Warn("All keys are blacklisted, resetting blacklist")
|
||||
km.blacklistedKeys = sync.Map{}
|
||||
km.keyFailureCounts = sync.Map{}
|
||||
|
||||
// Return first key after reset
|
||||
firstKey := km.keys[0]
|
||||
firstPreview := km.keyPreviews[0]
|
||||
|
||||
return &types.KeyInfo{
|
||||
Key: firstKey,
|
||||
Index: 0,
|
||||
Preview: firstPreview,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.ErrAllAPIKeysBlacklisted
|
||||
}
|
||||
|
||||
// RecordSuccess records successful key usage
|
||||
func (km *Manager) RecordSuccess(key string) {
|
||||
atomic.AddInt64(&km.successCount, 1)
|
||||
// Reset failure count for this key on success
|
||||
km.keyFailureCounts.Delete(key)
|
||||
}
|
||||
|
||||
// RecordFailure records key failure and potentially blacklists it
|
||||
func (km *Manager) RecordFailure(key string, err error) {
|
||||
atomic.AddInt64(&km.failureCount, 1)
|
||||
|
||||
// Check if this is a permanent error
|
||||
if km.isPermanentError(err) {
|
||||
km.blacklistedKeys.Store(key, time.Now())
|
||||
logrus.Debugf("Key blacklisted due to permanent error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Increment failure count
|
||||
failCount, _ := km.keyFailureCounts.LoadOrStore(key, new(int64))
|
||||
if counter, ok := failCount.(*int64); ok {
|
||||
newFailCount := atomic.AddInt64(counter, 1)
|
||||
|
||||
// Blacklist if threshold exceeded
|
||||
if int(newFailCount) >= km.config.BlacklistThreshold {
|
||||
km.blacklistedKeys.Store(key, time.Now())
|
||||
logrus.Debugf("Key blacklisted after %d failures", newFailCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isPermanentError checks if an error is permanent
|
||||
func (km *Manager) isPermanentError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errorStr := strings.ToLower(err.Error())
|
||||
for _, pattern := range km.permanentErrorPatterns {
|
||||
if pattern.MatchString(errorStr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetStats returns current statistics
|
||||
func (km *Manager) GetStats() types.Stats {
|
||||
km.keysMutex.RLock()
|
||||
totalKeys := len(km.keys)
|
||||
km.keysMutex.RUnlock()
|
||||
|
||||
blacklistedCount := 0
|
||||
km.blacklistedKeys.Range(func(key, value any) bool {
|
||||
blacklistedCount++
|
||||
return true
|
||||
})
|
||||
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
return types.Stats{
|
||||
CurrentIndex: atomic.LoadInt64(&km.currentIndex),
|
||||
TotalKeys: totalKeys,
|
||||
HealthyKeys: totalKeys - blacklistedCount,
|
||||
BlacklistedKeys: blacklistedCount,
|
||||
SuccessCount: atomic.LoadInt64(&km.successCount),
|
||||
FailureCount: atomic.LoadInt64(&km.failureCount),
|
||||
MemoryUsage: types.MemoryUsage{
|
||||
Alloc: m.Alloc,
|
||||
TotalAlloc: m.TotalAlloc,
|
||||
Sys: m.Sys,
|
||||
NumGC: m.NumGC,
|
||||
LastGCTime: time.Unix(0, int64(m.LastGC)).Format("2006-01-02 15:04:05"),
|
||||
NextGCTarget: m.NextGC,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ResetBlacklist resets the blacklist
|
||||
func (km *Manager) ResetBlacklist() {
|
||||
km.blacklistedKeys = sync.Map{}
|
||||
km.keyFailureCounts = sync.Map{}
|
||||
logrus.Info("Blacklist reset successfully")
|
||||
}
|
||||
|
||||
// GetBlacklist returns current blacklisted keys
|
||||
func (km *Manager) GetBlacklist() []types.BlacklistEntry {
|
||||
var blacklist []types.BlacklistEntry
|
||||
|
||||
km.blacklistedKeys.Range(func(key, value any) bool {
|
||||
keyStr := key.(string)
|
||||
blacklistTime := value.(time.Time)
|
||||
|
||||
// Create preview
|
||||
preview := keyStr
|
||||
if len(keyStr) > 12 {
|
||||
preview = keyStr[:8] + "..." + keyStr[len(keyStr)-4:]
|
||||
}
|
||||
|
||||
// Get failure count
|
||||
failCount := 0
|
||||
if count, exists := km.keyFailureCounts.Load(keyStr); exists {
|
||||
failCount = int(atomic.LoadInt64(count.(*int64)))
|
||||
}
|
||||
|
||||
blacklist = append(blacklist, types.BlacklistEntry{
|
||||
Key: keyStr,
|
||||
Preview: preview,
|
||||
Reason: "Exceeded failure threshold",
|
||||
BlacklistAt: blacklistTime,
|
||||
FailCount: failCount,
|
||||
})
|
||||
return true
|
||||
})
|
||||
|
||||
return blacklist
|
||||
}
|
||||
|
||||
// setupMemoryCleanup sets up periodic memory cleanup
|
||||
func (km *Manager) setupMemoryCleanup() {
|
||||
// Reduce GC frequency to every 15 minutes to avoid performance impact
|
||||
km.cleanupTicker = time.NewTicker(15 * time.Minute)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-km.cleanupTicker.C:
|
||||
// Only trigger GC if memory usage is high
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
// Trigger GC only if allocated memory is above 100MB
|
||||
if m.Alloc > 100*1024*1024 {
|
||||
runtime.GC()
|
||||
logrus.Debugf("Manual GC triggered, memory usage: %d MB", m.Alloc/1024/1024)
|
||||
}
|
||||
case <-km.stopCleanup:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Close closes the key manager and cleans up resources
|
||||
func (km *Manager) Close() {
|
||||
if km.cleanupTicker != nil {
|
||||
km.cleanupTicker.Stop()
|
||||
}
|
||||
close(km.stopCleanup)
|
||||
}
|
@@ -8,6 +8,7 @@ import (
|
||||
"gpt-load/internal/response"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
// ProxyServer represents the proxy server
|
||||
type ProxyServer struct {
|
||||
DB *gorm.DB
|
||||
groupCounters sync.Map // For round-robin key selection
|
||||
groupCounters sync.Map // map[uint]*atomic.Uint64
|
||||
requestLogChan chan models.RequestLog
|
||||
}
|
||||
|
||||
@@ -82,18 +83,22 @@ func (ps *ProxyServer) selectAPIKey(group *models.Group) (*models.APIKey, error)
|
||||
return nil, fmt.Errorf("no active API keys available in group '%s'", group.Name)
|
||||
}
|
||||
|
||||
// Get the current counter for the group
|
||||
counter, _ := ps.groupCounters.LoadOrStore(group.ID, uint64(0))
|
||||
currentCounter := counter.(uint64)
|
||||
// Get or create a counter for the group. The value is a pointer to a uint64.
|
||||
val, _ := ps.groupCounters.LoadOrStore(group.ID, new(atomic.Uint64))
|
||||
counter := val.(*atomic.Uint64)
|
||||
|
||||
// Select the key and increment the counter
|
||||
selectedKey := activeKeys[int(currentCounter%uint64(len(activeKeys)))]
|
||||
ps.groupCounters.Store(group.ID, currentCounter+1)
|
||||
// Atomically increment the counter and get the index for this request.
|
||||
index := counter.Add(1) - 1
|
||||
selectedKey := activeKeys[int(index%uint64(len(activeKeys)))]
|
||||
|
||||
return &selectedKey, nil
|
||||
}
|
||||
|
||||
func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time) {
|
||||
// Update key stats based on request success
|
||||
isSuccess := c.Writer.Status() < 400
|
||||
go ps.updateKeyStats(key.ID, isSuccess)
|
||||
|
||||
logEntry := models.RequestLog{
|
||||
ID: fmt.Sprintf("req_%d", time.Now().UnixNano()),
|
||||
Timestamp: startTime,
|
||||
@@ -113,6 +118,27 @@ func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *mode
|
||||
}
|
||||
}
|
||||
|
||||
// updateKeyStats atomically updates the request and failure counts for a key
|
||||
func (ps *ProxyServer) updateKeyStats(keyID uint, success bool) {
|
||||
// Always increment the request count
|
||||
updates := map[string]interface{}{
|
||||
"request_count": gorm.Expr("request_count + 1"),
|
||||
}
|
||||
|
||||
// Additionally, increment the failure count if the request was not successful
|
||||
if !success {
|
||||
updates["failure_count"] = gorm.Expr("failure_count + 1")
|
||||
}
|
||||
|
||||
result := ps.DB.Model(&models.APIKey{}).Where("id = ?", keyID).Updates(updates)
|
||||
if result.Error != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"keyID": keyID,
|
||||
"error": result.Error,
|
||||
}).Error("Failed to update key stats")
|
||||
}
|
||||
}
|
||||
|
||||
// Close cleans up resources
|
||||
func (ps *ProxyServer) Close() {
|
||||
// Nothing to close for now
|
||||
|
@@ -2,15 +2,12 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ConfigManager defines the interface for configuration management
|
||||
type ConfigManager interface {
|
||||
GetServerConfig() ServerConfig
|
||||
GetKeysConfig() KeysConfig
|
||||
GetOpenAIConfig() OpenAIConfig
|
||||
GetAuthConfig() AuthConfig
|
||||
GetCORSConfig() CORSConfig
|
||||
@@ -21,18 +18,6 @@ type ConfigManager interface {
|
||||
ReloadConfig() error
|
||||
}
|
||||
|
||||
// KeyManager defines the interface for API key management
|
||||
type KeyManager interface {
|
||||
LoadKeys() error
|
||||
GetNextKey() (*KeyInfo, error)
|
||||
RecordSuccess(key string)
|
||||
RecordFailure(key string, err error)
|
||||
GetStats() Stats
|
||||
ResetBlacklist()
|
||||
GetBlacklist() []BlacklistEntry
|
||||
Close()
|
||||
}
|
||||
|
||||
// ProxyServer defines the interface for proxy server
|
||||
type ProxyServer interface {
|
||||
HandleProxy(c *gin.Context)
|
||||
@@ -49,14 +34,6 @@ type ServerConfig struct {
|
||||
GracefulShutdownTimeout int `json:"gracefulShutdownTimeout"`
|
||||
}
|
||||
|
||||
// KeysConfig represents keys configuration
|
||||
type KeysConfig struct {
|
||||
FilePath string `json:"filePath"`
|
||||
StartIndex int `json:"startIndex"`
|
||||
BlacklistThreshold int `json:"blacklistThreshold"`
|
||||
MaxRetries int `json:"maxRetries"`
|
||||
}
|
||||
|
||||
// OpenAIConfig represents OpenAI API configuration
|
||||
type OpenAIConfig struct {
|
||||
BaseURL string `json:"baseUrl"`
|
||||
@@ -95,48 +72,3 @@ type LogConfig struct {
|
||||
FilePath string `json:"filePath"`
|
||||
EnableRequest bool `json:"enableRequest"`
|
||||
}
|
||||
|
||||
// KeyInfo represents API key information
|
||||
type KeyInfo struct {
|
||||
Key string `json:"key"`
|
||||
Index int `json:"index"`
|
||||
Preview string `json:"preview"`
|
||||
}
|
||||
|
||||
// Stats represents system statistics
|
||||
type Stats struct {
|
||||
CurrentIndex int64 `json:"currentIndex"`
|
||||
TotalKeys int `json:"totalKeys"`
|
||||
HealthyKeys int `json:"healthyKeys"`
|
||||
BlacklistedKeys int `json:"blacklistedKeys"`
|
||||
SuccessCount int64 `json:"successCount"`
|
||||
FailureCount int64 `json:"failureCount"`
|
||||
MemoryUsage MemoryUsage `json:"memoryUsage"`
|
||||
}
|
||||
|
||||
// MemoryUsage represents memory usage statistics
|
||||
type MemoryUsage struct {
|
||||
Alloc uint64 `json:"alloc"`
|
||||
TotalAlloc uint64 `json:"totalAlloc"`
|
||||
Sys uint64 `json:"sys"`
|
||||
NumGC uint32 `json:"numGC"`
|
||||
LastGCTime string `json:"lastGCTime"`
|
||||
NextGCTarget uint64 `json:"nextGCTarget"`
|
||||
}
|
||||
|
||||
// BlacklistEntry represents a blacklisted key entry
|
||||
type BlacklistEntry struct {
|
||||
Key string `json:"key"`
|
||||
Preview string `json:"preview"`
|
||||
Reason string `json:"reason"`
|
||||
BlacklistAt time.Time `json:"blacklistAt"`
|
||||
FailCount int `json:"failCount"`
|
||||
}
|
||||
|
||||
// RetryError represents retry error information
|
||||
type RetryError struct {
|
||||
StatusCode int `json:"statusCode"`
|
||||
ErrorMessage string `json:"errorMessage"`
|
||||
KeyIndex int `json:"keyIndex"`
|
||||
Attempt int `json:"attempt"`
|
||||
}
|
Reference in New Issue
Block a user