diff --git a/internal/app/app.go b/internal/app/app.go index 3b9fca0..a3f2902 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -184,8 +184,6 @@ func (a *App) Stop(ctx context.Context) { a.logCleanupService.Stop() a.groupManager.Stop() a.settingsManager.Stop() - - // Close resources a.storage.Close() // Wait for the logger to finish writing all logs diff --git a/internal/channel/base_channel.go b/internal/channel/base_channel.go index 77f5647..b51eecb 100644 --- a/internal/channel/base_channel.go +++ b/internal/channel/base_channel.go @@ -2,13 +2,15 @@ package channel import ( "bytes" - "encoding/json" + "fmt" "gpt-load/internal/models" + "gpt-load/internal/types" "net/http" "net/url" + "reflect" + "strings" "sync" - "github.com/sirupsen/logrus" "gorm.io/datatypes" ) @@ -28,7 +30,7 @@ type BaseChannel struct { TestModel string upstreamLock sync.Mutex groupUpstreams datatypes.JSON - groupConfig datatypes.JSONMap + effectiveConfig *types.SystemSettings } // getUpstreamURL selects an upstream URL using a smooth weighted round-robin algorithm. @@ -64,28 +66,33 @@ func (b *BaseChannel) getUpstreamURL() *url.URL { return best.URL } +// BuildUpstreamURL constructs the target URL for the upstream service. +func (b *BaseChannel) BuildUpstreamURL(originalURL *url.URL, group *models.Group) (string, error) { + base := b.getUpstreamURL() + if base == nil { + return "", fmt.Errorf("no upstream URL configured for channel %s", b.Name) + } + + finalURL := *base + proxyPrefix := "/proxy/" + group.Name + if strings.HasPrefix(originalURL.Path, proxyPrefix) { + finalURL.Path = strings.TrimPrefix(originalURL.Path, proxyPrefix) + } else { + finalURL.Path = originalURL.Path + } + + finalURL.RawQuery = originalURL.RawQuery + + return finalURL.String(), nil +} + // IsConfigStale checks if the channel's configuration is stale compared to the provided group. func (b *BaseChannel) IsConfigStale(group *models.Group) bool { - // It's important to compare the raw JSON here to detect any changes. if !bytes.Equal(b.groupUpstreams, group.Upstreams) { return true } - // For JSONMap, we need to marshal it to compare. - currentConfigBytes, err := json.Marshal(b.groupConfig) - if err != nil { - // Log the error and assume it's stale to be safe - logrus.Errorf("failed to marshal current group config: %v", err) - return true - } - newConfigBytes, err := json.Marshal(group.Config) - if err != nil { - // Log the error and assume it's stale - logrus.Errorf("failed to marshal new group config: %v", err) - return true - } - - if !bytes.Equal(currentConfigBytes, newConfigBytes) { + if !reflect.DeepEqual(b.effectiveConfig, &group.EffectiveConfig) { return true } diff --git a/internal/channel/channel.go b/internal/channel/channel.go index 53ae3bf..de074d6 100644 --- a/internal/channel/channel.go +++ b/internal/channel/channel.go @@ -10,17 +10,14 @@ import ( ) // ChannelProxy defines the interface for different API channel proxies. -// It's responsible for channel-specific logic like URL building and request modification. type ChannelProxy interface { // BuildUpstreamURL constructs the target URL for the upstream service. BuildUpstreamURL(originalURL *url.URL, group *models.Group) (string, error) // ModifyRequest allows the channel to add specific headers or modify the request - // before it's sent to the upstream service. ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) // IsStreamRequest checks if the request is for a streaming response, - // now using the cached request body to avoid re-reading the stream. IsStreamRequest(c *gin.Context, bodyBytes []byte) bool // ValidateKey checks if the given API key is valid. diff --git a/internal/channel/factory.go b/internal/channel/factory.go index c647765..133216d 100644 --- a/internal/channel/factory.go +++ b/internal/channel/factory.go @@ -22,7 +22,6 @@ var ( ) // Register adds a new channel constructor to the registry. -// This function is intended to be called from the init() function of each channel implementation. func Register(channelType string, constructor channelConstructor) { if _, exists := channelRegistry[channelType]; exists { panic(fmt.Sprintf("channel type '%s' is already registered", channelType)) @@ -57,7 +56,6 @@ func NewFactory(settingsManager *config.SystemSettingsManager, clientManager *ht } // GetChannel returns a channel proxy based on the group's channel type. -// It uses a cache to ensure that only one instance of a channel is created for each group. func (f *Factory) GetChannel(group *models.Group) (ChannelProxy, error) { f.cacheLock.Lock() defer f.cacheLock.Unlock() @@ -120,20 +118,28 @@ func (f *Factory) newBaseChannel(name string, group *models.Group) (*BaseChannel MaxIdleConnsPerHost: group.EffectiveConfig.MaxIdleConnsPerHost, ResponseHeaderTimeout: time.Duration(group.EffectiveConfig.ResponseHeaderTimeout) * time.Second, DisableCompression: group.EffectiveConfig.DisableCompression, - WriteBufferSize: 32 * 1024, // Use a reasonable default buffer size for regular requests + WriteBufferSize: 32 * 1024, ReadBufferSize: 32 * 1024, + ForceAttemptHTTP2: true, + TLSHandshakeTimeout: 15 * time.Second, + ExpectContinueTimeout: 1 * time.Second, } // Create a dedicated configuration for streaming requests. - // This configuration is optimized for low-latency, long-running connections. streamConfig := *clientConfig - streamConfig.RequestTimeout = 0 // No overall timeout for the entire request. - streamConfig.DisableCompression = true // Always disable compression for streaming to reduce latency. - streamConfig.WriteBufferSize = 0 // Disable buffering for real-time data transfer. + streamConfig.RequestTimeout = 0 + streamConfig.DisableCompression = true + streamConfig.WriteBufferSize = 0 streamConfig.ReadBufferSize = 0 - // For stream-specific connection pool, we can use a simple heuristic like doubling the regular one. + // Use a larger, independent connection pool for streaming clients to avoid exhaustion. streamConfig.MaxIdleConns = group.EffectiveConfig.MaxIdleConns * 2 + if streamConfig.MaxIdleConns < 200 { + streamConfig.MaxIdleConns = 200 + } streamConfig.MaxIdleConnsPerHost = group.EffectiveConfig.MaxIdleConnsPerHost * 2 + if streamConfig.MaxIdleConnsPerHost < 40 { + streamConfig.MaxIdleConnsPerHost = 40 + } // Get both clients from the manager using their respective configurations. httpClient := f.clientManager.GetClient(clientConfig) @@ -145,7 +151,7 @@ func (f *Factory) newBaseChannel(name string, group *models.Group) (*BaseChannel HTTPClient: httpClient, StreamClient: streamClient, TestModel: group.TestModel, - groupUpstreams: group.Upstreams, - groupConfig: group.Config, + groupUpstreams: group.Upstreams, + effectiveConfig: &group.EffectiveConfig, }, nil } diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index 4d8b957..9b463fd 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -9,7 +9,6 @@ import ( "gpt-load/internal/models" "io" "net/http" - "net/url" "strings" "github.com/gin-gonic/gin" @@ -34,31 +33,6 @@ func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) { }, nil } -// BuildUpstreamURL constructs the target URL for the Gemini service. -func (ch *GeminiChannel) BuildUpstreamURL(originalURL *url.URL, group *models.Group) (string, error) { - base := ch.getUpstreamURL() - if base == nil { - // Fallback to default Gemini URL - base, _ = url.Parse("https://generativelanguage.googleapis.com") - } - - finalURL := *base - // The originalURL.Path contains the full path, e.g., "/proxy/gemini/v1beta/models/gemini-pro:generateContent". - // We need to strip the proxy prefix to get the correct upstream path. - proxyPrefix := "/proxy/" + group.Name - if strings.HasPrefix(originalURL.Path, proxyPrefix) { - finalURL.Path = strings.TrimPrefix(originalURL.Path, proxyPrefix) - } else { - // Fallback for safety. - finalURL.Path = originalURL.Path - } - - // The API key will be added to RawQuery in ModifyRequest. - finalURL.RawQuery = originalURL.RawQuery - - return finalURL.String(), nil -} - // ModifyRequest adds the API key as a query parameter for Gemini requests. func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) { q := req.URL.Query() @@ -73,11 +47,8 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, err return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name) } - // Use the test model specified in the group settings. - // The path format for Gemini is /v1beta/models/{model}:generateContent reqURL := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", upstreamURL.String(), ch.TestModel, key) - // Use a minimal, low-cost payload for validation payload := gin.H{ "contents": []gin.H{ {"parts": []gin.H{ diff --git a/internal/channel/openai_channel.go b/internal/channel/openai_channel.go index 179ad2b..2bc5963 100644 --- a/internal/channel/openai_channel.go +++ b/internal/channel/openai_channel.go @@ -9,7 +9,6 @@ import ( "gpt-load/internal/models" "io" "net/http" - "net/url" "strings" "github.com/gin-gonic/gin" @@ -34,34 +33,6 @@ func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) { }, nil } -// BuildUpstreamURL constructs the target URL for the OpenAI service. -func (ch *OpenAIChannel) BuildUpstreamURL(originalURL *url.URL, group *models.Group) (string, error) { - // Use the weighted round-robin selection from the base channel. - // This method already handles parsing the group's Upstreams JSON. - base := ch.getUpstreamURL() - if base == nil { - // If no upstreams are configured in the group, fallback to a default. - // This can be considered an error or a feature depending on requirements. - // For now, we'll use the official OpenAI URL as a last resort. - base, _ = url.Parse("https://api.openai.com") - } - - // It's crucial to create a copy to avoid modifying the cached URL object in BaseChannel. - finalURL := *base - // The originalURL.Path contains the full path, e.g., "/proxy/openai/v1/chat/completions". - // We need to strip the proxy prefix to get the correct upstream path. - proxyPrefix := "/proxy/" + group.Name - if strings.HasPrefix(originalURL.Path, proxyPrefix) { - finalURL.Path = strings.TrimPrefix(originalURL.Path, proxyPrefix) - } else { - // Fallback for safety, though this case should ideally not be hit. - finalURL.Path = originalURL.Path - } - - finalURL.RawQuery = originalURL.RawQuery - - return finalURL.String(), nil -} // ModifyRequest sets the Authorization header for the OpenAI service. func (ch *OpenAIChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) { diff --git a/internal/config/manager.go b/internal/config/manager.go index d4712a5..b52bd88 100644 --- a/internal/config/manager.go +++ b/internal/config/manager.go @@ -79,8 +79,6 @@ func (m *Manager) ReloadConfig() error { Server: types.ServerConfig{ Port: parseInteger(os.Getenv("PORT"), 3000), Host: getEnvOrDefault("HOST", "0.0.0.0"), - // Server timeout configs now come from system settings, not environment - // Using defaults from SystemSettings struct as the initial value ReadTimeout: defaultSettings.ServerReadTimeout, WriteTimeout: defaultSettings.ServerWriteTimeout, IdleTimeout: defaultSettings.ServerIdleTimeout, diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 52096c6..bc932bd 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -22,18 +22,18 @@ func (e *APIError) Error() string { // Predefined API errors var ( - ErrBadRequest = &APIError{HTTPStatus: http.StatusBadRequest, Code: "BAD_REQUEST", Message: "Invalid request parameters"} - ErrInvalidJSON = &APIError{HTTPStatus: http.StatusBadRequest, Code: "INVALID_JSON", Message: "Invalid JSON format"} - ErrValidation = &APIError{HTTPStatus: http.StatusBadRequest, Code: "VALIDATION_FAILED", Message: "Input validation failed"} - ErrDuplicateResource = &APIError{HTTPStatus: http.StatusConflict, Code: "DUPLICATE_RESOURCE", Message: "Resource already exists"} - ErrResourceNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "NOT_FOUND", Message: "Resource not found"} - ErrInternalServer = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "INTERNAL_SERVER_ERROR", Message: "An unexpected error occurred"} - ErrDatabase = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "DATABASE_ERROR", Message: "Database operation failed"} - ErrUnauthorized = &APIError{HTTPStatus: http.StatusUnauthorized, Code: "UNAUTHORIZED", Message: "Authentication failed"} - ErrForbidden = &APIError{HTTPStatus: http.StatusForbidden, Code: "FORBIDDEN", Message: "You do not have permission to access this resource"} - ErrTaskInProgress = &APIError{HTTPStatus: http.StatusConflict, Code: "TASK_IN_PROGRESS", Message: "A task is already in progress"} - ErrBadGateway = &APIError{HTTPStatus: http.StatusBadGateway, Code: "BAD_GATEWAY", Message: "Upstream service error"} - ErrNoActiveKeys = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_ACTIVE_KEYS", Message: "No active API keys available for this group"} + ErrBadRequest = &APIError{HTTPStatus: http.StatusBadRequest, Code: "BAD_REQUEST", Message: "Invalid request parameters"} + ErrInvalidJSON = &APIError{HTTPStatus: http.StatusBadRequest, Code: "INVALID_JSON", Message: "Invalid JSON format"} + ErrValidation = &APIError{HTTPStatus: http.StatusBadRequest, Code: "VALIDATION_FAILED", Message: "Input validation failed"} + ErrDuplicateResource = &APIError{HTTPStatus: http.StatusConflict, Code: "DUPLICATE_RESOURCE", Message: "Resource already exists"} + ErrResourceNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "NOT_FOUND", Message: "Resource not found"} + ErrInternalServer = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "INTERNAL_SERVER_ERROR", Message: "An unexpected error occurred"} + ErrDatabase = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "DATABASE_ERROR", Message: "Database operation failed"} + ErrUnauthorized = &APIError{HTTPStatus: http.StatusUnauthorized, Code: "UNAUTHORIZED", Message: "Authentication failed"} + ErrForbidden = &APIError{HTTPStatus: http.StatusForbidden, Code: "FORBIDDEN", Message: "You do not have permission to access this resource"} + ErrTaskInProgress = &APIError{HTTPStatus: http.StatusConflict, Code: "TASK_IN_PROGRESS", Message: "A task is already in progress"} + ErrBadGateway = &APIError{HTTPStatus: http.StatusBadGateway, Code: "BAD_GATEWAY", Message: "Upstream service error"} + ErrNoActiveKeys = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_ACTIVE_KEYS", Message: "No active API keys available for this group"} ErrMaxRetriesExceeded = &APIError{HTTPStatus: http.StatusBadGateway, Code: "MAX_RETRIES_EXCEEDED", Message: "Request failed after maximum retries"} ErrNoKeysAvailable = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_KEYS_AVAILABLE", Message: "No API keys available to process the request"} ) diff --git a/internal/errors/parser.go b/internal/errors/parser.go index cda5eff..b91bd6f 100644 --- a/internal/errors/parser.go +++ b/internal/errors/parser.go @@ -33,8 +33,6 @@ type rootMessageErrorResponse struct { } // ParseUpstreamError attempts to parse a structured error message from an upstream response body -// using a chain of responsibility pattern. It tries various common formats and gracefully -// degrades to a raw string if all parsing attempts fail. func ParseUpstreamError(body []byte) string { // 1. Attempt to parse the standard OpenAI/Gemini format. var stdErr standardErrorResponse diff --git a/internal/httpclient/manager.go b/internal/httpclient/manager.go index 7d93f6b..7dfaf56 100644 --- a/internal/httpclient/manager.go +++ b/internal/httpclient/manager.go @@ -20,6 +20,9 @@ type Config struct { DisableCompression bool WriteBufferSize int ReadBufferSize int + ForceAttemptHTTP2 bool + TLSHandshakeTimeout time.Duration + ExpectContinueTimeout time.Duration } // HTTPClientManager manages the lifecycle of HTTP clients. @@ -65,14 +68,14 @@ func (m *HTTPClientManager) GetClient(config *Config) *http.Client { Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: config.ConnectTimeout, - KeepAlive: 30 * time.Second, // KeepAlive is a good default + KeepAlive: 30 * time.Second, }).DialContext, - ForceAttemptHTTP2: true, + ForceAttemptHTTP2: config.ForceAttemptHTTP2, MaxIdleConns: config.MaxIdleConns, MaxIdleConnsPerHost: config.MaxIdleConnsPerHost, IdleConnTimeout: config.IdleConnTimeout, - TLSHandshakeTimeout: 10 * time.Second, // A reasonable default - ExpectContinueTimeout: 1 * time.Second, // A reasonable default + TLSHandshakeTimeout: config.TLSHandshakeTimeout, + ExpectContinueTimeout: config.ExpectContinueTimeout, ResponseHeaderTimeout: config.ResponseHeaderTimeout, DisableCompression: config.DisableCompression, WriteBufferSize: config.WriteBufferSize, @@ -91,7 +94,7 @@ func (m *HTTPClientManager) GetClient(config *Config) *http.Client { // getFingerprint generates a unique string representation of the client configuration. func (c *Config) getFingerprint() string { return fmt.Sprintf( - "ct:%.0fs|rt:%.0fs|it:%.0fs|mic:%d|mich:%d|rht:%.0fs|dc:%t|wbs:%d|rbs:%d", + "ct:%.0fs|rt:%.0fs|it:%.0fs|mic:%d|mich:%d|rht:%.0fs|dc:%t|wbs:%d|rbs:%d|fh2:%t|tlst:%.0fs|ect:%.0fs", c.ConnectTimeout.Seconds(), c.RequestTimeout.Seconds(), c.IdleConnTimeout.Seconds(), @@ -101,5 +104,8 @@ func (c *Config) getFingerprint() string { c.DisableCompression, c.WriteBufferSize, c.ReadBufferSize, + c.ForceAttemptHTTP2, + c.TLSHandshakeTimeout.Seconds(), + c.ExpectContinueTimeout.Seconds(), ) } diff --git a/internal/proxy/request_helpers.go b/internal/proxy/request_helpers.go new file mode 100644 index 0000000..a3ebd23 --- /dev/null +++ b/internal/proxy/request_helpers.go @@ -0,0 +1,63 @@ +package proxy + +import ( + "bytes" + "compress/gzip" + "encoding/json" + app_errors "gpt-load/internal/errors" + "gpt-load/internal/models" + "io" + "net/http" + + "github.com/sirupsen/logrus" +) + +func (ps *ProxyServer) applyParamOverrides(bodyBytes []byte, group *models.Group) ([]byte, error) { + if len(group.ParamOverrides) == 0 || len(bodyBytes) == 0 { + return bodyBytes, nil + } + + var requestData map[string]any + if err := json.Unmarshal(bodyBytes, &requestData); err != nil { + logrus.Warnf("failed to unmarshal request body for param override, passing through: %v", err) + return bodyBytes, nil + } + + for key, value := range group.ParamOverrides { + requestData[key] = value + } + + return json.Marshal(requestData) +} + +// logUpstreamError provides a centralized way to log errors from upstream interactions. +func logUpstreamError(context string, err error) { + if err == nil { + return + } + if app_errors.IsIgnorableError(err) { + logrus.Debugf("Ignorable upstream error in %s: %v", context, err) + } else { + logrus.Errorf("Upstream error in %s: %v", context, err) + } +} + +// handleGzipCompression checks for gzip encoding and decompresses the body if necessary. +func handleGzipCompression(resp *http.Response, bodyBytes []byte) []byte { + if resp.Header.Get("Content-Encoding") == "gzip" { + reader, gzipErr := gzip.NewReader(bytes.NewReader(bodyBytes)) + if gzipErr != nil { + logrus.Warnf("Failed to create gzip reader for error body: %v", gzipErr) + return bodyBytes + } + defer reader.Close() + + decompressedBody, readAllErr := io.ReadAll(reader) + if readAllErr != nil { + logrus.Warnf("Failed to decompress gzip error body: %v", readAllErr) + return bodyBytes + } + return decompressedBody + } + return bodyBytes +} diff --git a/internal/proxy/response_handlers.go b/internal/proxy/response_handlers.go new file mode 100644 index 0000000..979aee4 --- /dev/null +++ b/internal/proxy/response_handlers.go @@ -0,0 +1,55 @@ +package proxy + +import ( + "bufio" + "net/http" + + "io" + + "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" +) + +func (ps *ProxyServer) handleStreamingResponse(c *gin.Context, resp *http.Response) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + logrus.Error("Streaming unsupported by the writer, falling back to normal response") + ps.handleNormalResponse(c, resp) + return + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + select { + case <-c.Request.Context().Done(): + logrus.Debugf("Client disconnected, closing stream.") + return + default: + } + + if _, err := c.Writer.Write(scanner.Bytes()); err != nil { + logUpstreamError("writing stream to client", err) + return + } + if _, err := c.Writer.Write([]byte("\n\n")); err != nil { + logUpstreamError("writing stream newline to client", err) + return + } + flusher.Flush() + } + + if err := scanner.Err(); err != nil { + logUpstreamError("reading from upstream scanner", err) + } +} + +func (ps *ProxyServer) handleNormalResponse(c *gin.Context, resp *http.Response) { + if _, err := io.Copy(c.Writer, resp.Body); err != nil { + logUpstreamError("copying response body", err) + } +} diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 22c265b..cdb4517 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -2,15 +2,12 @@ package proxy import ( - "bufio" "bytes" - "compress/gzip" "context" "encoding/json" "fmt" "io" "net/http" - "strings" "time" "gpt-load/internal/channel" @@ -21,34 +18,12 @@ import ( "gpt-load/internal/response" "gpt-load/internal/services" "gpt-load/internal/types" + "gpt-load/internal/utils" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" ) -// A list of errors that are considered normal during streaming when a client disconnects. -var ignorableStreamErrors = []string{ - "context canceled", - "connection reset by peer", - "broken pipe", - "use of closed network connection", -} - -// isIgnorableStreamError checks if the error is a common, non-critical error that can occur -// when a client disconnects during a streaming response. -func isIgnorableStreamError(err error) bool { - if err == nil { - return false - } - errStr := err.Error() - for _, ignorableError := range ignorableStreamErrors { - if strings.Contains(errStr, ignorableError) { - return true - } - } - return false -} - // ProxyServer represents the proxy server type ProxyServer struct { keyProvider *keypool.KeyProvider @@ -97,17 +72,14 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) { } c.Request.Body.Close() - // 4. Apply parameter overrides if any. finalBodyBytes, err := ps.applyParamOverrides(bodyBytes, group) if err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to apply parameter overrides: %v", err))) return } - // 5. Determine if this is a streaming request. isStream := channelHandler.IsStreamRequest(c, bodyBytes) - // 6. Execute the request using the recursive retry logic. ps.executeRequestWithRetry(c, channelHandler, group, finalBodyBytes, isStream, startTime, 0, nil) } @@ -124,7 +96,6 @@ func (ps *ProxyServer) executeRequestWithRetry( ) { cfg := group.EffectiveConfig if retryCount > cfg.MaxRetries { - logrus.Errorf("Max retries exceeded for group %s after %d attempts.", group.Name, retryCount) if len(retryErrors) > 0 { lastError := retryErrors[len(retryErrors)-1] var errorJSON map[string]any @@ -133,8 +104,10 @@ func (ps *ProxyServer) executeRequestWithRetry( } else { response.Error(c, app_errors.NewAPIErrorWithUpstream(lastError.StatusCode, "UPSTREAM_ERROR", lastError.ErrorMessage)) } + logrus.Debugf("Max retries exceeded for group %s after %d attempts. Parsed Error: %s", group.Name, retryCount, lastError.ErrorMessage) } else { response.Error(c, app_errors.ErrMaxRetriesExceeded) + logrus.Debugf("Max retries exceeded for group %s after %d attempts.", group.Name, retryCount) } return } @@ -173,58 +146,54 @@ func (ps *ProxyServer) executeRequestWithRetry( req.Header = c.Request.Header.Clone() channelHandler.ModifyRequest(req, apiKey, group) - client := channelHandler.GetHTTPClient() + var client *http.Client if isStream { client = channelHandler.GetStreamClient() req.Header.Set("X-Accel-Buffering", "no") + } else { + client = channelHandler.GetHTTPClient() } resp, err := client.Do(req) - if err != nil { - ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, false) - logrus.Warnf("Request failed (attempt %d/%d) for key %s: %v", retryCount+1, cfg.MaxRetries, apiKey.KeyValue[:8], err) - - newRetryErrors := append(retryErrors, types.RetryError{ - StatusCode: 0, - ErrorMessage: err.Error(), - KeyID: fmt.Sprintf("%d", apiKey.ID), - Attempt: retryCount + 1, - }) - ps.executeRequestWithRetry(c, channelHandler, group, bodyBytes, isStream, startTime, retryCount+1, newRetryErrors) - return + if resp != nil { + defer resp.Body.Close() } - defer resp.Body.Close() - if resp.StatusCode >= 400 { + // Unified error handling for retries. + if err != nil || (resp != nil && resp.StatusCode >= 400) { + if err != nil && app_errors.IsIgnorableError(err) { + logrus.Debugf("Client-side ignorable error for key %s, aborting retries: %v", utils.MaskAPIKey(apiKey.KeyValue), err) + return + } + ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, false) - errorBody, readErr := io.ReadAll(resp.Body) - if readErr != nil { - logrus.Errorf("Failed to read error body: %v", readErr) - // Even if reading fails, we should proceed with retry logic - errorBody = []byte("Failed to read error body") - } - // Check for gzip encoding and decompress if necessary. - if resp.Header.Get("Content-Encoding") == "gzip" { - reader, err := gzip.NewReader(bytes.NewReader(errorBody)) - if err == nil { - decompressedBody, err := io.ReadAll(reader) - if err == nil { - errorBody = decompressedBody - } else { - logrus.Warnf("Failed to decompress gzip error body: %v", err) - } - reader.Close() - } else { - logrus.Warnf("Failed to create gzip reader for error body: %v", err) + var statusCode int + var errorMessage string + var parsedError string + + if err != nil { + statusCode = 0 + errorMessage = err.Error() + logrus.Debugf("Request failed (attempt %d/%d) for key %s: %v", retryCount+1, cfg.MaxRetries, utils.MaskAPIKey(apiKey.KeyValue), err) + } else { + // HTTP-level error (status >= 400) + statusCode = resp.StatusCode + errorBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + logrus.Errorf("Failed to read error body: %v", readErr) + errorBody = []byte("Failed to read error body") } - } - logrus.Warnf("Request failed with status %d (attempt %d/%d) for key %s. Body: %s", resp.StatusCode, retryCount+1, cfg.MaxRetries, apiKey.KeyValue[:8], string(errorBody)) + errorBody = handleGzipCompression(resp, errorBody) + errorMessage = string(errorBody) + parsedError = app_errors.ParseUpstreamError(errorBody) + logrus.Debugf("Request failed with status %d (attempt %d/%d) for key %s. Parsed Error: %s", statusCode, retryCount+1, cfg.MaxRetries, utils.MaskAPIKey(apiKey.KeyValue), parsedError) + } newRetryErrors := append(retryErrors, types.RetryError{ - StatusCode: resp.StatusCode, - ErrorMessage: string(errorBody), + StatusCode: statusCode, + ErrorMessage: errorMessage, KeyID: fmt.Sprintf("%d", apiKey.ID), Attempt: retryCount + 1, }) @@ -233,7 +202,7 @@ func (ps *ProxyServer) executeRequestWithRetry( } ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, true) - logrus.Debugf("Request for group %s succeeded on attempt %d with key %s", group.Name, retryCount+1, apiKey.KeyValue[:8]) + logrus.Debugf("Request for group %s succeeded on attempt %d with key %s", group.Name, retryCount+1, utils.MaskAPIKey(apiKey.KeyValue)) for key, values := range resp.Header { for _, value := range values { @@ -248,65 +217,3 @@ func (ps *ProxyServer) executeRequestWithRetry( ps.handleNormalResponse(c, resp) } } - -func (ps *ProxyServer) handleStreamingResponse(c *gin.Context, resp *http.Response) { - flusher, ok := c.Writer.(http.Flusher) - if !ok { - logrus.Error("Streaming unsupported by the writer") - ps.handleNormalResponse(c, resp) - return - } - - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - if _, err := c.Writer.Write(scanner.Bytes()); err != nil { - if !isIgnorableStreamError(err) { - logrus.Errorf("Error writing to client: %v", err) - } - return - } - if _, err := c.Writer.Write([]byte("\n")); err != nil { - if !isIgnorableStreamError(err) { - logrus.Errorf("Error writing newline to client: %v", err) - } - return - } - flusher.Flush() - } - - if err := scanner.Err(); err != nil && !isIgnorableStreamError(err) { - logrus.Errorf("Error reading from upstream: %v", err) - } -} - -func (ps *ProxyServer) handleNormalResponse(c *gin.Context, resp *http.Response) { - if _, err := io.Copy(c.Writer, resp.Body); err != nil { - if !isIgnorableStreamError(err) { - logrus.Errorf("Failed to copy response body to client: %v", err) - } - } -} - -func (ps *ProxyServer) applyParamOverrides(bodyBytes []byte, group *models.Group) ([]byte, error) { - if len(group.ParamOverrides) == 0 || len(bodyBytes) == 0 { - return bodyBytes, nil - } - - var requestData map[string]any - if err := json.Unmarshal(bodyBytes, &requestData); err != nil { - logrus.Warnf("failed to unmarshal request body for param override, passing through: %v", err) - return bodyBytes, nil - } - - for key, value := range group.ParamOverrides { - requestData[key] = value - } - - return json.Marshal(requestData) -} - -func (ps *ProxyServer) Close() { - // The HTTP clients are now managed by the channel factory and httpclient manager, - // so the proxy server itself doesn't need to close them. - // The httpclient manager will handle closing idle connections for all its clients. -} diff --git a/internal/types/retry.go b/internal/types/retry.go deleted file mode 100644 index 2a46f42..0000000 --- a/internal/types/retry.go +++ /dev/null @@ -1,9 +0,0 @@ -package types - -// RetryError captures detailed information about a failed request attempt during retries. -type RetryError struct { - StatusCode int `json:"status_code"` - ErrorMessage string `json:"error_message"` - KeyID string `json:"key_id"` - Attempt int `json:"attempt"` -} diff --git a/internal/types/types.go b/internal/types/types.go index 0ecce43..7892c75 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -17,8 +17,8 @@ type ConfigManager interface { // SystemSettings 定义所有系统配置项 type SystemSettings struct { // 基础参数 - AppUrl string `json:"app_url" default:"" name:"项目地址" category:"基础参数" desc:"项目的基础 URL,用于拼接分组终端节点地址。系统配置优先于环境变量 APP_URL。"` - RequestLogRetentionDays int `json:"request_log_retention_days" default:"30" name:"日志保留天数" category:"基础参数" desc:"请求日志在数据库中的保留天数" validate:"min=1"` + AppUrl string `json:"app_url" default:"http://localhost:3000" name:"项目地址" category:"基础参数" desc:"项目的基础 URL,用于拼接分组终端节点地址。系统配置优先于环境变量 APP_URL。"` + RequestLogRetentionDays int `json:"request_log_retention_days" default:"7" name:"日志保留天数" category:"基础参数" desc:"请求日志在数据库中的保留天数" validate:"min=1"` // 服务超时 ServerReadTimeout int `json:"server_read_timeout" default:"120" name:"读取超时" category:"服务超时" desc:"HTTP 服务器读取超时时间(秒)" validate:"min=1"` @@ -27,13 +27,13 @@ type SystemSettings struct { ServerGracefulShutdownTimeout int `json:"server_graceful_shutdown_timeout" default:"60" name:"优雅关闭超时" category:"服务超时" desc:"服务优雅关闭的等待超时时间(秒)" validate:"min=1"` // 请求超时 - RequestTimeout int `json:"request_timeout" default:"600" name:"请求超时" category:"请求超时" desc:"转发请求的完整生命周期超时(秒),包括连接、重试等。" validate:"min=1"` - ConnectTimeout int `json:"connect_timeout" default:"5" name:"连接超时" category:"请求超时" desc:"与上游服务建立新连接的超时时间(秒)。" validate:"min=1"` - IdleConnTimeout int `json:"idle_conn_timeout" default:"120" name:"空闲连接超时" category:"请求超时" desc:"HTTP 客户端中空闲连接的超时时间(秒)。" validate:"min=1"` - MaxIdleConns int `json:"max_idle_conns" default:"100" name:"最大空闲连接数" category:"请求超时" desc:"HTTP 客户端连接池中允许的最大空闲连接总数。" validate:"min=1"` - MaxIdleConnsPerHost int `json:"max_idle_conns_per_host" default:"10" name:"每主机最大空闲连接数" category:"请求超时" desc:"HTTP 客户端连接池对每个上游主机允许的最大空闲连接数。" validate:"min=1"` - ResponseHeaderTimeout int `json:"response_header_timeout" default:"120" name:"响应头超时" category:"请求超时" desc:"等待上游服务响应头的最长时间(秒),用于流式请求。" validate:"min=1"` - DisableCompression bool `json:"disable_compression" default:"false" name:"禁用压缩" category:"请求超时" desc:"是否禁用对上游请求的传输压缩(Gzip)。对于流式请求建议开启以降低延迟。"` + RequestTimeout int `json:"request_timeout" default:"600" name:"请求超时" category:"请求超时" desc:"转发请求的完整生命周期超时(秒),包括连接、重试等。" validate:"min=1"` + ConnectTimeout int `json:"connect_timeout" default:"5" name:"连接超时" category:"请求超时" desc:"与上游服务建立新连接的超时时间(秒)。" validate:"min=1"` + IdleConnTimeout int `json:"idle_conn_timeout" default:"120" name:"空闲连接超时" category:"请求超时" desc:"HTTP 客户端中空闲连接的超时时间(秒)。" validate:"min=1"` + MaxIdleConns int `json:"max_idle_conns" default:"100" name:"最大空闲连接数" category:"请求超时" desc:"HTTP 客户端连接池中允许的最大空闲连接总数。" validate:"min=1"` + MaxIdleConnsPerHost int `json:"max_idle_conns_per_host" default:"10" name:"每主机最大空闲连接数" category:"请求超时" desc:"HTTP 客户端连接池对每个上游主机允许的最大空闲连接数。" validate:"min=1"` + ResponseHeaderTimeout int `json:"response_header_timeout" default:"120" name:"响应头超时" category:"请求超时" desc:"等待上游服务响应头的最长时间(秒),用于流式请求。" validate:"min=1"` + DisableCompression bool `json:"disable_compression" default:"false" name:"禁用压缩" category:"请求超时" desc:"是否禁用对上游请求的传输压缩(Gzip)。对于流式请求建议开启以降低延迟。"` // 密钥配置 MaxRetries int `json:"max_retries" default:"3" name:"最大重试次数" category:"密钥配置" desc:"单个请求使用不同 Key 的最大重试次数" validate:"min=0"` @@ -87,3 +87,10 @@ type LogConfig struct { type DatabaseConfig struct { DSN string `json:"dsn"` } + +type RetryError struct { + StatusCode int `json:"status_code"` + ErrorMessage string `json:"error_message"` + KeyID string `json:"key_id"` + Attempt int `json:"attempt"` +} diff --git a/internal/utils/string_utils.go b/internal/utils/string_utils.go new file mode 100644 index 0000000..00b0c9b --- /dev/null +++ b/internal/utils/string_utils.go @@ -0,0 +1,12 @@ +package utils + +import "fmt" + +// MaskAPIKey masks an API key for safe logging. +func MaskAPIKey(key string) string { + length := len(key) + if length <= 8 { + return key + } + return fmt.Sprintf("%s****%s", key[:4], key[length-4:]) +}