diff --git a/internal/channel/base_channel.go b/internal/channel/base_channel.go index 7257d8a..77f5647 100644 --- a/internal/channel/base_channel.go +++ b/internal/channel/base_channel.go @@ -3,16 +3,11 @@ package channel import ( "bytes" "encoding/json" - "fmt" "gpt-load/internal/models" - "io" "net/http" - "net/http/httputil" "net/url" - "strings" "sync" - "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" "gorm.io/datatypes" ) @@ -29,15 +24,13 @@ type BaseChannel struct { Name string Upstreams []UpstreamInfo HTTPClient *http.Client + StreamClient *http.Client TestModel string upstreamLock sync.Mutex groupUpstreams datatypes.JSON groupConfig datatypes.JSONMap } -// RequestModifier is a function that can modify the request before it's sent. -type RequestModifier func(req *http.Request, key *models.APIKey) - // getUpstreamURL selects an upstream URL using a smooth weighted round-robin algorithm. func (b *BaseChannel) getUpstreamURL() *url.URL { b.upstreamLock.Lock() @@ -99,100 +92,12 @@ func (b *BaseChannel) IsConfigStale(group *models.Group) bool { return false } -// ProcessRequest handles the common logic of processing and forwarding a request. -func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier, ch ChannelProxy) error { - upstreamURL := b.getUpstreamURL() - if upstreamURL == nil { - return fmt.Errorf("no upstream URL configured for channel %s", b.Name) - } - - director := func(req *http.Request) { - req.URL.Scheme = upstreamURL.Scheme - req.URL.Host = upstreamURL.Host - req.URL.Path = singleJoiningSlash(upstreamURL.Path, req.URL.Path) - req.Host = upstreamURL.Host - - // Apply the channel-specific modifications - if modifier != nil { - modifier(req, apiKey) - } - - // Remove headers that should not be forwarded - req.Header.Del("Cookie") - req.Header.Del("X-Real-Ip") - req.Header.Del("X-Forwarded-For") - } - - errorHandler := func(rw http.ResponseWriter, req *http.Request, err error) { - logrus.WithFields(logrus.Fields{ - "channel": b.Name, - "key_id": apiKey.ID, - "error": err, - }).Error("HTTP proxy error") - rw.WriteHeader(http.StatusBadGateway) - } - - proxy := &httputil.ReverseProxy{ - Director: director, - ErrorHandler: errorHandler, - Transport: b.HTTPClient.Transport, - } - - // Check if the client request is for a streaming endpoint - if ch.IsStreamingRequest(c) { - return b.handleStreaming(c, proxy) - } - - proxy.ServeHTTP(c.Writer, c.Request) - return nil +// GetHTTPClient returns the client for standard requests. +func (b *BaseChannel) GetHTTPClient() *http.Client { + return b.HTTPClient } -func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReverseProxy) error { - var wg sync.WaitGroup - wg.Add(1) - - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Access-Control-Allow-Origin", "*") - - // Use a pipe to avoid buffering the entire response - pr, pw := io.Pipe() - defer pr.Close() - - req := c.Request.Clone(c.Request.Context()) - req.Body = pr - - // Start the proxy in a goroutine - go func() { - defer wg.Done() - defer pw.Close() - proxy.ServeHTTP(c.Writer, req) - }() - - // Copy the original request body to the pipe writer - _, err := io.Copy(pw, c.Request.Body) - if err != nil { - logrus.Errorf("Error copying request body to pipe: %v", err) - wg.Wait() // Wait for the goroutine to finish even if copy fails - return err - } - - // Wait for the proxy to finish - wg.Wait() - - return nil -} - -// singleJoiningSlash joins two URL paths with a single slash. -func singleJoiningSlash(a, b string) string { - aslash := strings.HasSuffix(a, "/") - bslash := strings.HasPrefix(b, "/") - switch { - case aslash && bslash: - return a + b[1:] - case !aslash && !bslash: - return a + "/" + b - } - return a + b +// GetStreamClient returns the client for streaming requests. +func (b *BaseChannel) GetStreamClient() *http.Client { + return b.StreamClient } diff --git a/internal/channel/channel.go b/internal/channel/channel.go index 96604e2..53ae3bf 100644 --- a/internal/channel/channel.go +++ b/internal/channel/channel.go @@ -3,22 +3,35 @@ package channel import ( "context" "gpt-load/internal/models" + "net/http" + "net/url" "github.com/gin-gonic/gin" ) // ChannelProxy defines the interface for different API channel proxies. +// It's responsible for channel-specific logic like URL building and request modification. type ChannelProxy interface { - // Handle takes a context, an API key, and the original request, - // then forwards the request to the upstream service. - Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error + // BuildUpstreamURL constructs the target URL for the upstream service. + BuildUpstreamURL(originalURL *url.URL, group *models.Group) (string, error) + + // ModifyRequest allows the channel to add specific headers or modify the request + // before it's sent to the upstream service. + ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) + + // IsStreamRequest checks if the request is for a streaming response, + // now using the cached request body to avoid re-reading the stream. + IsStreamRequest(c *gin.Context, bodyBytes []byte) bool // ValidateKey checks if the given API key is valid. ValidateKey(ctx context.Context, key string) (bool, error) - // IsStreamingRequest checks if the request is for a streaming response. - IsStreamingRequest(c *gin.Context) bool - // IsConfigStale checks if the channel's configuration is stale compared to the provided group. IsConfigStale(group *models.Group) bool + + // GetHTTPClient returns the client for standard requests. + GetHTTPClient() *http.Client + + // GetStreamClient returns the client for streaming requests. + GetStreamClient() *http.Client } diff --git a/internal/channel/factory.go b/internal/channel/factory.go index ba8addf..c647765 100644 --- a/internal/channel/factory.go +++ b/internal/channel/factory.go @@ -4,8 +4,8 @@ import ( "encoding/json" "fmt" "gpt-load/internal/config" + "gpt-load/internal/httpclient" "gpt-load/internal/models" - "net/http" "net/url" "sync" "time" @@ -42,14 +42,16 @@ func GetChannels() []string { // Factory is responsible for creating channel proxies. type Factory struct { settingsManager *config.SystemSettingsManager + clientManager *httpclient.HTTPClientManager channelCache map[uint]ChannelProxy cacheLock sync.Mutex } // NewFactory creates a new channel factory. -func NewFactory(settingsManager *config.SystemSettingsManager) *Factory { +func NewFactory(settingsManager *config.SystemSettingsManager, clientManager *httpclient.HTTPClientManager) *Factory { return &Factory{ settingsManager: settingsManager, + clientManager: clientManager, channelCache: make(map[uint]ChannelProxy), } } @@ -109,21 +111,39 @@ func (f *Factory) newBaseChannel(name string, group *models.Group) (*BaseChannel upstreamInfos = append(upstreamInfos, UpstreamInfo{URL: u, Weight: weight}) } - // Get effective settings by merging system and group configs - effectiveSettings := f.settingsManager.GetEffectiveConfig(group.Config) - - // Configure the HTTP client with the effective timeouts - httpClient := &http.Client{ - Transport: &http.Transport{ - IdleConnTimeout: time.Duration(effectiveSettings.IdleConnTimeout) * time.Second, - }, - Timeout: time.Duration(effectiveSettings.RequestTimeout) * time.Second, + // Base configuration for regular requests, derived from the group's effective settings. + clientConfig := &httpclient.Config{ + ConnectTimeout: time.Duration(group.EffectiveConfig.ConnectTimeout) * time.Second, + RequestTimeout: time.Duration(group.EffectiveConfig.RequestTimeout) * time.Second, + IdleConnTimeout: time.Duration(group.EffectiveConfig.IdleConnTimeout) * time.Second, + MaxIdleConns: group.EffectiveConfig.MaxIdleConns, + MaxIdleConnsPerHost: group.EffectiveConfig.MaxIdleConnsPerHost, + ResponseHeaderTimeout: time.Duration(group.EffectiveConfig.ResponseHeaderTimeout) * time.Second, + DisableCompression: group.EffectiveConfig.DisableCompression, + WriteBufferSize: 32 * 1024, // Use a reasonable default buffer size for regular requests + ReadBufferSize: 32 * 1024, } + // Create a dedicated configuration for streaming requests. + // This configuration is optimized for low-latency, long-running connections. + streamConfig := *clientConfig + streamConfig.RequestTimeout = 0 // No overall timeout for the entire request. + streamConfig.DisableCompression = true // Always disable compression for streaming to reduce latency. + streamConfig.WriteBufferSize = 0 // Disable buffering for real-time data transfer. + streamConfig.ReadBufferSize = 0 + // For stream-specific connection pool, we can use a simple heuristic like doubling the regular one. + streamConfig.MaxIdleConns = group.EffectiveConfig.MaxIdleConns * 2 + streamConfig.MaxIdleConnsPerHost = group.EffectiveConfig.MaxIdleConnsPerHost * 2 + + // Get both clients from the manager using their respective configurations. + httpClient := f.clientManager.GetClient(clientConfig) + streamClient := f.clientManager.GetClient(&streamConfig) + return &BaseChannel{ Name: name, Upstreams: upstreamInfos, HTTPClient: httpClient, + StreamClient: streamClient, TestModel: group.TestModel, groupUpstreams: group.Upstreams, groupConfig: group.Config, diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index 9f9077b..4d8b957 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -9,6 +9,7 @@ import ( "gpt-load/internal/models" "io" "net/http" + "net/url" "strings" "github.com/gin-gonic/gin" @@ -33,13 +34,36 @@ func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) { }, nil } -func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error { - modifier := func(req *http.Request, key *models.APIKey) { - q := req.URL.Query() - q.Set("key", key.KeyValue) - req.URL.RawQuery = q.Encode() +// BuildUpstreamURL constructs the target URL for the Gemini service. +func (ch *GeminiChannel) BuildUpstreamURL(originalURL *url.URL, group *models.Group) (string, error) { + base := ch.getUpstreamURL() + if base == nil { + // Fallback to default Gemini URL + base, _ = url.Parse("https://generativelanguage.googleapis.com") } - return ch.ProcessRequest(c, apiKey, modifier, ch) + + finalURL := *base + // The originalURL.Path contains the full path, e.g., "/proxy/gemini/v1beta/models/gemini-pro:generateContent". + // We need to strip the proxy prefix to get the correct upstream path. + proxyPrefix := "/proxy/" + group.Name + if strings.HasPrefix(originalURL.Path, proxyPrefix) { + finalURL.Path = strings.TrimPrefix(originalURL.Path, proxyPrefix) + } else { + // Fallback for safety. + finalURL.Path = originalURL.Path + } + + // The API key will be added to RawQuery in ModifyRequest. + finalURL.RawQuery = originalURL.RawQuery + + return finalURL.String(), nil +} + +// ModifyRequest adds the API key as a query parameter for Gemini requests. +func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) { + q := req.URL.Query() + q.Set("key", apiKey.KeyValue) + req.URL.RawQuery = q.Encode() } // ValidateKey checks if the given API key is valid by making a generateContent request. @@ -95,12 +119,21 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, err return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError) } -// IsStreamingRequest checks if the request is for a streaming response. -func (ch *GeminiChannel) IsStreamingRequest(c *gin.Context) bool { - // For Gemini, streaming is indicated by the path containing streaming keywords +// IsStreamRequest checks if the request is for a streaming response. +// For Gemini, this is primarily determined by the URL path. +func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool { path := c.Request.URL.Path - return strings.Contains(path, ":streamGenerateContent") || - strings.Contains(path, "streamGenerateContent") || - strings.Contains(path, ":stream") || - strings.Contains(path, "/stream") + if strings.HasSuffix(path, ":streamGenerateContent") { + return true + } + + // Also check for standard streaming indicators as a fallback. + if strings.Contains(c.GetHeader("Accept"), "text/event-stream") { + return true + } + if c.Query("stream") == "true" { + return true + } + + return false } diff --git a/internal/channel/openai_channel.go b/internal/channel/openai_channel.go index 3ecd7e5..179ad2b 100644 --- a/internal/channel/openai_channel.go +++ b/internal/channel/openai_channel.go @@ -9,9 +9,10 @@ import ( "gpt-load/internal/models" "io" "net/http" + "net/url" + "strings" "github.com/gin-gonic/gin" - "github.com/gin-gonic/gin/binding" ) func init() { @@ -33,11 +34,38 @@ func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) { }, nil } -func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error { - modifier := func(req *http.Request, key *models.APIKey) { - req.Header.Set("Authorization", "Bearer "+key.KeyValue) +// BuildUpstreamURL constructs the target URL for the OpenAI service. +func (ch *OpenAIChannel) BuildUpstreamURL(originalURL *url.URL, group *models.Group) (string, error) { + // Use the weighted round-robin selection from the base channel. + // This method already handles parsing the group's Upstreams JSON. + base := ch.getUpstreamURL() + if base == nil { + // If no upstreams are configured in the group, fallback to a default. + // This can be considered an error or a feature depending on requirements. + // For now, we'll use the official OpenAI URL as a last resort. + base, _ = url.Parse("https://api.openai.com") } - return ch.ProcessRequest(c, apiKey, modifier, ch) + + // It's crucial to create a copy to avoid modifying the cached URL object in BaseChannel. + finalURL := *base + // The originalURL.Path contains the full path, e.g., "/proxy/openai/v1/chat/completions". + // We need to strip the proxy prefix to get the correct upstream path. + proxyPrefix := "/proxy/" + group.Name + if strings.HasPrefix(originalURL.Path, proxyPrefix) { + finalURL.Path = strings.TrimPrefix(originalURL.Path, proxyPrefix) + } else { + // Fallback for safety, though this case should ideally not be hit. + finalURL.Path = originalURL.Path + } + + finalURL.RawQuery = originalURL.RawQuery + + return finalURL.String(), nil +} + +// ModifyRequest sets the Authorization header for the OpenAI service. +func (ch *OpenAIChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) { + req.Header.Set("Authorization", "Bearer "+apiKey.KeyValue) } // ValidateKey checks if the given API key is valid by making a chat completion request. @@ -92,16 +120,23 @@ func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, err return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError) } -// IsStreamingRequest checks if the request is for a streaming response. -func (ch *OpenAIChannel) IsStreamingRequest(c *gin.Context) bool { - // For OpenAI, streaming is indicated by a "stream": true field in the JSON body. - // We use ShouldBindBodyWith to check the body without consuming it, so it can be read again by the proxy. +// IsStreamRequest checks if the request is for a streaming response using the pre-read body. +func (ch *OpenAIChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool { + if strings.Contains(c.GetHeader("Accept"), "text/event-stream") { + return true + } + + if c.Query("stream") == "true" { + return true + } + type streamPayload struct { Stream bool `json:"stream"` } var p streamPayload - if err := c.ShouldBindBodyWith(&p, binding.JSON); err == nil { + if err := json.Unmarshal(bodyBytes, &p); err == nil { return p.Stream } + return false } diff --git a/internal/container/container.go b/internal/container/container.go index dfe1a43..1fa256b 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -7,6 +7,7 @@ import ( "gpt-load/internal/config" "gpt-load/internal/db" "gpt-load/internal/handler" + "gpt-load/internal/httpclient" "gpt-load/internal/keypool" "gpt-load/internal/proxy" "gpt-load/internal/router" @@ -36,6 +37,9 @@ func BuildContainer() (*dig.Container, error) { if err := container.Provide(store.NewStore); err != nil { return nil, err } + if err := container.Provide(httpclient.NewHTTPClientManager); err != nil { + return nil, err + } if err := container.Provide(channel.NewFactory); err != nil { return nil, err } diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 4431907..52096c6 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -34,6 +34,8 @@ var ( ErrTaskInProgress = &APIError{HTTPStatus: http.StatusConflict, Code: "TASK_IN_PROGRESS", Message: "A task is already in progress"} ErrBadGateway = &APIError{HTTPStatus: http.StatusBadGateway, Code: "BAD_GATEWAY", Message: "Upstream service error"} ErrNoActiveKeys = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_ACTIVE_KEYS", Message: "No active API keys available for this group"} + ErrMaxRetriesExceeded = &APIError{HTTPStatus: http.StatusBadGateway, Code: "MAX_RETRIES_EXCEEDED", Message: "Request failed after maximum retries"} + ErrNoKeysAvailable = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_KEYS_AVAILABLE", Message: "No API keys available to process the request"} ) // NewAPIError creates a new APIError with a custom message. @@ -45,6 +47,15 @@ func NewAPIError(base *APIError, message string) *APIError { } } +// NewAPIErrorWithUpstream creates a new APIError specifically for wrapping raw upstream errors. +func NewAPIErrorWithUpstream(statusCode int, code string, upstreamMessage string) *APIError { + return &APIError{ + HTTPStatus: statusCode, + Code: code, + Message: upstreamMessage, + } +} + // ParseDBError intelligently converts a GORM error into a standard APIError. func ParseDBError(err error) *APIError { if err == nil { diff --git a/internal/errors/ignorable_errors.go b/internal/errors/ignorable_errors.go new file mode 100644 index 0000000..8ea4b29 --- /dev/null +++ b/internal/errors/ignorable_errors.go @@ -0,0 +1,31 @@ +package errors + +import ( + "strings" +) + +// ignorableErrorSubstrings contains a list of substrings that indicate an error +// can be safely ignored. These typically occur when a client disconnects prematurely. +var ignorableErrorSubstrings = []string{ + "context canceled", + "connection reset by peer", + "broken pipe", + "use of closed network connection", + "request canceled", +} + +// IsIgnorableError checks if the given error is a common, non-critical error +// that can occur when a client disconnects. This is used to prevent logging +// unnecessary errors and to avoid marking keys as failed for client-side issues. +func IsIgnorableError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + for _, sub := range ignorableErrorSubstrings { + if strings.Contains(errStr, sub) { + return true + } + } + return false +} diff --git a/internal/httpclient/manager.go b/internal/httpclient/manager.go new file mode 100644 index 0000000..7d93f6b --- /dev/null +++ b/internal/httpclient/manager.go @@ -0,0 +1,105 @@ +package httpclient + +import ( + "fmt" + "net" + "net/http" + "sync" + "time" +) + +// Config defines the parameters for creating an HTTP client. +// This struct is used to generate a unique fingerprint for client reuse. +type Config struct { + ConnectTimeout time.Duration + RequestTimeout time.Duration + IdleConnTimeout time.Duration + MaxIdleConns int + MaxIdleConnsPerHost int + ResponseHeaderTimeout time.Duration + DisableCompression bool + WriteBufferSize int + ReadBufferSize int +} + +// HTTPClientManager manages the lifecycle of HTTP clients. +// It creates and caches clients based on their configuration fingerprint, +// ensuring that clients with the same configuration are reused. +type HTTPClientManager struct { + clients map[string]*http.Client + lock sync.RWMutex +} + +// NewHTTPClientManager creates a new client manager. +func NewHTTPClientManager() *HTTPClientManager { + return &HTTPClientManager{ + clients: make(map[string]*http.Client), + } +} + +// GetClient returns an HTTP client that matches the given configuration. +// If a matching client already exists in the cache, it is returned. +// Otherwise, a new client is created, cached, and returned. +func (m *HTTPClientManager) GetClient(config *Config) *http.Client { + fingerprint := config.getFingerprint() + + // Fast path with read lock + m.lock.RLock() + client, exists := m.clients[fingerprint] + m.lock.RUnlock() + if exists { + return client + } + + // Slow path with write lock + m.lock.Lock() + defer m.lock.Unlock() + + // Double-check in case another goroutine created the client while we were waiting for the lock. + if client, exists = m.clients[fingerprint]; exists { + return client + } + + // Create a new transport and client with the specified configuration. + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: config.ConnectTimeout, + KeepAlive: 30 * time.Second, // KeepAlive is a good default + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: config.MaxIdleConns, + MaxIdleConnsPerHost: config.MaxIdleConnsPerHost, + IdleConnTimeout: config.IdleConnTimeout, + TLSHandshakeTimeout: 10 * time.Second, // A reasonable default + ExpectContinueTimeout: 1 * time.Second, // A reasonable default + ResponseHeaderTimeout: config.ResponseHeaderTimeout, + DisableCompression: config.DisableCompression, + WriteBufferSize: config.WriteBufferSize, + ReadBufferSize: config.ReadBufferSize, + } + + newClient := &http.Client{ + Transport: transport, + Timeout: config.RequestTimeout, + } + + m.clients[fingerprint] = newClient + return newClient +} + +// getFingerprint generates a unique string representation of the client configuration. +func (c *Config) getFingerprint() string { + return fmt.Sprintf( + "ct:%.0fs|rt:%.0fs|it:%.0fs|mic:%d|mich:%d|rht:%.0fs|dc:%t|wbs:%d|rbs:%d", + c.ConnectTimeout.Seconds(), + c.RequestTimeout.Seconds(), + c.IdleConnTimeout.Seconds(), + c.MaxIdleConns, + c.MaxIdleConnsPerHost, + c.ResponseHeaderTimeout.Seconds(), + c.DisableCompression, + c.WriteBufferSize, + c.ReadBufferSize, + ) +} diff --git a/internal/models/types.go b/internal/models/types.go index 25eef6b..efbe8ea 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -35,6 +35,11 @@ type GroupConfig struct { ResponseTimeout *int `json:"response_timeout,omitempty"` IdleConnTimeout *int `json:"idle_conn_timeout,omitempty"` KeyValidationIntervalMinutes *int `json:"key_validation_interval_minutes,omitempty"` + ConnectTimeout *int `json:"connect_timeout,omitempty"` + MaxIdleConns *int `json:"max_idle_conns,omitempty"` + MaxIdleConnsPerHost *int `json:"max_idle_conns_per_host,omitempty"` + ResponseHeaderTimeout *int `json:"response_header_timeout,omitempty"` + DisableCompression *bool `json:"disable_compression,omitempty"` } // Group 对应 groups 表 diff --git a/internal/proxy/server.go b/internal/proxy/server.go index d7a8fd8..22c265b 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -2,171 +2,311 @@ package proxy import ( + "bufio" "bytes" + "compress/gzip" + "context" "encoding/json" "fmt" + "io" + "net/http" + "strings" + "time" + "gpt-load/internal/channel" + "gpt-load/internal/config" app_errors "gpt-load/internal/errors" "gpt-load/internal/keypool" "gpt-load/internal/models" "gpt-load/internal/response" - "io" - "time" + "gpt-load/internal/services" + "gpt-load/internal/types" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" - "gorm.io/gorm" ) +// A list of errors that are considered normal during streaming when a client disconnects. +var ignorableStreamErrors = []string{ + "context canceled", + "connection reset by peer", + "broken pipe", + "use of closed network connection", +} + +// isIgnorableStreamError checks if the error is a common, non-critical error that can occur +// when a client disconnects during a streaming response. +func isIgnorableStreamError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + for _, ignorableError := range ignorableStreamErrors { + if strings.Contains(errStr, ignorableError) { + return true + } + } + return false +} + // ProxyServer represents the proxy server type ProxyServer struct { - DB *gorm.DB - channelFactory *channel.Factory - keyProvider *keypool.KeyProvider - requestLogChan chan models.RequestLog + keyProvider *keypool.KeyProvider + groupManager *services.GroupManager + settingsManager *config.SystemSettingsManager + channelFactory *channel.Factory } // NewProxyServer creates a new proxy server func NewProxyServer( - db *gorm.DB, - channelFactory *channel.Factory, keyProvider *keypool.KeyProvider, - requestLogChan chan models.RequestLog, + groupManager *services.GroupManager, + settingsManager *config.SystemSettingsManager, + channelFactory *channel.Factory, ) (*ProxyServer, error) { return &ProxyServer{ - DB: db, - channelFactory: channelFactory, - keyProvider: keyProvider, - requestLogChan: requestLogChan, + keyProvider: keyProvider, + groupManager: groupManager, + settingsManager: settingsManager, + channelFactory: channelFactory, }, nil } -// HandleProxy handles the main proxy logic +// HandleProxy is the main entry point for proxy requests, refactored based on the stable .bak logic. func (ps *ProxyServer) HandleProxy(c *gin.Context) { startTime := time.Now() groupName := c.Param("group_name") - // 1. Find the group by name (without preloading keys) - var group models.Group - if err := ps.DB.Where("name = ?", groupName).First(&group).Error; err != nil { + group, err := ps.groupManager.GetGroupByName(groupName) + if err != nil { response.Error(c, app_errors.ParseDBError(err)) return } - // 2. Select an available API key from the KeyPool - apiKey, err := ps.keyProvider.SelectKey(group.ID) - if err != nil { - // Properly handle the case where no keys are available - if apiErr, ok := err.(*app_errors.APIError); ok { - response.Error(c, apiErr) - } else { - response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, err.Error())) - } - return - } - - // 3. Get the appropriate channel handler from the factory - channelHandler, err := ps.channelFactory.GetChannel(&group) + channelHandler, err := ps.channelFactory.GetChannel(group) if err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s': %v", groupName, err))) return } - // 4. Apply parameter overrides if they exist - if len(group.ParamOverrides) > 0 { - err := ps.applyParamOverrides(c, &group) - if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to apply parameter overrides: %v", err))) - return - } - } - - // 5. Forward the request using the channel handler - err = channelHandler.Handle(c, apiKey, &group) - - // 6. Update key status and log the request asynchronously - isSuccess := err == nil - ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, isSuccess) - - if !isSuccess { - logrus.WithFields(logrus.Fields{ - "group": group.Name, - "key_id": apiKey.ID, - "error": err.Error(), - }).Error("Channel handler failed") - } - go ps.logRequest(c, &group, apiKey, startTime) -} - -func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time) { - logEntry := models.RequestLog{ - ID: fmt.Sprintf("req_%d", time.Now().UnixNano()), - Timestamp: startTime, - GroupID: group.ID, - KeyID: key.ID, - SourceIP: c.ClientIP(), - StatusCode: c.Writer.Status(), - RequestPath: c.Request.URL.Path, - RequestBodySnippet: "", // Can be implemented later if needed - } - - // Send to the logging channel without blocking - select { - case ps.requestLogChan <- logEntry: - default: - logrus.Warn("Request log channel is full. Dropping log entry.") - } -} - -// Close cleans up resources -func (ps *ProxyServer) Close() { - close(ps.requestLogChan) -} - -func (ps *ProxyServer) applyParamOverrides(c *gin.Context, group *models.Group) error { - // Read the original request body bodyBytes, err := io.ReadAll(c.Request.Body) if err != nil { - return fmt.Errorf("failed to read request body: %w", err) + logrus.Errorf("Failed to read request body: %v", err) + response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Failed to read request body")) + return } c.Request.Body.Close() - // If body is empty, nothing to override, just restore the body - if len(bodyBytes) == 0 { - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - return nil + // 4. Apply parameter overrides if any. + finalBodyBytes, err := ps.applyParamOverrides(bodyBytes, group) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to apply parameter overrides: %v", err))) + return } - // Save the original Content-Type - originalContentType := c.GetHeader("Content-Type") + // 5. Determine if this is a streaming request. + isStream := channelHandler.IsStreamRequest(c, bodyBytes) + + // 6. Execute the request using the recursive retry logic. + ps.executeRequestWithRetry(c, channelHandler, group, finalBodyBytes, isStream, startTime, 0, nil) +} + +// executeRequestWithRetry is the core recursive function for handling requests and retries. +func (ps *ProxyServer) executeRequestWithRetry( + c *gin.Context, + channelHandler channel.ChannelProxy, + group *models.Group, + bodyBytes []byte, + isStream bool, + startTime time.Time, + retryCount int, + retryErrors []types.RetryError, +) { + cfg := group.EffectiveConfig + if retryCount > cfg.MaxRetries { + logrus.Errorf("Max retries exceeded for group %s after %d attempts.", group.Name, retryCount) + if len(retryErrors) > 0 { + lastError := retryErrors[len(retryErrors)-1] + var errorJSON map[string]any + if err := json.Unmarshal([]byte(lastError.ErrorMessage), &errorJSON); err == nil { + c.JSON(lastError.StatusCode, errorJSON) + } else { + response.Error(c, app_errors.NewAPIErrorWithUpstream(lastError.StatusCode, "UPSTREAM_ERROR", lastError.ErrorMessage)) + } + } else { + response.Error(c, app_errors.ErrMaxRetriesExceeded) + } + return + } + + apiKey, err := ps.keyProvider.SelectKey(group.ID) + if err != nil { + logrus.Errorf("Failed to select a key for group %s on attempt %d: %v", group.Name, retryCount+1, err) + response.Error(c, app_errors.NewAPIError(app_errors.ErrNoKeysAvailable, err.Error())) + return + } + + upstreamURL, err := channelHandler.BuildUpstreamURL(c.Request.URL, group) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to build upstream URL: %v", err))) + return + } + + var ctx context.Context + var cancel context.CancelFunc + if isStream { + ctx, cancel = context.WithCancel(c.Request.Context()) + } else { + timeout := time.Duration(cfg.RequestTimeout) * time.Second + ctx, cancel = context.WithTimeout(c.Request.Context(), timeout) + } + defer cancel() + + req, err := http.NewRequestWithContext(ctx, c.Request.Method, upstreamURL, bytes.NewReader(bodyBytes)) + if err != nil { + logrus.Errorf("Failed to create upstream request: %v", err) + response.Error(c, app_errors.ErrInternalServer) + return + } + req.ContentLength = int64(len(bodyBytes)) + + req.Header = c.Request.Header.Clone() + channelHandler.ModifyRequest(req, apiKey, group) + + client := channelHandler.GetHTTPClient() + if isStream { + client = channelHandler.GetStreamClient() + req.Header.Set("X-Accel-Buffering", "no") + } + + resp, err := client.Do(req) + if err != nil { + ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, false) + logrus.Warnf("Request failed (attempt %d/%d) for key %s: %v", retryCount+1, cfg.MaxRetries, apiKey.KeyValue[:8], err) + + newRetryErrors := append(retryErrors, types.RetryError{ + StatusCode: 0, + ErrorMessage: err.Error(), + KeyID: fmt.Sprintf("%d", apiKey.ID), + Attempt: retryCount + 1, + }) + ps.executeRequestWithRetry(c, channelHandler, group, bodyBytes, isStream, startTime, retryCount+1, newRetryErrors) + return + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, false) + errorBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + logrus.Errorf("Failed to read error body: %v", readErr) + // Even if reading fails, we should proceed with retry logic + errorBody = []byte("Failed to read error body") + } + + // Check for gzip encoding and decompress if necessary. + if resp.Header.Get("Content-Encoding") == "gzip" { + reader, err := gzip.NewReader(bytes.NewReader(errorBody)) + if err == nil { + decompressedBody, err := io.ReadAll(reader) + if err == nil { + errorBody = decompressedBody + } else { + logrus.Warnf("Failed to decompress gzip error body: %v", err) + } + reader.Close() + } else { + logrus.Warnf("Failed to create gzip reader for error body: %v", err) + } + } + + logrus.Warnf("Request failed with status %d (attempt %d/%d) for key %s. Body: %s", resp.StatusCode, retryCount+1, cfg.MaxRetries, apiKey.KeyValue[:8], string(errorBody)) + + newRetryErrors := append(retryErrors, types.RetryError{ + StatusCode: resp.StatusCode, + ErrorMessage: string(errorBody), + KeyID: fmt.Sprintf("%d", apiKey.ID), + Attempt: retryCount + 1, + }) + ps.executeRequestWithRetry(c, channelHandler, group, bodyBytes, isStream, startTime, retryCount+1, newRetryErrors) + return + } + + ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, true) + logrus.Debugf("Request for group %s succeeded on attempt %d with key %s", group.Name, retryCount+1, apiKey.KeyValue[:8]) + + for key, values := range resp.Header { + for _, value := range values { + c.Header(key, value) + } + } + c.Status(resp.StatusCode) + + if isStream { + ps.handleStreamingResponse(c, resp) + } else { + ps.handleNormalResponse(c, resp) + } +} + +func (ps *ProxyServer) handleStreamingResponse(c *gin.Context, resp *http.Response) { + flusher, ok := c.Writer.(http.Flusher) + if !ok { + logrus.Error("Streaming unsupported by the writer") + ps.handleNormalResponse(c, resp) + return + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + if _, err := c.Writer.Write(scanner.Bytes()); err != nil { + if !isIgnorableStreamError(err) { + logrus.Errorf("Error writing to client: %v", err) + } + return + } + if _, err := c.Writer.Write([]byte("\n")); err != nil { + if !isIgnorableStreamError(err) { + logrus.Errorf("Error writing newline to client: %v", err) + } + return + } + flusher.Flush() + } + + if err := scanner.Err(); err != nil && !isIgnorableStreamError(err) { + logrus.Errorf("Error reading from upstream: %v", err) + } +} + +func (ps *ProxyServer) handleNormalResponse(c *gin.Context, resp *http.Response) { + if _, err := io.Copy(c.Writer, resp.Body); err != nil { + if !isIgnorableStreamError(err) { + logrus.Errorf("Failed to copy response body to client: %v", err) + } + } +} + +func (ps *ProxyServer) applyParamOverrides(bodyBytes []byte, group *models.Group) ([]byte, error) { + if len(group.ParamOverrides) == 0 || len(bodyBytes) == 0 { + return bodyBytes, nil + } - // Unmarshal the body into a map var requestData map[string]any if err := json.Unmarshal(bodyBytes, &requestData); err != nil { - // If not a valid JSON, just pass it through - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - return nil + logrus.Warnf("failed to unmarshal request body for param override, passing through: %v", err) + return bodyBytes, nil } - // Merge the overrides into the request data for key, value := range group.ParamOverrides { requestData[key] = value } - // Marshal the new data back to JSON - newBodyBytes, err := json.Marshal(requestData) - if err != nil { - return fmt.Errorf("failed to marshal new request body: %w", err) - } - - // Replace the request body with the new one - c.Request.Body = io.NopCloser(bytes.NewBuffer(newBodyBytes)) - c.Request.ContentLength = int64(len(newBodyBytes)) - - // Restore the original Content-Type header - if originalContentType != "" { - c.Request.Header.Set("Content-Type", originalContentType) - } - - return nil + return json.Marshal(requestData) +} + +func (ps *ProxyServer) Close() { + // The HTTP clients are now managed by the channel factory and httpclient manager, + // so the proxy server itself doesn't need to close them. + // The httpclient manager will handle closing idle connections for all its clients. } diff --git a/internal/types/retry.go b/internal/types/retry.go new file mode 100644 index 0000000..2a46f42 --- /dev/null +++ b/internal/types/retry.go @@ -0,0 +1,9 @@ +package types + +// RetryError captures detailed information about a failed request attempt during retries. +type RetryError struct { + StatusCode int `json:"status_code"` + ErrorMessage string `json:"error_message"` + KeyID string `json:"key_id"` + Attempt int `json:"attempt"` +} diff --git a/internal/types/types.go b/internal/types/types.go index 91048c0..0ecce43 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -32,6 +32,8 @@ type SystemSettings struct { IdleConnTimeout int `json:"idle_conn_timeout" default:"120" name:"空闲连接超时" category:"请求超时" desc:"HTTP 客户端中空闲连接的超时时间(秒)。" validate:"min=1"` MaxIdleConns int `json:"max_idle_conns" default:"100" name:"最大空闲连接数" category:"请求超时" desc:"HTTP 客户端连接池中允许的最大空闲连接总数。" validate:"min=1"` MaxIdleConnsPerHost int `json:"max_idle_conns_per_host" default:"10" name:"每主机最大空闲连接数" category:"请求超时" desc:"HTTP 客户端连接池对每个上游主机允许的最大空闲连接数。" validate:"min=1"` + ResponseHeaderTimeout int `json:"response_header_timeout" default:"120" name:"响应头超时" category:"请求超时" desc:"等待上游服务响应头的最长时间(秒),用于流式请求。" validate:"min=1"` + DisableCompression bool `json:"disable_compression" default:"false" name:"禁用压缩" category:"请求超时" desc:"是否禁用对上游请求的传输压缩(Gzip)。对于流式请求建议开启以降低延迟。"` // 密钥配置 MaxRetries int `json:"max_retries" default:"3" name:"最大重试次数" category:"密钥配置" desc:"单个请求使用不同 Key 的最大重试次数" validate:"min=0"`