feat: 优化流式性能

This commit is contained in:
tbphp
2025-06-07 11:34:05 +08:00
parent 8a4cb65bad
commit c94ed14357
3 changed files with 185 additions and 72 deletions

View File

@@ -4,6 +4,7 @@
package proxy
import (
"bytes"
"context"
"fmt"
"io"
@@ -24,9 +25,11 @@ import (
type ProxyServer struct {
keyManager *keymanager.KeyManager
httpClient *http.Client
streamClient *http.Client // 专门用于流式传输的客户端
upstreamURL *url.URL
requestCount int64
startTime time.Time
flushTicker *time.Ticker // 流式响应的定时flush
}
// NewProxyServer 创建新的代理服务器
@@ -57,9 +60,25 @@ func NewProxyServer() (*ProxyServer, error) {
ReadBufferSize: config.AppConfig.Performance.BufferSize,
}
// 创建专门用于流式传输的transport优化TCP参数
streamTransport := &http.Transport{
MaxIdleConns: config.AppConfig.Performance.MaxSockets * 2,
MaxIdleConnsPerHost: config.AppConfig.Performance.MaxFreeSockets * 2,
MaxConnsPerHost: 0,
IdleConnTimeout: 300 * time.Second, // 流式连接保持更长时间
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableCompression: true, // 流式传输始终禁用压缩
ForceAttemptHTTP2: true,
WriteBufferSize: config.AppConfig.Performance.StreamBufferSize,
ReadBufferSize: config.AppConfig.Performance.StreamBufferSize,
ResponseHeaderTimeout: time.Duration(config.AppConfig.Performance.StreamHeaderTimeout) * time.Millisecond,
}
// 配置 Keep-Alive
if !config.AppConfig.Performance.EnableKeepAlive {
transport.DisableKeepAlives = true
streamTransport.DisableKeepAlives = true
}
httpClient := &http.Client{
@@ -68,11 +87,17 @@ func NewProxyServer() (*ProxyServer, error) {
// Timeout: time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond,
}
// 流式客户端不设置整体超时
streamClient := &http.Client{
Transport: streamTransport,
}
return &ProxyServer{
keyManager: keyManager,
httpClient: httpClient,
upstreamURL: upstreamURL,
startTime: time.Now(),
keyManager: keyManager,
httpClient: httpClient,
streamClient: streamClient,
upstreamURL: upstreamURL,
startTime: time.Now(),
}, nil
}
@@ -309,15 +334,8 @@ func (ps *ProxyServer) handleProxy(c *gin.Context) {
// 增加请求计数
atomic.AddInt64(&ps.requestCount, 1)
// 检查是否为流式请求
isStreamRequest := strings.Contains(c.GetHeader("Accept"), "text/event-stream") ||
strings.Contains(c.Request.URL.RawQuery, "stream=true") ||
strings.Contains(c.Request.Header.Get("Content-Type"), "text/event-stream")
// 统一入口,提前缓存所有请求
var bodyBytes []byte
var requestBody io.Reader = c.Request.Body
// 为了支持重试,需要缓存请求体(包括流式请求)
if c.Request.Body != nil {
var err error
bodyBytes, err = io.ReadAll(c.Request.Body)
@@ -333,15 +351,40 @@ func (ps *ProxyServer) handleProxy(c *gin.Context) {
})
return
}
requestBody = strings.NewReader(string(bodyBytes))
}
// 使用缓存后的数据判断请求类型
isStreamRequest := ps.isStreamRequest(bodyBytes, c)
// 执行带重试的请求
ps.executeRequestWithRetry(c, startTime, bodyBytes, requestBody, isStreamRequest, 0)
ps.executeRequestWithRetry(c, startTime, bodyBytes, isStreamRequest, 0)
}
// isStreamRequest 判断是否为流式请求
func (ps *ProxyServer) isStreamRequest(bodyBytes []byte, c *gin.Context) bool {
// 检查 Accept header
if strings.Contains(c.GetHeader("Accept"), "text/event-stream") {
return true
}
// 检查 URL 查询参数
if c.Query("stream") == "true" {
return true
}
// 检查请求体中的 stream 参数
if len(bodyBytes) > 0 {
if strings.Contains(string(bodyBytes), `"stream":true`) ||
strings.Contains(string(bodyBytes), `"stream": true`) {
return true
}
}
return false
}
// executeRequestWithRetry 执行带重试的请求
func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Time, bodyBytes []byte, requestBody io.Reader, isStreamRequest bool, retryCount int) {
func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Time, bodyBytes []byte, isStreamRequest bool, retryCount int) {
// 获取密钥信息
keyInfo, err := ps.keyManager.GetNextKey()
if err != nil {
@@ -366,34 +409,32 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti
c.Set("retryCount", retryCount)
}
// 准备请求体和内容长度
var contentLength int64
if len(bodyBytes) > 0 {
// 使用缓存的请求体(支持重试)
requestBody = strings.NewReader(string(bodyBytes))
contentLength = int64(len(bodyBytes))
}
// 构建上游请求URL
targetURL := *ps.upstreamURL
targetURL.Path = c.Request.URL.Path
targetURL.RawQuery = c.Request.URL.RawQuery
// 创建带超时的上下文
timeout := time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
// 为流式和非流式请求使用不同的超时策略
var ctx context.Context
var cancel context.CancelFunc
if isStreamRequest {
// 流式请求只设置响应头超时,不设置整体超时
ctx, cancel = context.WithCancel(c.Request.Context())
} else {
// 非流式请求使用配置的超时
timeout := time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond
ctx, cancel = context.WithTimeout(c.Request.Context(), timeout)
}
defer cancel()
// 创建上游请求,直接使用原始请求体进行流式传输
// 统一使用缓存的 bodyBytes 创建请求
req, err := http.NewRequestWithContext(
ctx,
c.Request.Method,
targetURL.String(),
requestBody,
bytes.NewReader(bodyBytes),
)
if err == nil && contentLength > 0 {
req.ContentLength = contentLength
}
if err != nil {
logrus.Errorf("创建上游请求失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{
@@ -406,6 +447,7 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti
})
return
}
req.ContentLength = int64(len(bodyBytes))
// 复制请求头
for key, values := range c.Request.Header {
@@ -419,8 +461,18 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti
// 设置认证头
req.Header.Set("Authorization", "Bearer "+keyInfo.Key)
// 根据请求类型选择合适的客户端
var client *http.Client
if isStreamRequest {
client = ps.streamClient
// 添加禁用nginx缓冲的头
req.Header.Set("X-Accel-Buffering", "no")
} else {
client = ps.httpClient
}
// 发送请求
resp, err := ps.httpClient.Do(req)
resp, err := client.Do(req)
if err != nil {
responseTime := time.Since(startTime)
@@ -437,7 +489,7 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti
// 检查是否可以重试
if retryCount < config.AppConfig.Keys.MaxRetries {
logrus.Infof("准备重试请求 (第 %d/%d 次)", retryCount+1, config.AppConfig.Keys.MaxRetries)
ps.executeRequestWithRetry(c, startTime, bodyBytes, nil, isStreamRequest, retryCount+1)
ps.executeRequestWithRetry(c, startTime, bodyBytes, isStreamRequest, retryCount+1)
return
}
@@ -473,7 +525,7 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti
resp.Body.Close()
logrus.Infof("准备重试请求 (第 %d/%d 次)", retryCount+1, config.AppConfig.Keys.MaxRetries)
ps.executeRequestWithRetry(c, startTime, bodyBytes, nil, isStreamRequest, retryCount+1)
ps.executeRequestWithRetry(c, startTime, bodyBytes, isStreamRequest, retryCount+1)
return
}
@@ -493,39 +545,19 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti
}
}
// 流式响应添加禁用缓冲的头
if isStreamRequest {
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
}
// 设置状态码
c.Status(resp.StatusCode)
// 优化流式响应传输
if isStreamRequest {
// 流式响应:启用实时刷新
if flusher, ok := c.Writer.(http.Flusher); ok {
// 使用优化的缓冲区进行流式复制
buf := make([]byte, config.AppConfig.Performance.BufferSize)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
_, writeErr := c.Writer.Write(buf[:n])
if writeErr != nil {
logrus.Errorf("写入流式响应失败: %v", writeErr)
break
}
flusher.Flush() // 立即刷新到客户端
}
if err != nil {
if err != io.EOF {
logrus.Errorf("读取流式响应失败: %v", err)
}
break
}
}
} else {
// 降级到标准复制
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
logrus.Errorf("复制流式响应失败: %v", err)
}
}
ps.handleStreamResponse(c, resp.Body)
} else {
// 非流式响应:使用标准零拷贝
_, err = io.Copy(c.Writer, resp.Body)
@@ -535,9 +567,60 @@ func (ps *ProxyServer) executeRequestWithRetry(c *gin.Context, startTime time.Ti
}
}
// handleStreamResponse 处理流式响应
func (ps *ProxyServer) handleStreamResponse(c *gin.Context, body io.ReadCloser) {
defer body.Close()
flusher, ok := c.Writer.(http.Flusher)
if !ok {
// 降级到标准复制
_, err := io.Copy(c.Writer, body)
if err != nil {
logrus.Errorf("复制流式响应失败: %v", err)
}
return
}
// 使用智能flush策略
flushInterval := time.Duration(config.AppConfig.Performance.StreamFlushInterval) * time.Millisecond
lastFlush := time.Now()
// 使用更大的缓冲区以减少系统调用
buf := make([]byte, config.AppConfig.Performance.StreamBufferSize)
for {
n, err := body.Read(buf)
if n > 0 {
_, writeErr := c.Writer.Write(buf[:n])
if writeErr != nil {
logrus.Errorf("写入流式响应失败: %v", writeErr)
break
}
// 智能flush基于时间间隔或数据量
if time.Since(lastFlush) >= flushInterval || n >= config.AppConfig.Performance.StreamBufferSize/2 {
flusher.Flush()
lastFlush = time.Now()
}
}
if err != nil {
// 最后一次flush确保所有数据都发送出去
flusher.Flush()
if err != io.EOF {
logrus.Errorf("读取流式响应失败: %v", err)
}
break
}
}
}
// Close 关闭代理服务器
func (ps *ProxyServer) Close() {
if ps.keyManager != nil {
ps.keyManager.Close()
}
if ps.flushTicker != nil {
ps.flushTicker.Stop()
}
}