feat: 代理调试版本

This commit is contained in:
tbphp
2025-07-11 14:01:54 +08:00
parent 395d48c3e7
commit 6ffbb7e9a1
13 changed files with 568 additions and 255 deletions

View File

@@ -3,16 +3,11 @@ package channel
import (
"bytes"
"encoding/json"
"fmt"
"gpt-load/internal/models"
"io"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gorm.io/datatypes"
)
@@ -29,15 +24,13 @@ type BaseChannel struct {
Name string
Upstreams []UpstreamInfo
HTTPClient *http.Client
StreamClient *http.Client
TestModel string
upstreamLock sync.Mutex
groupUpstreams datatypes.JSON
groupConfig datatypes.JSONMap
}
// RequestModifier is a function that can modify the request before it's sent.
type RequestModifier func(req *http.Request, key *models.APIKey)
// getUpstreamURL selects an upstream URL using a smooth weighted round-robin algorithm.
func (b *BaseChannel) getUpstreamURL() *url.URL {
b.upstreamLock.Lock()
@@ -99,100 +92,12 @@ func (b *BaseChannel) IsConfigStale(group *models.Group) bool {
return false
}
// ProcessRequest handles the common logic of processing and forwarding a request.
func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier, ch ChannelProxy) error {
upstreamURL := b.getUpstreamURL()
if upstreamURL == nil {
return fmt.Errorf("no upstream URL configured for channel %s", b.Name)
}
director := func(req *http.Request) {
req.URL.Scheme = upstreamURL.Scheme
req.URL.Host = upstreamURL.Host
req.URL.Path = singleJoiningSlash(upstreamURL.Path, req.URL.Path)
req.Host = upstreamURL.Host
// Apply the channel-specific modifications
if modifier != nil {
modifier(req, apiKey)
}
// Remove headers that should not be forwarded
req.Header.Del("Cookie")
req.Header.Del("X-Real-Ip")
req.Header.Del("X-Forwarded-For")
}
errorHandler := func(rw http.ResponseWriter, req *http.Request, err error) {
logrus.WithFields(logrus.Fields{
"channel": b.Name,
"key_id": apiKey.ID,
"error": err,
}).Error("HTTP proxy error")
rw.WriteHeader(http.StatusBadGateway)
}
proxy := &httputil.ReverseProxy{
Director: director,
ErrorHandler: errorHandler,
Transport: b.HTTPClient.Transport,
}
// Check if the client request is for a streaming endpoint
if ch.IsStreamingRequest(c) {
return b.handleStreaming(c, proxy)
}
proxy.ServeHTTP(c.Writer, c.Request)
return nil
// GetHTTPClient returns the client for standard requests.
func (b *BaseChannel) GetHTTPClient() *http.Client {
return b.HTTPClient
}
func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReverseProxy) error {
var wg sync.WaitGroup
wg.Add(1)
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
// Use a pipe to avoid buffering the entire response
pr, pw := io.Pipe()
defer pr.Close()
req := c.Request.Clone(c.Request.Context())
req.Body = pr
// Start the proxy in a goroutine
go func() {
defer wg.Done()
defer pw.Close()
proxy.ServeHTTP(c.Writer, req)
}()
// Copy the original request body to the pipe writer
_, err := io.Copy(pw, c.Request.Body)
if err != nil {
logrus.Errorf("Error copying request body to pipe: %v", err)
wg.Wait() // Wait for the goroutine to finish even if copy fails
return err
}
// Wait for the proxy to finish
wg.Wait()
return nil
}
// singleJoiningSlash joins two URL paths with a single slash.
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
// GetStreamClient returns the client for streaming requests.
func (b *BaseChannel) GetStreamClient() *http.Client {
return b.StreamClient
}

View File

@@ -3,22 +3,35 @@ package channel
import (
"context"
"gpt-load/internal/models"
"net/http"
"net/url"
"github.com/gin-gonic/gin"
)
// ChannelProxy defines the interface for different API channel proxies.
// It's responsible for channel-specific logic like URL building and request modification.
type ChannelProxy interface {
// Handle takes a context, an API key, and the original request,
// then forwards the request to the upstream service.
Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error
// BuildUpstreamURL constructs the target URL for the upstream service.
BuildUpstreamURL(originalURL *url.URL, group *models.Group) (string, error)
// ModifyRequest allows the channel to add specific headers or modify the request
// before it's sent to the upstream service.
ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group)
// IsStreamRequest checks if the request is for a streaming response,
// now using the cached request body to avoid re-reading the stream.
IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
// ValidateKey checks if the given API key is valid.
ValidateKey(ctx context.Context, key string) (bool, error)
// IsStreamingRequest checks if the request is for a streaming response.
IsStreamingRequest(c *gin.Context) bool
// IsConfigStale checks if the channel's configuration is stale compared to the provided group.
IsConfigStale(group *models.Group) bool
// GetHTTPClient returns the client for standard requests.
GetHTTPClient() *http.Client
// GetStreamClient returns the client for streaming requests.
GetStreamClient() *http.Client
}

View File

@@ -4,8 +4,8 @@ import (
"encoding/json"
"fmt"
"gpt-load/internal/config"
"gpt-load/internal/httpclient"
"gpt-load/internal/models"
"net/http"
"net/url"
"sync"
"time"
@@ -42,14 +42,16 @@ func GetChannels() []string {
// Factory is responsible for creating channel proxies.
type Factory struct {
settingsManager *config.SystemSettingsManager
clientManager *httpclient.HTTPClientManager
channelCache map[uint]ChannelProxy
cacheLock sync.Mutex
}
// NewFactory creates a new channel factory.
func NewFactory(settingsManager *config.SystemSettingsManager) *Factory {
func NewFactory(settingsManager *config.SystemSettingsManager, clientManager *httpclient.HTTPClientManager) *Factory {
return &Factory{
settingsManager: settingsManager,
clientManager: clientManager,
channelCache: make(map[uint]ChannelProxy),
}
}
@@ -109,21 +111,39 @@ func (f *Factory) newBaseChannel(name string, group *models.Group) (*BaseChannel
upstreamInfos = append(upstreamInfos, UpstreamInfo{URL: u, Weight: weight})
}
// Get effective settings by merging system and group configs
effectiveSettings := f.settingsManager.GetEffectiveConfig(group.Config)
// Configure the HTTP client with the effective timeouts
httpClient := &http.Client{
Transport: &http.Transport{
IdleConnTimeout: time.Duration(effectiveSettings.IdleConnTimeout) * time.Second,
},
Timeout: time.Duration(effectiveSettings.RequestTimeout) * time.Second,
// Base configuration for regular requests, derived from the group's effective settings.
clientConfig := &httpclient.Config{
ConnectTimeout: time.Duration(group.EffectiveConfig.ConnectTimeout) * time.Second,
RequestTimeout: time.Duration(group.EffectiveConfig.RequestTimeout) * time.Second,
IdleConnTimeout: time.Duration(group.EffectiveConfig.IdleConnTimeout) * time.Second,
MaxIdleConns: group.EffectiveConfig.MaxIdleConns,
MaxIdleConnsPerHost: group.EffectiveConfig.MaxIdleConnsPerHost,
ResponseHeaderTimeout: time.Duration(group.EffectiveConfig.ResponseHeaderTimeout) * time.Second,
DisableCompression: group.EffectiveConfig.DisableCompression,
WriteBufferSize: 32 * 1024, // Use a reasonable default buffer size for regular requests
ReadBufferSize: 32 * 1024,
}
// Create a dedicated configuration for streaming requests.
// This configuration is optimized for low-latency, long-running connections.
streamConfig := *clientConfig
streamConfig.RequestTimeout = 0 // No overall timeout for the entire request.
streamConfig.DisableCompression = true // Always disable compression for streaming to reduce latency.
streamConfig.WriteBufferSize = 0 // Disable buffering for real-time data transfer.
streamConfig.ReadBufferSize = 0
// For stream-specific connection pool, we can use a simple heuristic like doubling the regular one.
streamConfig.MaxIdleConns = group.EffectiveConfig.MaxIdleConns * 2
streamConfig.MaxIdleConnsPerHost = group.EffectiveConfig.MaxIdleConnsPerHost * 2
// Get both clients from the manager using their respective configurations.
httpClient := f.clientManager.GetClient(clientConfig)
streamClient := f.clientManager.GetClient(&streamConfig)
return &BaseChannel{
Name: name,
Upstreams: upstreamInfos,
HTTPClient: httpClient,
StreamClient: streamClient,
TestModel: group.TestModel,
groupUpstreams: group.Upstreams,
groupConfig: group.Config,

View File

@@ -9,6 +9,7 @@ import (
"gpt-load/internal/models"
"io"
"net/http"
"net/url"
"strings"
"github.com/gin-gonic/gin"
@@ -33,13 +34,36 @@ func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
}, nil
}
func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error {
modifier := func(req *http.Request, key *models.APIKey) {
q := req.URL.Query()
q.Set("key", key.KeyValue)
req.URL.RawQuery = q.Encode()
// BuildUpstreamURL constructs the target URL for the Gemini service.
func (ch *GeminiChannel) BuildUpstreamURL(originalURL *url.URL, group *models.Group) (string, error) {
base := ch.getUpstreamURL()
if base == nil {
// Fallback to default Gemini URL
base, _ = url.Parse("https://generativelanguage.googleapis.com")
}
return ch.ProcessRequest(c, apiKey, modifier, ch)
finalURL := *base
// The originalURL.Path contains the full path, e.g., "/proxy/gemini/v1beta/models/gemini-pro:generateContent".
// We need to strip the proxy prefix to get the correct upstream path.
proxyPrefix := "/proxy/" + group.Name
if strings.HasPrefix(originalURL.Path, proxyPrefix) {
finalURL.Path = strings.TrimPrefix(originalURL.Path, proxyPrefix)
} else {
// Fallback for safety.
finalURL.Path = originalURL.Path
}
// The API key will be added to RawQuery in ModifyRequest.
finalURL.RawQuery = originalURL.RawQuery
return finalURL.String(), nil
}
// ModifyRequest adds the API key as a query parameter for Gemini requests.
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) {
q := req.URL.Query()
q.Set("key", apiKey.KeyValue)
req.URL.RawQuery = q.Encode()
}
// ValidateKey checks if the given API key is valid by making a generateContent request.
@@ -95,12 +119,21 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, err
return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
}
// IsStreamingRequest checks if the request is for a streaming response.
func (ch *GeminiChannel) IsStreamingRequest(c *gin.Context) bool {
// For Gemini, streaming is indicated by the path containing streaming keywords
// IsStreamRequest checks if the request is for a streaming response.
// For Gemini, this is primarily determined by the URL path.
func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool {
path := c.Request.URL.Path
return strings.Contains(path, ":streamGenerateContent") ||
strings.Contains(path, "streamGenerateContent") ||
strings.Contains(path, ":stream") ||
strings.Contains(path, "/stream")
if strings.HasSuffix(path, ":streamGenerateContent") {
return true
}
// Also check for standard streaming indicators as a fallback.
if strings.Contains(c.GetHeader("Accept"), "text/event-stream") {
return true
}
if c.Query("stream") == "true" {
return true
}
return false
}

View File

@@ -9,9 +9,10 @@ import (
"gpt-load/internal/models"
"io"
"net/http"
"net/url"
"strings"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
)
func init() {
@@ -33,11 +34,38 @@ func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
}, nil
}
func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error {
modifier := func(req *http.Request, key *models.APIKey) {
req.Header.Set("Authorization", "Bearer "+key.KeyValue)
// BuildUpstreamURL constructs the target URL for the OpenAI service.
func (ch *OpenAIChannel) BuildUpstreamURL(originalURL *url.URL, group *models.Group) (string, error) {
// Use the weighted round-robin selection from the base channel.
// This method already handles parsing the group's Upstreams JSON.
base := ch.getUpstreamURL()
if base == nil {
// If no upstreams are configured in the group, fallback to a default.
// This can be considered an error or a feature depending on requirements.
// For now, we'll use the official OpenAI URL as a last resort.
base, _ = url.Parse("https://api.openai.com")
}
return ch.ProcessRequest(c, apiKey, modifier, ch)
// It's crucial to create a copy to avoid modifying the cached URL object in BaseChannel.
finalURL := *base
// The originalURL.Path contains the full path, e.g., "/proxy/openai/v1/chat/completions".
// We need to strip the proxy prefix to get the correct upstream path.
proxyPrefix := "/proxy/" + group.Name
if strings.HasPrefix(originalURL.Path, proxyPrefix) {
finalURL.Path = strings.TrimPrefix(originalURL.Path, proxyPrefix)
} else {
// Fallback for safety, though this case should ideally not be hit.
finalURL.Path = originalURL.Path
}
finalURL.RawQuery = originalURL.RawQuery
return finalURL.String(), nil
}
// ModifyRequest sets the Authorization header for the OpenAI service.
func (ch *OpenAIChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) {
req.Header.Set("Authorization", "Bearer "+apiKey.KeyValue)
}
// ValidateKey checks if the given API key is valid by making a chat completion request.
@@ -92,16 +120,23 @@ func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, err
return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
}
// IsStreamingRequest checks if the request is for a streaming response.
func (ch *OpenAIChannel) IsStreamingRequest(c *gin.Context) bool {
// For OpenAI, streaming is indicated by a "stream": true field in the JSON body.
// We use ShouldBindBodyWith to check the body without consuming it, so it can be read again by the proxy.
// IsStreamRequest checks if the request is for a streaming response using the pre-read body.
func (ch *OpenAIChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool {
if strings.Contains(c.GetHeader("Accept"), "text/event-stream") {
return true
}
if c.Query("stream") == "true" {
return true
}
type streamPayload struct {
Stream bool `json:"stream"`
}
var p streamPayload
if err := c.ShouldBindBodyWith(&p, binding.JSON); err == nil {
if err := json.Unmarshal(bodyBytes, &p); err == nil {
return p.Stream
}
return false
}

View File

@@ -7,6 +7,7 @@ import (
"gpt-load/internal/config"
"gpt-load/internal/db"
"gpt-load/internal/handler"
"gpt-load/internal/httpclient"
"gpt-load/internal/keypool"
"gpt-load/internal/proxy"
"gpt-load/internal/router"
@@ -36,6 +37,9 @@ func BuildContainer() (*dig.Container, error) {
if err := container.Provide(store.NewStore); err != nil {
return nil, err
}
if err := container.Provide(httpclient.NewHTTPClientManager); err != nil {
return nil, err
}
if err := container.Provide(channel.NewFactory); err != nil {
return nil, err
}

View File

@@ -34,6 +34,8 @@ var (
ErrTaskInProgress = &APIError{HTTPStatus: http.StatusConflict, Code: "TASK_IN_PROGRESS", Message: "A task is already in progress"}
ErrBadGateway = &APIError{HTTPStatus: http.StatusBadGateway, Code: "BAD_GATEWAY", Message: "Upstream service error"}
ErrNoActiveKeys = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_ACTIVE_KEYS", Message: "No active API keys available for this group"}
ErrMaxRetriesExceeded = &APIError{HTTPStatus: http.StatusBadGateway, Code: "MAX_RETRIES_EXCEEDED", Message: "Request failed after maximum retries"}
ErrNoKeysAvailable = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_KEYS_AVAILABLE", Message: "No API keys available to process the request"}
)
// NewAPIError creates a new APIError with a custom message.
@@ -45,6 +47,15 @@ func NewAPIError(base *APIError, message string) *APIError {
}
}
// NewAPIErrorWithUpstream creates a new APIError specifically for wrapping raw upstream errors.
func NewAPIErrorWithUpstream(statusCode int, code string, upstreamMessage string) *APIError {
return &APIError{
HTTPStatus: statusCode,
Code: code,
Message: upstreamMessage,
}
}
// ParseDBError intelligently converts a GORM error into a standard APIError.
func ParseDBError(err error) *APIError {
if err == nil {

View File

@@ -0,0 +1,31 @@
package errors
import (
"strings"
)
// ignorableErrorSubstrings contains a list of substrings that indicate an error
// can be safely ignored. These typically occur when a client disconnects prematurely.
var ignorableErrorSubstrings = []string{
"context canceled",
"connection reset by peer",
"broken pipe",
"use of closed network connection",
"request canceled",
}
// IsIgnorableError checks if the given error is a common, non-critical error
// that can occur when a client disconnects. This is used to prevent logging
// unnecessary errors and to avoid marking keys as failed for client-side issues.
func IsIgnorableError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
for _, sub := range ignorableErrorSubstrings {
if strings.Contains(errStr, sub) {
return true
}
}
return false
}

View File

@@ -0,0 +1,105 @@
package httpclient
import (
"fmt"
"net"
"net/http"
"sync"
"time"
)
// Config defines the parameters for creating an HTTP client.
// This struct is used to generate a unique fingerprint for client reuse.
type Config struct {
ConnectTimeout time.Duration
RequestTimeout time.Duration
IdleConnTimeout time.Duration
MaxIdleConns int
MaxIdleConnsPerHost int
ResponseHeaderTimeout time.Duration
DisableCompression bool
WriteBufferSize int
ReadBufferSize int
}
// HTTPClientManager manages the lifecycle of HTTP clients.
// It creates and caches clients based on their configuration fingerprint,
// ensuring that clients with the same configuration are reused.
type HTTPClientManager struct {
clients map[string]*http.Client
lock sync.RWMutex
}
// NewHTTPClientManager creates a new client manager.
func NewHTTPClientManager() *HTTPClientManager {
return &HTTPClientManager{
clients: make(map[string]*http.Client),
}
}
// GetClient returns an HTTP client that matches the given configuration.
// If a matching client already exists in the cache, it is returned.
// Otherwise, a new client is created, cached, and returned.
func (m *HTTPClientManager) GetClient(config *Config) *http.Client {
fingerprint := config.getFingerprint()
// Fast path with read lock
m.lock.RLock()
client, exists := m.clients[fingerprint]
m.lock.RUnlock()
if exists {
return client
}
// Slow path with write lock
m.lock.Lock()
defer m.lock.Unlock()
// Double-check in case another goroutine created the client while we were waiting for the lock.
if client, exists = m.clients[fingerprint]; exists {
return client
}
// Create a new transport and client with the specified configuration.
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: config.ConnectTimeout,
KeepAlive: 30 * time.Second, // KeepAlive is a good default
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: config.MaxIdleConns,
MaxIdleConnsPerHost: config.MaxIdleConnsPerHost,
IdleConnTimeout: config.IdleConnTimeout,
TLSHandshakeTimeout: 10 * time.Second, // A reasonable default
ExpectContinueTimeout: 1 * time.Second, // A reasonable default
ResponseHeaderTimeout: config.ResponseHeaderTimeout,
DisableCompression: config.DisableCompression,
WriteBufferSize: config.WriteBufferSize,
ReadBufferSize: config.ReadBufferSize,
}
newClient := &http.Client{
Transport: transport,
Timeout: config.RequestTimeout,
}
m.clients[fingerprint] = newClient
return newClient
}
// getFingerprint generates a unique string representation of the client configuration.
func (c *Config) getFingerprint() string {
return fmt.Sprintf(
"ct:%.0fs|rt:%.0fs|it:%.0fs|mic:%d|mich:%d|rht:%.0fs|dc:%t|wbs:%d|rbs:%d",
c.ConnectTimeout.Seconds(),
c.RequestTimeout.Seconds(),
c.IdleConnTimeout.Seconds(),
c.MaxIdleConns,
c.MaxIdleConnsPerHost,
c.ResponseHeaderTimeout.Seconds(),
c.DisableCompression,
c.WriteBufferSize,
c.ReadBufferSize,
)
}

View File

@@ -35,6 +35,11 @@ type GroupConfig struct {
ResponseTimeout *int `json:"response_timeout,omitempty"`
IdleConnTimeout *int `json:"idle_conn_timeout,omitempty"`
KeyValidationIntervalMinutes *int `json:"key_validation_interval_minutes,omitempty"`
ConnectTimeout *int `json:"connect_timeout,omitempty"`
MaxIdleConns *int `json:"max_idle_conns,omitempty"`
MaxIdleConnsPerHost *int `json:"max_idle_conns_per_host,omitempty"`
ResponseHeaderTimeout *int `json:"response_header_timeout,omitempty"`
DisableCompression *bool `json:"disable_compression,omitempty"`
}
// Group 对应 groups 表

View File

@@ -2,171 +2,311 @@
package proxy
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"gpt-load/internal/channel"
"gpt-load/internal/config"
app_errors "gpt-load/internal/errors"
"gpt-load/internal/keypool"
"gpt-load/internal/models"
"gpt-load/internal/response"
"io"
"time"
"gpt-load/internal/services"
"gpt-load/internal/types"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// A list of errors that are considered normal during streaming when a client disconnects.
var ignorableStreamErrors = []string{
"context canceled",
"connection reset by peer",
"broken pipe",
"use of closed network connection",
}
// isIgnorableStreamError checks if the error is a common, non-critical error that can occur
// when a client disconnects during a streaming response.
func isIgnorableStreamError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
for _, ignorableError := range ignorableStreamErrors {
if strings.Contains(errStr, ignorableError) {
return true
}
}
return false
}
// ProxyServer represents the proxy server
type ProxyServer struct {
DB *gorm.DB
channelFactory *channel.Factory
keyProvider *keypool.KeyProvider
requestLogChan chan models.RequestLog
keyProvider *keypool.KeyProvider
groupManager *services.GroupManager
settingsManager *config.SystemSettingsManager
channelFactory *channel.Factory
}
// NewProxyServer creates a new proxy server
func NewProxyServer(
db *gorm.DB,
channelFactory *channel.Factory,
keyProvider *keypool.KeyProvider,
requestLogChan chan models.RequestLog,
groupManager *services.GroupManager,
settingsManager *config.SystemSettingsManager,
channelFactory *channel.Factory,
) (*ProxyServer, error) {
return &ProxyServer{
DB: db,
channelFactory: channelFactory,
keyProvider: keyProvider,
requestLogChan: requestLogChan,
keyProvider: keyProvider,
groupManager: groupManager,
settingsManager: settingsManager,
channelFactory: channelFactory,
}, nil
}
// HandleProxy handles the main proxy logic
// HandleProxy is the main entry point for proxy requests, refactored based on the stable .bak logic.
func (ps *ProxyServer) HandleProxy(c *gin.Context) {
startTime := time.Now()
groupName := c.Param("group_name")
// 1. Find the group by name (without preloading keys)
var group models.Group
if err := ps.DB.Where("name = ?", groupName).First(&group).Error; err != nil {
group, err := ps.groupManager.GetGroupByName(groupName)
if err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
// 2. Select an available API key from the KeyPool
apiKey, err := ps.keyProvider.SelectKey(group.ID)
if err != nil {
// Properly handle the case where no keys are available
if apiErr, ok := err.(*app_errors.APIError); ok {
response.Error(c, apiErr)
} else {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, err.Error()))
}
return
}
// 3. Get the appropriate channel handler from the factory
channelHandler, err := ps.channelFactory.GetChannel(&group)
channelHandler, err := ps.channelFactory.GetChannel(group)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s': %v", groupName, err)))
return
}
// 4. Apply parameter overrides if they exist
if len(group.ParamOverrides) > 0 {
err := ps.applyParamOverrides(c, &group)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to apply parameter overrides: %v", err)))
return
}
}
// 5. Forward the request using the channel handler
err = channelHandler.Handle(c, apiKey, &group)
// 6. Update key status and log the request asynchronously
isSuccess := err == nil
ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, isSuccess)
if !isSuccess {
logrus.WithFields(logrus.Fields{
"group": group.Name,
"key_id": apiKey.ID,
"error": err.Error(),
}).Error("Channel handler failed")
}
go ps.logRequest(c, &group, apiKey, startTime)
}
func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time) {
logEntry := models.RequestLog{
ID: fmt.Sprintf("req_%d", time.Now().UnixNano()),
Timestamp: startTime,
GroupID: group.ID,
KeyID: key.ID,
SourceIP: c.ClientIP(),
StatusCode: c.Writer.Status(),
RequestPath: c.Request.URL.Path,
RequestBodySnippet: "", // Can be implemented later if needed
}
// Send to the logging channel without blocking
select {
case ps.requestLogChan <- logEntry:
default:
logrus.Warn("Request log channel is full. Dropping log entry.")
}
}
// Close cleans up resources
func (ps *ProxyServer) Close() {
close(ps.requestLogChan)
}
func (ps *ProxyServer) applyParamOverrides(c *gin.Context, group *models.Group) error {
// Read the original request body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return fmt.Errorf("failed to read request body: %w", err)
logrus.Errorf("Failed to read request body: %v", err)
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Failed to read request body"))
return
}
c.Request.Body.Close()
// If body is empty, nothing to override, just restore the body
if len(bodyBytes) == 0 {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
return nil
// 4. Apply parameter overrides if any.
finalBodyBytes, err := ps.applyParamOverrides(bodyBytes, group)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to apply parameter overrides: %v", err)))
return
}
// Save the original Content-Type
originalContentType := c.GetHeader("Content-Type")
// 5. Determine if this is a streaming request.
isStream := channelHandler.IsStreamRequest(c, bodyBytes)
// 6. Execute the request using the recursive retry logic.
ps.executeRequestWithRetry(c, channelHandler, group, finalBodyBytes, isStream, startTime, 0, nil)
}
// executeRequestWithRetry is the core recursive function for handling requests and retries.
func (ps *ProxyServer) executeRequestWithRetry(
c *gin.Context,
channelHandler channel.ChannelProxy,
group *models.Group,
bodyBytes []byte,
isStream bool,
startTime time.Time,
retryCount int,
retryErrors []types.RetryError,
) {
cfg := group.EffectiveConfig
if retryCount > cfg.MaxRetries {
logrus.Errorf("Max retries exceeded for group %s after %d attempts.", group.Name, retryCount)
if len(retryErrors) > 0 {
lastError := retryErrors[len(retryErrors)-1]
var errorJSON map[string]any
if err := json.Unmarshal([]byte(lastError.ErrorMessage), &errorJSON); err == nil {
c.JSON(lastError.StatusCode, errorJSON)
} else {
response.Error(c, app_errors.NewAPIErrorWithUpstream(lastError.StatusCode, "UPSTREAM_ERROR", lastError.ErrorMessage))
}
} else {
response.Error(c, app_errors.ErrMaxRetriesExceeded)
}
return
}
apiKey, err := ps.keyProvider.SelectKey(group.ID)
if err != nil {
logrus.Errorf("Failed to select a key for group %s on attempt %d: %v", group.Name, retryCount+1, err)
response.Error(c, app_errors.NewAPIError(app_errors.ErrNoKeysAvailable, err.Error()))
return
}
upstreamURL, err := channelHandler.BuildUpstreamURL(c.Request.URL, group)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to build upstream URL: %v", err)))
return
}
var ctx context.Context
var cancel context.CancelFunc
if isStream {
ctx, cancel = context.WithCancel(c.Request.Context())
} else {
timeout := time.Duration(cfg.RequestTimeout) * time.Second
ctx, cancel = context.WithTimeout(c.Request.Context(), timeout)
}
defer cancel()
req, err := http.NewRequestWithContext(ctx, c.Request.Method, upstreamURL, bytes.NewReader(bodyBytes))
if err != nil {
logrus.Errorf("Failed to create upstream request: %v", err)
response.Error(c, app_errors.ErrInternalServer)
return
}
req.ContentLength = int64(len(bodyBytes))
req.Header = c.Request.Header.Clone()
channelHandler.ModifyRequest(req, apiKey, group)
client := channelHandler.GetHTTPClient()
if isStream {
client = channelHandler.GetStreamClient()
req.Header.Set("X-Accel-Buffering", "no")
}
resp, err := client.Do(req)
if err != nil {
ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, false)
logrus.Warnf("Request failed (attempt %d/%d) for key %s: %v", retryCount+1, cfg.MaxRetries, apiKey.KeyValue[:8], err)
newRetryErrors := append(retryErrors, types.RetryError{
StatusCode: 0,
ErrorMessage: err.Error(),
KeyID: fmt.Sprintf("%d", apiKey.ID),
Attempt: retryCount + 1,
})
ps.executeRequestWithRetry(c, channelHandler, group, bodyBytes, isStream, startTime, retryCount+1, newRetryErrors)
return
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, false)
errorBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
logrus.Errorf("Failed to read error body: %v", readErr)
// Even if reading fails, we should proceed with retry logic
errorBody = []byte("Failed to read error body")
}
// Check for gzip encoding and decompress if necessary.
if resp.Header.Get("Content-Encoding") == "gzip" {
reader, err := gzip.NewReader(bytes.NewReader(errorBody))
if err == nil {
decompressedBody, err := io.ReadAll(reader)
if err == nil {
errorBody = decompressedBody
} else {
logrus.Warnf("Failed to decompress gzip error body: %v", err)
}
reader.Close()
} else {
logrus.Warnf("Failed to create gzip reader for error body: %v", err)
}
}
logrus.Warnf("Request failed with status %d (attempt %d/%d) for key %s. Body: %s", resp.StatusCode, retryCount+1, cfg.MaxRetries, apiKey.KeyValue[:8], string(errorBody))
newRetryErrors := append(retryErrors, types.RetryError{
StatusCode: resp.StatusCode,
ErrorMessage: string(errorBody),
KeyID: fmt.Sprintf("%d", apiKey.ID),
Attempt: retryCount + 1,
})
ps.executeRequestWithRetry(c, channelHandler, group, bodyBytes, isStream, startTime, retryCount+1, newRetryErrors)
return
}
ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, true)
logrus.Debugf("Request for group %s succeeded on attempt %d with key %s", group.Name, retryCount+1, apiKey.KeyValue[:8])
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
c.Status(resp.StatusCode)
if isStream {
ps.handleStreamingResponse(c, resp)
} else {
ps.handleNormalResponse(c, resp)
}
}
func (ps *ProxyServer) handleStreamingResponse(c *gin.Context, resp *http.Response) {
flusher, ok := c.Writer.(http.Flusher)
if !ok {
logrus.Error("Streaming unsupported by the writer")
ps.handleNormalResponse(c, resp)
return
}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
if _, err := c.Writer.Write(scanner.Bytes()); err != nil {
if !isIgnorableStreamError(err) {
logrus.Errorf("Error writing to client: %v", err)
}
return
}
if _, err := c.Writer.Write([]byte("\n")); err != nil {
if !isIgnorableStreamError(err) {
logrus.Errorf("Error writing newline to client: %v", err)
}
return
}
flusher.Flush()
}
if err := scanner.Err(); err != nil && !isIgnorableStreamError(err) {
logrus.Errorf("Error reading from upstream: %v", err)
}
}
func (ps *ProxyServer) handleNormalResponse(c *gin.Context, resp *http.Response) {
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
if !isIgnorableStreamError(err) {
logrus.Errorf("Failed to copy response body to client: %v", err)
}
}
}
func (ps *ProxyServer) applyParamOverrides(bodyBytes []byte, group *models.Group) ([]byte, error) {
if len(group.ParamOverrides) == 0 || len(bodyBytes) == 0 {
return bodyBytes, nil
}
// Unmarshal the body into a map
var requestData map[string]any
if err := json.Unmarshal(bodyBytes, &requestData); err != nil {
// If not a valid JSON, just pass it through
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
return nil
logrus.Warnf("failed to unmarshal request body for param override, passing through: %v", err)
return bodyBytes, nil
}
// Merge the overrides into the request data
for key, value := range group.ParamOverrides {
requestData[key] = value
}
// Marshal the new data back to JSON
newBodyBytes, err := json.Marshal(requestData)
if err != nil {
return fmt.Errorf("failed to marshal new request body: %w", err)
}
// Replace the request body with the new one
c.Request.Body = io.NopCloser(bytes.NewBuffer(newBodyBytes))
c.Request.ContentLength = int64(len(newBodyBytes))
// Restore the original Content-Type header
if originalContentType != "" {
c.Request.Header.Set("Content-Type", originalContentType)
}
return nil
return json.Marshal(requestData)
}
func (ps *ProxyServer) Close() {
// The HTTP clients are now managed by the channel factory and httpclient manager,
// so the proxy server itself doesn't need to close them.
// The httpclient manager will handle closing idle connections for all its clients.
}

9
internal/types/retry.go Normal file
View File

@@ -0,0 +1,9 @@
package types
// RetryError captures detailed information about a failed request attempt during retries.
type RetryError struct {
StatusCode int `json:"status_code"`
ErrorMessage string `json:"error_message"`
KeyID string `json:"key_id"`
Attempt int `json:"attempt"`
}

View File

@@ -32,6 +32,8 @@ type SystemSettings struct {
IdleConnTimeout int `json:"idle_conn_timeout" default:"120" name:"空闲连接超时" category:"请求超时" desc:"HTTP 客户端中空闲连接的超时时间(秒)。" validate:"min=1"`
MaxIdleConns int `json:"max_idle_conns" default:"100" name:"最大空闲连接数" category:"请求超时" desc:"HTTP 客户端连接池中允许的最大空闲连接总数。" validate:"min=1"`
MaxIdleConnsPerHost int `json:"max_idle_conns_per_host" default:"10" name:"每主机最大空闲连接数" category:"请求超时" desc:"HTTP 客户端连接池对每个上游主机允许的最大空闲连接数。" validate:"min=1"`
ResponseHeaderTimeout int `json:"response_header_timeout" default:"120" name:"响应头超时" category:"请求超时" desc:"等待上游服务响应头的最长时间(秒),用于流式请求。" validate:"min=1"`
DisableCompression bool `json:"disable_compression" default:"false" name:"禁用压缩" category:"请求超时" desc:"是否禁用对上游请求的传输压缩Gzip。对于流式请求建议开启以降低延迟。"`
// 密钥配置
MaxRetries int `json:"max_retries" default:"3" name:"最大重试次数" category:"密钥配置" desc:"单个请求使用不同 Key 的最大重试次数" validate:"min=0"`