feat: Multi-Target Load Balancing

This commit is contained in:
tbphp
2025-06-11 11:50:49 +08:00
parent b97bd1146b
commit 3450a05615
5 changed files with 63 additions and 27 deletions

View File

@@ -7,6 +7,7 @@ import (
"os"
"strconv"
"strings"
"sync/atomic"
"gpt-load/internal/errors"
"gpt-load/pkg/types"
@@ -37,7 +38,8 @@ var DefaultConstants = Constants{
// Manager implements the ConfigManager interface
type Manager struct {
config *Config
config *Config
roundRobinCounter uint64
}
// Config represents the application configuration
@@ -70,8 +72,8 @@ func NewManager() (types.ConfigManager, error) {
MaxRetries: parseInteger(os.Getenv("MAX_RETRIES"), 3),
},
OpenAI: types.OpenAIConfig{
BaseURL: getEnvOrDefault("OPENAI_BASE_URL", "https://api.openai.com"),
Timeout: parseInteger(os.Getenv("REQUEST_TIMEOUT"), DefaultConstants.DefaultTimeout),
BaseURLs: parseArray(os.Getenv("OPENAI_BASE_URL"), []string{"https://api.openai.com"}),
Timeout: parseInteger(os.Getenv("REQUEST_TIMEOUT"), DefaultConstants.DefaultTimeout),
},
Auth: types.AuthConfig{
Key: os.Getenv("AUTH_KEY"),
@@ -120,7 +122,15 @@ func (m *Manager) GetKeysConfig() types.KeysConfig {
// GetOpenAIConfig returns OpenAI configuration
func (m *Manager) GetOpenAIConfig() types.OpenAIConfig {
return m.config.OpenAI
config := m.config.OpenAI
if len(config.BaseURLs) > 1 {
// Use atomic counter for thread-safe round-robin
index := atomic.AddUint64(&m.roundRobinCounter, 1) - 1
config.BaseURL = config.BaseURLs[index%uint64(len(config.BaseURLs))]
} else if len(config.BaseURLs) == 1 {
config.BaseURL = config.BaseURLs[0]
}
return config
}
// GetAuthConfig returns authentication configuration
@@ -168,8 +178,13 @@ func (m *Manager) Validate() error {
}
// Validate upstream URL format
if _, err := url.Parse(m.config.OpenAI.BaseURL); err != nil {
validationErrors = append(validationErrors, "invalid upstream API URL format")
if len(m.config.OpenAI.BaseURLs) == 0 {
validationErrors = append(validationErrors, "at least one upstream API URL is required")
}
for _, baseURL := range m.config.OpenAI.BaseURLs {
if _, err := url.Parse(baseURL); err != nil {
validationErrors = append(validationErrors, fmt.Sprintf("invalid upstream API URL format: %s", baseURL))
}
}
// Validate performance configuration
@@ -196,7 +211,7 @@ func (m *Manager) DisplayConfig() {
logrus.Infof(" Start index: %d", m.config.Keys.StartIndex)
logrus.Infof(" Blacklist threshold: %d errors", m.config.Keys.BlacklistThreshold)
logrus.Infof(" Max retries: %d", m.config.Keys.MaxRetries)
logrus.Infof(" Upstream URL: %s", m.config.OpenAI.BaseURL)
logrus.Infof(" Upstream URLs: %s", strings.Join(m.config.OpenAI.BaseURLs, ", "))
logrus.Infof(" Request timeout: %dms", m.config.OpenAI.Timeout)
authStatus := "disabled"