From c94ed14357b8b1102795cd42fad3f12e7135b506 Mon Sep 17 00:00:00 2001 From: tbphp Date: Sat, 7 Jun 2025 11:34:05 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 9 ++ internal/config/config.go | 41 ++++++-- internal/proxy/proxy.go | 207 ++++++++++++++++++++++++++------------ 3 files changed, 185 insertions(+), 72 deletions(-) diff --git a/.env.example b/.env.example index 992c627..c6de4b2 100644 --- a/.env.example +++ b/.env.example @@ -53,6 +53,15 @@ DISABLE_COMPRESSION=true # 缓冲区大小(字节,建议流式响应使用64KB或更大) BUFFER_SIZE=65536 +# 流式传输缓冲区大小(字节,默认64KB) +STREAM_BUFFER_SIZE=65536 + +# 流式传输flush间隔(毫秒,默认100ms) +STREAM_FLUSH_INTERVAL=100 + +# 流式请求响应头超时(毫秒,默认10秒) +STREAM_HEADER_TIMEOUT=10000 + # =========================================== # 日志配置 # =========================================== diff --git a/internal/config/config.go b/internal/config/config.go index e7f9ed4..f776383 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -68,11 +68,14 @@ type CORSConfig struct { // PerformanceConfig 性能配置 type PerformanceConfig struct { - MaxSockets int `json:"maxSockets"` - MaxFreeSockets int `json:"maxFreeSockets"` - EnableKeepAlive bool `json:"enableKeepAlive"` - DisableCompression bool `json:"disableCompression"` - BufferSize int `json:"bufferSize"` + MaxSockets int `json:"maxSockets"` + MaxFreeSockets int `json:"maxFreeSockets"` + EnableKeepAlive bool `json:"enableKeepAlive"` + DisableCompression bool `json:"disableCompression"` + BufferSize int `json:"bufferSize"` + StreamBufferSize int `json:"streamBufferSize"` // 流式传输缓冲区大小 + StreamFlushInterval int `json:"streamFlushInterval"` // 流式传输flush间隔(毫秒) + StreamHeaderTimeout int `json:"streamHeaderTimeout"` // 流式请求响应头超时(毫秒) } // LogConfig 日志配置 @@ -129,11 +132,14 @@ func LoadConfig() (*Config, error) { AllowedOrigins: parseArray(os.Getenv("ALLOWED_ORIGINS"), []string{"*"}), }, Performance: PerformanceConfig{ - MaxSockets: parseInteger(os.Getenv("MAX_SOCKETS"), DefaultConstants.DefaultMaxSockets), - MaxFreeSockets: parseInteger(os.Getenv("MAX_FREE_SOCKETS"), DefaultConstants.DefaultMaxFreeSockets), - EnableKeepAlive: parseBoolean(os.Getenv("ENABLE_KEEP_ALIVE"), true), - DisableCompression: parseBoolean(os.Getenv("DISABLE_COMPRESSION"), true), - BufferSize: parseInteger(os.Getenv("BUFFER_SIZE"), 32*1024), + MaxSockets: parseInteger(os.Getenv("MAX_SOCKETS"), DefaultConstants.DefaultMaxSockets), + MaxFreeSockets: parseInteger(os.Getenv("MAX_FREE_SOCKETS"), DefaultConstants.DefaultMaxFreeSockets), + EnableKeepAlive: parseBoolean(os.Getenv("ENABLE_KEEP_ALIVE"), true), + DisableCompression: parseBoolean(os.Getenv("DISABLE_COMPRESSION"), true), + BufferSize: parseInteger(os.Getenv("BUFFER_SIZE"), 32*1024), + StreamBufferSize: parseInteger(os.Getenv("STREAM_BUFFER_SIZE"), 64*1024), // 默认64KB + StreamFlushInterval: parseInteger(os.Getenv("STREAM_FLUSH_INTERVAL"), 100), // 默认100ms + StreamHeaderTimeout: parseInteger(os.Getenv("STREAM_HEADER_TIMEOUT"), 10000), // 默认10秒 }, Log: LogConfig{ Level: getEnvOrDefault("LOG_LEVEL", "info"), @@ -191,6 +197,18 @@ func validateConfig(config *Config) error { errors = append(errors, "最大空闲连接数不能小于 0") } + if config.Performance.StreamBufferSize < 1024 { + errors = append(errors, "流式缓冲区大小不能小于 1KB") + } + + if config.Performance.StreamFlushInterval < 10 { + errors = append(errors, "流式flush间隔不能小于 10ms") + } + + if config.Performance.StreamHeaderTimeout < 1000 { + errors = append(errors, "流式响应头超时不能小于 1秒") + } + if len(errors) > 0 { logrus.Error("❌ 配置验证失败:") for _, err := range errors { @@ -238,6 +256,9 @@ func DisplayConfig(config *Config) { } logrus.Infof(" 压缩: %s", compressionStatus) logrus.Infof(" 缓冲区大小: %d bytes", config.Performance.BufferSize) + logrus.Infof(" 流式缓冲区: %d bytes", config.Performance.StreamBufferSize) + logrus.Infof(" 流式Flush间隔: %dms", config.Performance.StreamFlushInterval) + logrus.Infof(" 流式响应头超时: %dms", config.Performance.StreamHeaderTimeout) // 显示日志配置 requestLogStatus := "已启用" diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 60cbeec..8beb6fc 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -4,6 +4,7 @@ package proxy import ( + "bytes" "context" "fmt" "io" @@ -24,9 +25,11 @@ import ( type ProxyServer struct { keyManager *keymanager.KeyManager httpClient *http.Client + streamClient *http.Client // 专门用于流式传输的客户端 upstreamURL *url.URL requestCount int64 startTime time.Time + flushTicker *time.Ticker // 流式响应的定时flush } // NewProxyServer 创建新的代理服务器 @@ -57,9 +60,25 @@ func NewProxyServer() (*ProxyServer, error) { ReadBufferSize: config.AppConfig.Performance.BufferSize, } + // 创建专门用于流式传输的transport,优化TCP参数 + streamTransport := &http.Transport{ + MaxIdleConns: config.AppConfig.Performance.MaxSockets * 2, + MaxIdleConnsPerHost: config.AppConfig.Performance.MaxFreeSockets * 2, + MaxConnsPerHost: 0, + IdleConnTimeout: 300 * time.Second, // 流式连接保持更长时间 + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DisableCompression: true, // 流式传输始终禁用压缩 + ForceAttemptHTTP2: true, + WriteBufferSize: config.AppConfig.Performance.StreamBufferSize, + ReadBufferSize: config.AppConfig.Performance.StreamBufferSize, + ResponseHeaderTimeout: time.Duration(config.AppConfig.Performance.StreamHeaderTimeout) * time.Millisecond, + } + // 配置 Keep-Alive if !config.AppConfig.Performance.EnableKeepAlive { transport.DisableKeepAlives = true + streamTransport.DisableKeepAlives = true } httpClient := &http.Client{ @@ -68,11 +87,17 @@ func NewProxyServer() (*ProxyServer, error) { // Timeout: time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond, } + // 流式客户端不设置整体超时 + streamClient := &http.Client{ + Transport: streamTransport, + } + return &ProxyServer{ - keyManager: keyManager, - httpClient: httpClient, - upstreamURL: upstreamURL, - startTime: time.Now(), + keyManager: keyManager, + httpClient: httpClient, + streamClient: streamClient, + upstreamURL: upstreamURL, + startTime: time.Now(), }, nil } @@ -309,15 +334,8 @@ func (ps *ProxyServer) handleProxy(c *gin.Context) { // 增加请求计数 atomic.AddInt64(&ps.requestCount, 1) - // 检查是否为流式请求 - isStreamRequest := strings.Contains(c.GetHeader("Accept"), "text/event-stream") || - strings.Contains(c.Request.URL.RawQuery, "stream=true") || - strings.Contains(c.Request.Header.Get("Content-Type"), "text/event-stream") - + // 统一入口,提前缓存所有请求体 var bodyBytes []byte - var requestBody io.Reader = c.Request.Body - - // 为了支持重试,需要缓存请求体(包括流式请求) if c.Request.Body != nil { var err error bodyBytes, err = io.ReadAll(c.Request.Body) @@ -333,15 +351,40 @@ func (ps *ProxyServer) handleProxy(c *gin.Context) { }) return } - requestBody = strings.NewReader(string(bodyBytes)) } + // 使用缓存后的数据判断请求类型 + isStreamRequest := ps.isStreamRequest(bodyBytes, c) + // 执行带重试的请求 - ps.executeRequestWithRetry(c, startTime, bodyBytes, requestBody, isStreamRequest, 0) + ps.executeRequestWithRetry(c, startTime, bodyBytes, isStreamRequest, 0) +} + +// isStreamRequest 判断是否为流式请求 +func (ps *ProxyServer) isStreamRequest(bodyBytes []byte, c *gin.Context) bool { + // 检查 Accept header + if strings.Contains(c.GetHeader("Accept"), "text/event-stream") { + return true + } + + // 检查 URL 查询参数 + if c.Query("stream") == "true" { + return true + } + + // 检查请求体中的 stream 参数 + if len(bodyBytes) > 0 { + if strings.Contains(string(bodyBytes), `"stream":true`) || + strings.Contains(string(bodyBytes), `"stream": true`) { + return true + } + } + + return false } // executeRequestWithRetry 执行带重试的请求 -func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Time, bodyBytes []byte, requestBody io.Reader, isStreamRequest bool, retryCount int) { +func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Time, bodyBytes []byte, isStreamRequest bool, retryCount int) { // 获取密钥信息 keyInfo, err := ps.keyManager.GetNextKey() if err != nil { @@ -366,34 +409,32 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti c.Set("retryCount", retryCount) } - // 准备请求体和内容长度 - var contentLength int64 - if len(bodyBytes) > 0 { - // 使用缓存的请求体(支持重试) - requestBody = strings.NewReader(string(bodyBytes)) - contentLength = int64(len(bodyBytes)) - } - // 构建上游请求URL targetURL := *ps.upstreamURL targetURL.Path = c.Request.URL.Path targetURL.RawQuery = c.Request.URL.RawQuery - // 创建带超时的上下文 - timeout := time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond - ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + // 为流式和非流式请求使用不同的超时策略 + var ctx context.Context + var cancel context.CancelFunc + + if isStreamRequest { + // 流式请求只设置响应头超时,不设置整体超时 + ctx, cancel = context.WithCancel(c.Request.Context()) + } else { + // 非流式请求使用配置的超时 + timeout := time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond + ctx, cancel = context.WithTimeout(c.Request.Context(), timeout) + } defer cancel() - // 创建上游请求,直接使用原始请求体进行流式传输 + // 统一使用缓存的 bodyBytes 创建请求 req, err := http.NewRequestWithContext( ctx, c.Request.Method, targetURL.String(), - requestBody, + bytes.NewReader(bodyBytes), ) - if err == nil && contentLength > 0 { - req.ContentLength = contentLength - } if err != nil { logrus.Errorf("创建上游请求失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{ @@ -406,6 +447,7 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti }) return } + req.ContentLength = int64(len(bodyBytes)) // 复制请求头 for key, values := range c.Request.Header { @@ -419,8 +461,18 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti // 设置认证头 req.Header.Set("Authorization", "Bearer "+keyInfo.Key) + // 根据请求类型选择合适的客户端 + var client *http.Client + if isStreamRequest { + client = ps.streamClient + // 添加禁用nginx缓冲的头 + req.Header.Set("X-Accel-Buffering", "no") + } else { + client = ps.httpClient + } + // 发送请求 - resp, err := ps.httpClient.Do(req) + resp, err := client.Do(req) if err != nil { responseTime := time.Since(startTime) @@ -437,7 +489,7 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti // 检查是否可以重试 if retryCount < config.AppConfig.Keys.MaxRetries { logrus.Infof("准备重试请求 (第 %d/%d 次)", retryCount+1, config.AppConfig.Keys.MaxRetries) - ps.executeRequestWithRetry(c, startTime, bodyBytes, nil, isStreamRequest, retryCount+1) + ps.executeRequestWithRetry(c, startTime, bodyBytes, isStreamRequest, retryCount+1) return } @@ -473,7 +525,7 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti resp.Body.Close() logrus.Infof("准备重试请求 (第 %d/%d 次)", retryCount+1, config.AppConfig.Keys.MaxRetries) - ps.executeRequestWithRetry(c, startTime, bodyBytes, nil, isStreamRequest, retryCount+1) + ps.executeRequestWithRetry(c, startTime, bodyBytes, isStreamRequest, retryCount+1) return } @@ -493,39 +545,19 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti } } + // 流式响应添加禁用缓冲的头 + if isStreamRequest { + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + } + // 设置状态码 c.Status(resp.StatusCode) // 优化流式响应传输 if isStreamRequest { - // 流式响应:启用实时刷新 - if flusher, ok := c.Writer.(http.Flusher); ok { - // 使用优化的缓冲区进行流式复制 - buf := make([]byte, config.AppConfig.Performance.BufferSize) - for { - n, err := resp.Body.Read(buf) - if n > 0 { - _, writeErr := c.Writer.Write(buf[:n]) - if writeErr != nil { - logrus.Errorf("写入流式响应失败: %v", writeErr) - break - } - flusher.Flush() // 立即刷新到客户端 - } - if err != nil { - if err != io.EOF { - logrus.Errorf("读取流式响应失败: %v", err) - } - break - } - } - } else { - // 降级到标准复制 - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - logrus.Errorf("复制流式响应失败: %v", err) - } - } + ps.handleStreamResponse(c, resp.Body) } else { // 非流式响应:使用标准零拷贝 _, err = io.Copy(c.Writer, resp.Body) @@ -535,9 +567,60 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti } } +// handleStreamResponse 处理流式响应 +func (ps *ProxyServer) handleStreamResponse(c *gin.Context, body io.ReadCloser) { + defer body.Close() + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + // 降级到标准复制 + _, err := io.Copy(c.Writer, body) + if err != nil { + logrus.Errorf("复制流式响应失败: %v", err) + } + return + } + + // 使用智能flush策略 + flushInterval := time.Duration(config.AppConfig.Performance.StreamFlushInterval) * time.Millisecond + lastFlush := time.Now() + + // 使用更大的缓冲区以减少系统调用 + buf := make([]byte, config.AppConfig.Performance.StreamBufferSize) + + for { + n, err := body.Read(buf) + if n > 0 { + _, writeErr := c.Writer.Write(buf[:n]) + if writeErr != nil { + logrus.Errorf("写入流式响应失败: %v", writeErr) + break + } + + // 智能flush:基于时间间隔或数据量 + if time.Since(lastFlush) >= flushInterval || n >= config.AppConfig.Performance.StreamBufferSize/2 { + flusher.Flush() + lastFlush = time.Now() + } + } + + if err != nil { + // 最后一次flush确保所有数据都发送出去 + flusher.Flush() + if err != io.EOF { + logrus.Errorf("读取流式响应失败: %v", err) + } + break + } + } +} + // Close 关闭代理服务器 func (ps *ProxyServer) Close() { if ps.keyManager != nil { ps.keyManager.Close() } + if ps.flushTicker != nil { + ps.flushTicker.Stop() + } }