Files
gpt-load/internal/proxy/server.go
enoch 2ab9ded58f
Some checks failed
Release Linux Build / release (push) Has been cancelled
Release Windows Build / release (push) Has been cancelled
Release MacOS Build / release (push) Has been cancelled
Build Docker Image / Build Docker Image (push) Has been cancelled
fix(proxy): 排除404状态码避免无意义重试
当请求资源不存在时,不代表key无效,反而可能错误地将key判断为无效,而且此时重试没有意义
2025-08-10 13:06:25 +08:00

297 lines
9.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package proxy provides high-performance OpenAI multi-key proxy server
package proxy
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
"gpt-load/internal/channel"
"gpt-load/internal/config"
app_errors "gpt-load/internal/errors"
"gpt-load/internal/keypool"
"gpt-load/internal/models"
"gpt-load/internal/response"
"gpt-load/internal/services"
"gpt-load/internal/types"
"gpt-load/internal/utils"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
// ProxyServer represents the proxy server
type ProxyServer struct {
keyProvider *keypool.KeyProvider
groupManager *services.GroupManager
settingsManager *config.SystemSettingsManager
channelFactory *channel.Factory
requestLogService *services.RequestLogService
}
// NewProxyServer creates a new proxy server
func NewProxyServer(
keyProvider *keypool.KeyProvider,
groupManager *services.GroupManager,
settingsManager *config.SystemSettingsManager,
channelFactory *channel.Factory,
requestLogService *services.RequestLogService,
) (*ProxyServer, error) {
return &ProxyServer{
keyProvider: keyProvider,
groupManager: groupManager,
settingsManager: settingsManager,
channelFactory: channelFactory,
requestLogService: requestLogService,
}, nil
}
// HandleProxy is the main entry point for proxy requests, refactored based on the stable .bak logic.
func (ps *ProxyServer) HandleProxy(c *gin.Context) {
startTime := time.Now()
groupName := c.Param("group_name")
group, err := ps.groupManager.GetGroupByName(groupName)
if err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
channelHandler, err := ps.channelFactory.GetChannel(group)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s': %v", groupName, err)))
return
}
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
logrus.Errorf("Failed to read request body: %v", err)
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Failed to read request body"))
return
}
c.Request.Body.Close()
finalBodyBytes, err := ps.applyParamOverrides(bodyBytes, group)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to apply parameter overrides: %v", err)))
return
}
isStream := channelHandler.IsStreamRequest(c, bodyBytes)
ps.executeRequestWithRetry(c, channelHandler, group, finalBodyBytes, isStream, startTime, 0, nil)
}
// executeRequestWithRetry is the core recursive function for handling requests and retries.
func (ps *ProxyServer) executeRequestWithRetry(
c *gin.Context,
channelHandler channel.ChannelProxy,
group *models.Group,
bodyBytes []byte,
isStream bool,
startTime time.Time,
retryCount int,
retryErrors []types.RetryError,
) {
cfg := group.EffectiveConfig
if retryCount > cfg.MaxRetries {
if len(retryErrors) > 0 {
lastError := retryErrors[len(retryErrors)-1]
var errorJSON map[string]any
if err := json.Unmarshal([]byte(lastError.ErrorMessage), &errorJSON); err == nil {
c.JSON(lastError.StatusCode, errorJSON)
} else {
response.Error(c, app_errors.NewAPIErrorWithUpstream(lastError.StatusCode, "UPSTREAM_ERROR", lastError.ErrorMessage))
}
logMessage := lastError.ParsedErrorMessage
if logMessage == "" {
logMessage = lastError.ErrorMessage
}
logrus.Debugf("Max retries exceeded for group %s after %d attempts. Parsed Error: %s", group.Name, retryCount, logMessage)
ps.logRequest(c, group, &models.APIKey{KeyValue: lastError.KeyValue}, startTime, lastError.StatusCode, retryCount, errors.New(logMessage), isStream, lastError.UpstreamAddr, channelHandler, bodyBytes)
} else {
response.Error(c, app_errors.ErrMaxRetriesExceeded)
logrus.Debugf("Max retries exceeded for group %s after %d attempts.", group.Name, retryCount)
ps.logRequest(c, group, nil, startTime, http.StatusServiceUnavailable, retryCount, app_errors.ErrMaxRetriesExceeded, isStream, "", channelHandler, bodyBytes)
}
return
}
apiKey, err := ps.keyProvider.SelectKey(group.ID)
if err != nil {
logrus.Errorf("Failed to select a key for group %s on attempt %d: %v", group.Name, retryCount+1, err)
response.Error(c, app_errors.NewAPIError(app_errors.ErrNoKeysAvailable, err.Error()))
ps.logRequest(c, group, nil, startTime, http.StatusServiceUnavailable, retryCount, err, isStream, "", channelHandler, bodyBytes)
return
}
upstreamURL, err := channelHandler.BuildUpstreamURL(c.Request.URL, group)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to build upstream URL: %v", err)))
return
}
var ctx context.Context
var cancel context.CancelFunc
if isStream {
ctx, cancel = context.WithCancel(c.Request.Context())
} else {
timeout := time.Duration(cfg.RequestTimeout) * time.Second
ctx, cancel = context.WithTimeout(c.Request.Context(), timeout)
}
defer cancel()
req, err := http.NewRequestWithContext(ctx, c.Request.Method, upstreamURL, bytes.NewReader(bodyBytes))
if err != nil {
logrus.Errorf("Failed to create upstream request: %v", err)
response.Error(c, app_errors.ErrInternalServer)
return
}
req.ContentLength = int64(len(bodyBytes))
req.Header = c.Request.Header.Clone()
// Clean up client auth key
req.Header.Del("Authorization")
req.Header.Del("X-Api-Key")
req.Header.Del("X-Goog-Api-Key")
q := req.URL.Query()
q.Del("key")
req.URL.RawQuery = q.Encode()
channelHandler.ModifyRequest(req, apiKey, group)
var client *http.Client
if isStream {
client = channelHandler.GetStreamClient()
req.Header.Set("X-Accel-Buffering", "no")
} else {
client = channelHandler.GetHTTPClient()
}
resp, err := client.Do(req)
if resp != nil {
defer resp.Body.Close()
}
// Unified error handling for retries.
// Exclude 404 from being a retryable error.
if err != nil || (resp != nil && resp.StatusCode >= 400 && resp.StatusCode != http.StatusNotFound) {
if err != nil && app_errors.IsIgnorableError(err) {
logrus.Debugf("Client-side ignorable error for key %s, aborting retries: %v", utils.MaskAPIKey(apiKey.KeyValue), err)
ps.logRequest(c, group, apiKey, startTime, 499, retryCount+1, err, isStream, upstreamURL, channelHandler, bodyBytes)
return
}
ps.keyProvider.UpdateStatus(apiKey, group, false)
var statusCode int
var errorMessage string
var parsedError string
if err != nil {
statusCode = 500
errorMessage = err.Error()
logrus.Debugf("Request failed (attempt %d/%d) for key %s: %v", retryCount+1, cfg.MaxRetries, utils.MaskAPIKey(apiKey.KeyValue), err)
} else {
// HTTP-level error (status >= 400)
statusCode = resp.StatusCode
errorBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
logrus.Errorf("Failed to read error body: %v", readErr)
errorBody = []byte("Failed to read error body")
}
errorBody = handleGzipCompression(resp, errorBody)
errorMessage = string(errorBody)
parsedError = app_errors.ParseUpstreamError(errorBody)
logrus.Debugf("Request failed with status %d (attempt %d/%d) for key %s. Parsed Error: %s", statusCode, retryCount+1, cfg.MaxRetries, utils.MaskAPIKey(apiKey.KeyValue), parsedError)
}
newRetryErrors := append(retryErrors, types.RetryError{
StatusCode: statusCode,
ErrorMessage: errorMessage,
ParsedErrorMessage: parsedError,
KeyValue: apiKey.KeyValue,
Attempt: retryCount + 1,
UpstreamAddr: upstreamURL,
})
ps.executeRequestWithRetry(c, channelHandler, group, bodyBytes, isStream, startTime, retryCount+1, newRetryErrors)
return
}
// ps.keyProvider.UpdateStatus(apiKey, group, true) // 请求成功不再重置成功次数减少IO消耗
logrus.Debugf("Request for group %s succeeded on attempt %d with key %s", group.Name, retryCount+1, utils.MaskAPIKey(apiKey.KeyValue))
ps.logRequest(c, group, apiKey, startTime, resp.StatusCode, retryCount+1, nil, isStream, upstreamURL, channelHandler, bodyBytes)
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
c.Status(resp.StatusCode)
if isStream {
ps.handleStreamingResponse(c, resp)
} else {
ps.handleNormalResponse(c, resp)
}
}
// logRequest is a helper function to create and record a request log.
func (ps *ProxyServer) logRequest(
c *gin.Context,
group *models.Group,
apiKey *models.APIKey,
startTime time.Time,
statusCode int,
retries int,
finalError error,
isStream bool,
upstreamAddr string,
channelHandler channel.ChannelProxy,
bodyBytes []byte,
) {
if ps.requestLogService == nil {
return
}
duration := time.Since(startTime).Milliseconds()
logEntry := &models.RequestLog{
GroupID: group.ID,
GroupName: group.Name,
IsSuccess: finalError == nil && statusCode < 400,
SourceIP: c.ClientIP(),
StatusCode: statusCode,
RequestPath: utils.TruncateString(c.Request.URL.String(), 500),
Duration: duration,
UserAgent: c.Request.UserAgent(),
Retries: retries,
IsStream: isStream,
UpstreamAddr: utils.TruncateString(upstreamAddr, 500),
}
if channelHandler != nil && bodyBytes != nil {
logEntry.Model = channelHandler.ExtractModel(c, bodyBytes)
}
if apiKey != nil {
logEntry.KeyValue = apiKey.KeyValue
}
if finalError != nil {
logEntry.ErrorMessage = finalError.Error()
}
if err := ps.requestLogService.Record(logEntry); err != nil {
logrus.Errorf("Failed to record request log: %v", err)
}
}