feat: Multi-Target Load Balancing
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"gpt-load/internal/errors"
|
||||
"gpt-load/pkg/types"
|
||||
@@ -37,7 +38,8 @@ var DefaultConstants = Constants{
|
||||
|
||||
// Manager implements the ConfigManager interface
|
||||
type Manager struct {
|
||||
config *Config
|
||||
config *Config
|
||||
roundRobinCounter uint64
|
||||
}
|
||||
|
||||
// Config represents the application configuration
|
||||
@@ -70,8 +72,8 @@ func NewManager() (types.ConfigManager, error) {
|
||||
MaxRetries: parseInteger(os.Getenv("MAX_RETRIES"), 3),
|
||||
},
|
||||
OpenAI: types.OpenAIConfig{
|
||||
BaseURL: getEnvOrDefault("OPENAI_BASE_URL", "https://api.openai.com"),
|
||||
Timeout: parseInteger(os.Getenv("REQUEST_TIMEOUT"), DefaultConstants.DefaultTimeout),
|
||||
BaseURLs: parseArray(os.Getenv("OPENAI_BASE_URL"), []string{"https://api.openai.com"}),
|
||||
Timeout: parseInteger(os.Getenv("REQUEST_TIMEOUT"), DefaultConstants.DefaultTimeout),
|
||||
},
|
||||
Auth: types.AuthConfig{
|
||||
Key: os.Getenv("AUTH_KEY"),
|
||||
@@ -120,7 +122,15 @@ func (m *Manager) GetKeysConfig() types.KeysConfig {
|
||||
|
||||
// GetOpenAIConfig returns OpenAI configuration
|
||||
func (m *Manager) GetOpenAIConfig() types.OpenAIConfig {
|
||||
return m.config.OpenAI
|
||||
config := m.config.OpenAI
|
||||
if len(config.BaseURLs) > 1 {
|
||||
// Use atomic counter for thread-safe round-robin
|
||||
index := atomic.AddUint64(&m.roundRobinCounter, 1) - 1
|
||||
config.BaseURL = config.BaseURLs[index%uint64(len(config.BaseURLs))]
|
||||
} else if len(config.BaseURLs) == 1 {
|
||||
config.BaseURL = config.BaseURLs[0]
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
// GetAuthConfig returns authentication configuration
|
||||
@@ -168,8 +178,13 @@ func (m *Manager) Validate() error {
|
||||
}
|
||||
|
||||
// Validate upstream URL format
|
||||
if _, err := url.Parse(m.config.OpenAI.BaseURL); err != nil {
|
||||
validationErrors = append(validationErrors, "invalid upstream API URL format")
|
||||
if len(m.config.OpenAI.BaseURLs) == 0 {
|
||||
validationErrors = append(validationErrors, "at least one upstream API URL is required")
|
||||
}
|
||||
for _, baseURL := range m.config.OpenAI.BaseURLs {
|
||||
if _, err := url.Parse(baseURL); err != nil {
|
||||
validationErrors = append(validationErrors, fmt.Sprintf("invalid upstream API URL format: %s", baseURL))
|
||||
}
|
||||
}
|
||||
|
||||
// Validate performance configuration
|
||||
@@ -196,7 +211,7 @@ func (m *Manager) DisplayConfig() {
|
||||
logrus.Infof(" Start index: %d", m.config.Keys.StartIndex)
|
||||
logrus.Infof(" Blacklist threshold: %d errors", m.config.Keys.BlacklistThreshold)
|
||||
logrus.Infof(" Max retries: %d", m.config.Keys.MaxRetries)
|
||||
logrus.Infof(" Upstream URL: %s", m.config.OpenAI.BaseURL)
|
||||
logrus.Infof(" Upstream URLs: %s", strings.Join(m.config.OpenAI.BaseURLs, ", "))
|
||||
logrus.Infof(" Request timeout: %dms", m.config.OpenAI.Timeout)
|
||||
|
||||
authStatus := "disabled"
|
||||
|
@@ -44,7 +44,6 @@ type ProxyServer struct {
|
||||
configManager types.ConfigManager
|
||||
httpClient *http.Client
|
||||
streamClient *http.Client // Dedicated client for streaming
|
||||
upstreamURL *url.URL
|
||||
requestCount int64
|
||||
startTime time.Time
|
||||
}
|
||||
@@ -54,11 +53,6 @@ func NewProxyServer(keyManager types.KeyManager, configManager types.ConfigManag
|
||||
openaiConfig := configManager.GetOpenAIConfig()
|
||||
perfConfig := configManager.GetPerformanceConfig()
|
||||
|
||||
// Parse upstream URL
|
||||
upstreamURL, err := url.Parse(openaiConfig.BaseURL)
|
||||
if err != nil {
|
||||
return nil, errors.NewAppErrorWithCause(errors.ErrConfigInvalid, "Failed to parse upstream URL", err)
|
||||
}
|
||||
|
||||
// Create high-performance HTTP client
|
||||
transport := &http.Transport{
|
||||
@@ -104,7 +98,6 @@ func NewProxyServer(keyManager types.KeyManager, configManager types.ConfigManag
|
||||
configManager: configManager,
|
||||
httpClient: httpClient,
|
||||
streamClient: streamClient,
|
||||
upstreamURL: upstreamURL,
|
||||
startTime: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
@@ -205,8 +198,20 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti
|
||||
c.Set("retryCount", retryCount)
|
||||
}
|
||||
|
||||
// Get a base URL from the config manager (handles round-robin)
|
||||
openaiConfig := ps.configManager.GetOpenAIConfig()
|
||||
upstreamURL, err := url.Parse(openaiConfig.BaseURL)
|
||||
if err != nil {
|
||||
logrus.Errorf("Failed to parse upstream URL: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Invalid upstream URL configured",
|
||||
"code": errors.ErrConfigInvalid,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Build upstream request URL
|
||||
targetURL := *ps.upstreamURL
|
||||
targetURL := *upstreamURL
|
||||
// Correctly append path instead of replacing it
|
||||
if strings.HasSuffix(targetURL.Path, "/") {
|
||||
targetURL.Path = targetURL.Path + strings.TrimPrefix(c.Request.URL.Path, "/")
|
||||
@@ -223,8 +228,7 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti
|
||||
// Streaming requests only set response header timeout, no overall timeout
|
||||
ctx, cancel = context.WithCancel(c.Request.Context())
|
||||
} else {
|
||||
// Non-streaming requests use configured timeout
|
||||
openaiConfig := ps.configManager.GetOpenAIConfig()
|
||||
// Non-streaming requests use configured timeout from the already fetched config
|
||||
timeout := time.Duration(openaiConfig.Timeout) * time.Millisecond
|
||||
ctx, cancel = context.WithTimeout(c.Request.Context(), timeout)
|
||||
}
|
||||
|
Reference in New Issue
Block a user