fix: 优化

This commit is contained in:
tbphp
2025-06-06 21:56:40 +08:00
parent 219c068dbf
commit 1faa0b1b73
5 changed files with 193 additions and 156 deletions

View File

@@ -32,6 +32,24 @@ OPENAI_BASE_URL=https://api.openai.com
# 请求超时时间(毫秒)
REQUEST_TIMEOUT=30000
# ===========================================
# 性能优化配置
# ===========================================
# 最大连接数
MAX_SOCKETS=100
# 最大空闲连接数
MAX_FREE_SOCKETS=20
# 启用 Keep-Alive
ENABLE_KEEP_ALIVE=true
# 禁用压缩减少CPU开销
DISABLE_COMPRESSION=true
# 缓冲区大小(字节)
BUFFER_SIZE=32768
# ===========================================
# 认证配置
# ===========================================

View File

@@ -44,13 +44,14 @@ func main() {
// 设置路由
router := proxyServer.SetupRoutes()
// 创建HTTP服务器
// 创建HTTP服务器,优化超时配置
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
Handler: router,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
ReadTimeout: 60 * time.Second, // 增加读超时,支持大文件上传
WriteTimeout: 300 * time.Second, // 增加写超时,支持流式响应
IdleTimeout: 120 * time.Second, // 增加空闲超时,复用连接
MaxHeaderBytes: 1 << 20, // 1MB header limit
}
// 启动服务器

View File

@@ -69,6 +69,9 @@ type CORSConfig struct {
type PerformanceConfig struct {
MaxSockets int `json:"maxSockets"`
MaxFreeSockets int `json:"maxFreeSockets"`
EnableKeepAlive bool `json:"enableKeepAlive"`
DisableCompression bool `json:"disableCompression"`
BufferSize int `json:"bufferSize"`
}
// Config 应用配置
@@ -116,6 +119,9 @@ func LoadConfig() (*Config, error) {
Performance: PerformanceConfig{
MaxSockets: parseInteger(os.Getenv("MAX_SOCKETS"), DefaultConstants.DefaultMaxSockets),
MaxFreeSockets: parseInteger(os.Getenv("MAX_FREE_SOCKETS"), DefaultConstants.DefaultMaxFreeSockets),
EnableKeepAlive: parseBoolean(os.Getenv("ENABLE_KEEP_ALIVE"), true),
DisableCompression: parseBoolean(os.Getenv("DISABLE_COMPRESSION"), true),
BufferSize: parseInteger(os.Getenv("BUFFER_SIZE"), 32*1024),
},
}

View File

@@ -155,36 +155,40 @@ func (km *KeyManager) LoadKeys() error {
func (km *KeyManager) GetNextKey() (*KeyInfo, error) {
km.keysMutex.RLock()
keysLen := len(km.keys)
km.keysMutex.RUnlock()
if keysLen == 0 {
km.keysMutex.RUnlock()
return nil, fmt.Errorf("没有可用的 API 密钥")
}
// 检查是否所有密钥都被拉黑
blacklistedCount := 0
km.blacklistedKeys.Range(func(key, value interface{}) bool {
blacklistedCount++
return true
})
if blacklistedCount >= keysLen {
logrus.Warn("⚠️ 所有密钥都被拉黑,重置黑名单")
km.blacklistedKeys = sync.Map{}
km.keyFailureCounts = sync.Map{}
}
// 使用原子操作避免竞态条件
attempts := 0
for attempts < keysLen {
// 快速路径:直接获取下一个密钥,避免黑名单检查的开销
currentIdx := atomic.AddInt64(&km.currentIndex, 1) - 1
keyIndex := int(currentIdx) % keysLen
km.keysMutex.RLock()
selectedKey := km.keys[keyIndex]
keyPreview := km.keyPreviews[keyIndex]
km.keysMutex.RUnlock()
// 检查是否被拉黑
if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted {
return &KeyInfo{
Key: selectedKey,
Index: keyIndex,
Preview: keyPreview,
}, nil
}
// 慢速路径:寻找可用密钥
attempts := 0
maxAttempts := keysLen * 2 // 最多尝试两轮
for attempts < maxAttempts {
currentIdx = atomic.AddInt64(&km.currentIndex, 1) - 1
keyIndex = int(currentIdx) % keysLen
km.keysMutex.RLock()
selectedKey = km.keys[keyIndex]
keyPreview = km.keyPreviews[keyIndex]
km.keysMutex.RUnlock()
if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted {
return &KeyInfo{
Key: selectedKey,
@@ -196,7 +200,19 @@ func (km *KeyManager) GetNextKey() (*KeyInfo, error) {
attempts++
}
// 兜底:返回第一个密钥
// 检查是否所有密钥都被拉黑,如果是则重置
blacklistedCount := 0
km.blacklistedKeys.Range(func(key, value interface{}) bool {
blacklistedCount++
return blacklistedCount < keysLen // 提前退出优化
})
if blacklistedCount >= keysLen {
logrus.Warn("⚠️ 所有密钥都被拉黑,重置黑名单")
km.blacklistedKeys = sync.Map{}
km.keyFailureCounts = sync.Map{}
// 重置后返回第一个密钥
km.keysMutex.RLock()
firstKey := km.keys[0]
firstPreview := km.keyPreviews[0]
@@ -209,6 +225,9 @@ func (km *KeyManager) GetNextKey() (*KeyInfo, error) {
}, nil
}
return nil, fmt.Errorf("暂时没有可用的 API 密钥")
}
// RecordSuccess 记录密钥使用成功
func (km *KeyManager) RecordSuccess(key string) {
atomic.AddInt64(&km.successCount, 1)

View File

@@ -4,7 +4,6 @@
package proxy
import (
"bytes"
"context"
"fmt"
"io"
@@ -48,14 +47,20 @@ func NewProxyServer() (*ProxyServer, error) {
transport := &http.Transport{
MaxIdleConns: config.AppConfig.Performance.MaxSockets,
MaxIdleConnsPerHost: config.AppConfig.Performance.MaxFreeSockets,
MaxConnsPerHost: 0, // 无限制,避免连接池瓶颈
IdleConnTimeout: 90 * time.Second,
DisableCompression: false,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableCompression: true, // 禁用压缩以减少CPU开销让上游处理
ForceAttemptHTTP2: true,
WriteBufferSize: 32 * 1024, // 32KB写缓冲
ReadBufferSize: 32 * 1024, // 32KB读缓冲
}
httpClient := &http.Client{
Transport: transport,
Timeout: time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond,
// 移除全局超时,使用更细粒度的超时控制
// Timeout: time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond,
}
return &ProxyServer{
@@ -180,9 +185,19 @@ func (ps *ProxyServer) authMiddleware() gin.HandlerFunc {
}
}
// loggerMiddleware 自定义日志中间件
// loggerMiddleware 高性能日志中间件
func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 只在非生产环境或调试模式下记录详细日志
if gin.Mode() == gin.ReleaseMode {
// 生产模式下只记录错误和关键信息
c.Next()
if c.Writer.Status() >= 400 {
logrus.Errorf("Error %d: %s %s", c.Writer.Status(), c.Request.Method, c.Request.URL.Path)
}
return
}
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
@@ -193,16 +208,14 @@ func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc {
// 计算响应时间
latency := time.Since(start)
// 获取客户端IP
clientIP := c.ClientIP()
// 获取方法和状态码
// 获取基本信息
method := c.Request.Method
statusCode := c.Writer.Status()
// 构建完整路径
// 构建完整路径(避免字符串拼接)
fullPath := path
if raw != "" {
path = path + "?" + raw
fullPath = path + "?" + raw
}
// 获取密钥信息(如果存在)
@@ -213,22 +226,8 @@ func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc {
}
}
// 根据状态码选择颜色
var statusColor string
if statusCode >= 200 && statusCode < 300 {
statusColor = "\033[32m" // 绿色
} else {
statusColor = "\033[31m" // 红色
}
resetColor := "\033[0m"
keyColor := "\033[36m" // 青色
// 输出日志
logrus.Infof("%s[%s] %s %s%s%s%s - %s%d%s - %v - %s",
statusColor, time.Now().Format(time.RFC3339), method, path, resetColor,
keyColor, keyInfo, resetColor,
statusColor, statusCode, resetColor,
latency, clientIP)
// 简化日志输出,减少格式化开销
logrus.Infof("%s %s - %d - %v%s", method, fullPath, statusCode, latency, keyInfo)
}
}
@@ -311,36 +310,30 @@ func (ps *ProxyServer) handleProxy(c *gin.Context) {
c.Set("keyIndex", keyInfo.Index)
c.Set("keyPreview", keyInfo.Preview)
// 读取请求体
var bodyBytes []byte
if c.Request.Body != nil {
bodyBytes, err = io.ReadAll(c.Request.Body)
if err != nil {
logrus.Errorf("读取请求体失败: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "读取请求体失败",
"type": "request_error",
"code": "invalid_request_body",
"timestamp": time.Now().Format(time.RFC3339),
},
})
return
}
}
// 直接使用请求体,避免完整读取到内存
var requestBody io.Reader = c.Request.Body
var contentLength int64 = c.Request.ContentLength
// 构建上游请求URL
targetURL := *ps.upstreamURL
targetURL.Path = c.Request.URL.Path
targetURL.RawQuery = c.Request.URL.RawQuery
// 创建上游请求
// 创建带超时的上下文
timeout := time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// 创建上游请求,直接使用原始请求体进行流式传输
req, err := http.NewRequestWithContext(
context.Background(),
ctx,
c.Request.Method,
targetURL.String(),
bytes.NewReader(bodyBytes),
requestBody,
)
if err == nil && contentLength > 0 {
req.ContentLength = contentLength
}
if err != nil {
logrus.Errorf("创建上游请求失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{