297 lines
9.4 KiB
Go
297 lines
9.4 KiB
Go
// Package proxy provides high-performance OpenAI multi-key proxy server
|
||
package proxy
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"time"
|
||
|
||
"gpt-load/internal/channel"
|
||
"gpt-load/internal/config"
|
||
app_errors "gpt-load/internal/errors"
|
||
"gpt-load/internal/keypool"
|
||
"gpt-load/internal/models"
|
||
"gpt-load/internal/response"
|
||
"gpt-load/internal/services"
|
||
"gpt-load/internal/types"
|
||
"gpt-load/internal/utils"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/sirupsen/logrus"
|
||
)
|
||
|
||
// ProxyServer represents the proxy server
|
||
type ProxyServer struct {
|
||
keyProvider *keypool.KeyProvider
|
||
groupManager *services.GroupManager
|
||
settingsManager *config.SystemSettingsManager
|
||
channelFactory *channel.Factory
|
||
requestLogService *services.RequestLogService
|
||
}
|
||
|
||
// NewProxyServer creates a new proxy server
|
||
func NewProxyServer(
|
||
keyProvider *keypool.KeyProvider,
|
||
groupManager *services.GroupManager,
|
||
settingsManager *config.SystemSettingsManager,
|
||
channelFactory *channel.Factory,
|
||
requestLogService *services.RequestLogService,
|
||
) (*ProxyServer, error) {
|
||
return &ProxyServer{
|
||
keyProvider: keyProvider,
|
||
groupManager: groupManager,
|
||
settingsManager: settingsManager,
|
||
channelFactory: channelFactory,
|
||
requestLogService: requestLogService,
|
||
}, nil
|
||
}
|
||
|
||
// HandleProxy is the main entry point for proxy requests, refactored based on the stable .bak logic.
|
||
func (ps *ProxyServer) HandleProxy(c *gin.Context) {
|
||
startTime := time.Now()
|
||
groupName := c.Param("group_name")
|
||
|
||
group, err := ps.groupManager.GetGroupByName(groupName)
|
||
if err != nil {
|
||
response.Error(c, app_errors.ParseDBError(err))
|
||
return
|
||
}
|
||
|
||
channelHandler, err := ps.channelFactory.GetChannel(group)
|
||
if err != nil {
|
||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s': %v", groupName, err)))
|
||
return
|
||
}
|
||
|
||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||
if err != nil {
|
||
logrus.Errorf("Failed to read request body: %v", err)
|
||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Failed to read request body"))
|
||
return
|
||
}
|
||
c.Request.Body.Close()
|
||
|
||
finalBodyBytes, err := ps.applyParamOverrides(bodyBytes, group)
|
||
if err != nil {
|
||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to apply parameter overrides: %v", err)))
|
||
return
|
||
}
|
||
|
||
isStream := channelHandler.IsStreamRequest(c, bodyBytes)
|
||
|
||
ps.executeRequestWithRetry(c, channelHandler, group, finalBodyBytes, isStream, startTime, 0, nil)
|
||
}
|
||
|
||
// executeRequestWithRetry is the core recursive function for handling requests and retries.
|
||
func (ps *ProxyServer) executeRequestWithRetry(
|
||
c *gin.Context,
|
||
channelHandler channel.ChannelProxy,
|
||
group *models.Group,
|
||
bodyBytes []byte,
|
||
isStream bool,
|
||
startTime time.Time,
|
||
retryCount int,
|
||
retryErrors []types.RetryError,
|
||
) {
|
||
cfg := group.EffectiveConfig
|
||
if retryCount > cfg.MaxRetries {
|
||
if len(retryErrors) > 0 {
|
||
lastError := retryErrors[len(retryErrors)-1]
|
||
var errorJSON map[string]any
|
||
if err := json.Unmarshal([]byte(lastError.ErrorMessage), &errorJSON); err == nil {
|
||
c.JSON(lastError.StatusCode, errorJSON)
|
||
} else {
|
||
response.Error(c, app_errors.NewAPIErrorWithUpstream(lastError.StatusCode, "UPSTREAM_ERROR", lastError.ErrorMessage))
|
||
}
|
||
logMessage := lastError.ParsedErrorMessage
|
||
if logMessage == "" {
|
||
logMessage = lastError.ErrorMessage
|
||
}
|
||
logrus.Debugf("Max retries exceeded for group %s after %d attempts. Parsed Error: %s", group.Name, retryCount, logMessage)
|
||
|
||
ps.logRequest(c, group, &models.APIKey{KeyValue: lastError.KeyValue}, startTime, lastError.StatusCode, retryCount, errors.New(logMessage), isStream, lastError.UpstreamAddr, channelHandler, bodyBytes)
|
||
} else {
|
||
response.Error(c, app_errors.ErrMaxRetriesExceeded)
|
||
logrus.Debugf("Max retries exceeded for group %s after %d attempts.", group.Name, retryCount)
|
||
ps.logRequest(c, group, nil, startTime, http.StatusServiceUnavailable, retryCount, app_errors.ErrMaxRetriesExceeded, isStream, "", channelHandler, bodyBytes)
|
||
}
|
||
return
|
||
}
|
||
|
||
apiKey, err := ps.keyProvider.SelectKey(group.ID)
|
||
if err != nil {
|
||
logrus.Errorf("Failed to select a key for group %s on attempt %d: %v", group.Name, retryCount+1, err)
|
||
response.Error(c, app_errors.NewAPIError(app_errors.ErrNoKeysAvailable, err.Error()))
|
||
ps.logRequest(c, group, nil, startTime, http.StatusServiceUnavailable, retryCount, err, isStream, "", channelHandler, bodyBytes)
|
||
return
|
||
}
|
||
|
||
upstreamURL, err := channelHandler.BuildUpstreamURL(c.Request.URL, group)
|
||
if err != nil {
|
||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to build upstream URL: %v", err)))
|
||
return
|
||
}
|
||
|
||
var ctx context.Context
|
||
var cancel context.CancelFunc
|
||
if isStream {
|
||
ctx, cancel = context.WithCancel(c.Request.Context())
|
||
} else {
|
||
timeout := time.Duration(cfg.RequestTimeout) * time.Second
|
||
ctx, cancel = context.WithTimeout(c.Request.Context(), timeout)
|
||
}
|
||
defer cancel()
|
||
|
||
req, err := http.NewRequestWithContext(ctx, c.Request.Method, upstreamURL, bytes.NewReader(bodyBytes))
|
||
if err != nil {
|
||
logrus.Errorf("Failed to create upstream request: %v", err)
|
||
response.Error(c, app_errors.ErrInternalServer)
|
||
return
|
||
}
|
||
req.ContentLength = int64(len(bodyBytes))
|
||
|
||
req.Header = c.Request.Header.Clone()
|
||
|
||
// Clean up client auth key
|
||
req.Header.Del("Authorization")
|
||
req.Header.Del("X-Api-Key")
|
||
req.Header.Del("X-Goog-Api-Key")
|
||
q := req.URL.Query()
|
||
q.Del("key")
|
||
req.URL.RawQuery = q.Encode()
|
||
|
||
channelHandler.ModifyRequest(req, apiKey, group)
|
||
|
||
var client *http.Client
|
||
if isStream {
|
||
client = channelHandler.GetStreamClient()
|
||
req.Header.Set("X-Accel-Buffering", "no")
|
||
} else {
|
||
client = channelHandler.GetHTTPClient()
|
||
}
|
||
|
||
resp, err := client.Do(req)
|
||
if resp != nil {
|
||
defer resp.Body.Close()
|
||
}
|
||
|
||
// Unified error handling for retries.
|
||
// Exclude 404 from being a retryable error.
|
||
if err != nil || (resp != nil && resp.StatusCode >= 400 && resp.StatusCode != http.StatusNotFound) {
|
||
if err != nil && app_errors.IsIgnorableError(err) {
|
||
logrus.Debugf("Client-side ignorable error for key %s, aborting retries: %v", utils.MaskAPIKey(apiKey.KeyValue), err)
|
||
ps.logRequest(c, group, apiKey, startTime, 499, retryCount+1, err, isStream, upstreamURL, channelHandler, bodyBytes)
|
||
return
|
||
}
|
||
|
||
ps.keyProvider.UpdateStatus(apiKey, group, false)
|
||
|
||
var statusCode int
|
||
var errorMessage string
|
||
var parsedError string
|
||
|
||
if err != nil {
|
||
statusCode = 500
|
||
errorMessage = err.Error()
|
||
logrus.Debugf("Request failed (attempt %d/%d) for key %s: %v", retryCount+1, cfg.MaxRetries, utils.MaskAPIKey(apiKey.KeyValue), err)
|
||
} else {
|
||
// HTTP-level error (status >= 400)
|
||
statusCode = resp.StatusCode
|
||
errorBody, readErr := io.ReadAll(resp.Body)
|
||
if readErr != nil {
|
||
logrus.Errorf("Failed to read error body: %v", readErr)
|
||
errorBody = []byte("Failed to read error body")
|
||
}
|
||
|
||
errorBody = handleGzipCompression(resp, errorBody)
|
||
errorMessage = string(errorBody)
|
||
parsedError = app_errors.ParseUpstreamError(errorBody)
|
||
logrus.Debugf("Request failed with status %d (attempt %d/%d) for key %s. Parsed Error: %s", statusCode, retryCount+1, cfg.MaxRetries, utils.MaskAPIKey(apiKey.KeyValue), parsedError)
|
||
}
|
||
|
||
newRetryErrors := append(retryErrors, types.RetryError{
|
||
StatusCode: statusCode,
|
||
ErrorMessage: errorMessage,
|
||
ParsedErrorMessage: parsedError,
|
||
KeyValue: apiKey.KeyValue,
|
||
Attempt: retryCount + 1,
|
||
UpstreamAddr: upstreamURL,
|
||
})
|
||
ps.executeRequestWithRetry(c, channelHandler, group, bodyBytes, isStream, startTime, retryCount+1, newRetryErrors)
|
||
return
|
||
}
|
||
|
||
// ps.keyProvider.UpdateStatus(apiKey, group, true) // 请求成功不再重置成功次数,减少IO消耗
|
||
logrus.Debugf("Request for group %s succeeded on attempt %d with key %s", group.Name, retryCount+1, utils.MaskAPIKey(apiKey.KeyValue))
|
||
ps.logRequest(c, group, apiKey, startTime, resp.StatusCode, retryCount+1, nil, isStream, upstreamURL, channelHandler, bodyBytes)
|
||
|
||
for key, values := range resp.Header {
|
||
for _, value := range values {
|
||
c.Header(key, value)
|
||
}
|
||
}
|
||
c.Status(resp.StatusCode)
|
||
|
||
if isStream {
|
||
ps.handleStreamingResponse(c, resp)
|
||
} else {
|
||
ps.handleNormalResponse(c, resp)
|
||
}
|
||
}
|
||
|
||
// logRequest is a helper function to create and record a request log.
|
||
func (ps *ProxyServer) logRequest(
|
||
c *gin.Context,
|
||
group *models.Group,
|
||
apiKey *models.APIKey,
|
||
startTime time.Time,
|
||
statusCode int,
|
||
retries int,
|
||
finalError error,
|
||
isStream bool,
|
||
upstreamAddr string,
|
||
channelHandler channel.ChannelProxy,
|
||
bodyBytes []byte,
|
||
) {
|
||
if ps.requestLogService == nil {
|
||
return
|
||
}
|
||
|
||
duration := time.Since(startTime).Milliseconds()
|
||
|
||
logEntry := &models.RequestLog{
|
||
GroupID: group.ID,
|
||
GroupName: group.Name,
|
||
IsSuccess: finalError == nil && statusCode < 400,
|
||
SourceIP: c.ClientIP(),
|
||
StatusCode: statusCode,
|
||
RequestPath: utils.TruncateString(c.Request.URL.String(), 500),
|
||
Duration: duration,
|
||
UserAgent: c.Request.UserAgent(),
|
||
Retries: retries,
|
||
IsStream: isStream,
|
||
UpstreamAddr: utils.TruncateString(upstreamAddr, 500),
|
||
}
|
||
|
||
if channelHandler != nil && bodyBytes != nil {
|
||
logEntry.Model = channelHandler.ExtractModel(c, bodyBytes)
|
||
}
|
||
|
||
if apiKey != nil {
|
||
logEntry.KeyValue = apiKey.KeyValue
|
||
}
|
||
|
||
if finalError != nil {
|
||
logEntry.ErrorMessage = finalError.Error()
|
||
}
|
||
|
||
if err := ps.requestLogService.Record(logEntry); err != nil {
|
||
logrus.Errorf("Failed to record request log: %v", err)
|
||
}
|
||
}
|