fix: 优化
This commit is contained in:
18
.env.example
18
.env.example
@@ -32,6 +32,24 @@ OPENAI_BASE_URL=https://api.openai.com
|
|||||||
# 请求超时时间(毫秒)
|
# 请求超时时间(毫秒)
|
||||||
REQUEST_TIMEOUT=30000
|
REQUEST_TIMEOUT=30000
|
||||||
|
|
||||||
|
# ===========================================
|
||||||
|
# 性能优化配置
|
||||||
|
# ===========================================
|
||||||
|
# 最大连接数
|
||||||
|
MAX_SOCKETS=100
|
||||||
|
|
||||||
|
# 最大空闲连接数
|
||||||
|
MAX_FREE_SOCKETS=20
|
||||||
|
|
||||||
|
# 启用 Keep-Alive
|
||||||
|
ENABLE_KEEP_ALIVE=true
|
||||||
|
|
||||||
|
# 禁用压缩(减少CPU开销)
|
||||||
|
DISABLE_COMPRESSION=true
|
||||||
|
|
||||||
|
# 缓冲区大小(字节)
|
||||||
|
BUFFER_SIZE=32768
|
||||||
|
|
||||||
# ===========================================
|
# ===========================================
|
||||||
# 认证配置
|
# 认证配置
|
||||||
# ===========================================
|
# ===========================================
|
||||||
|
13
cmd/main.go
13
cmd/main.go
@@ -44,13 +44,14 @@ func main() {
|
|||||||
// 设置路由
|
// 设置路由
|
||||||
router := proxyServer.SetupRoutes()
|
router := proxyServer.SetupRoutes()
|
||||||
|
|
||||||
// 创建HTTP服务器
|
// 创建HTTP服务器,优化超时配置
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
||||||
Handler: router,
|
Handler: router,
|
||||||
ReadTimeout: 30 * time.Second,
|
ReadTimeout: 60 * time.Second, // 增加读超时,支持大文件上传
|
||||||
WriteTimeout: 30 * time.Second,
|
WriteTimeout: 300 * time.Second, // 增加写超时,支持流式响应
|
||||||
IdleTimeout: 60 * time.Second,
|
IdleTimeout: 120 * time.Second, // 增加空闲超时,复用连接
|
||||||
|
MaxHeaderBytes: 1 << 20, // 1MB header limit
|
||||||
}
|
}
|
||||||
|
|
||||||
// 启动服务器
|
// 启动服务器
|
||||||
|
@@ -16,21 +16,21 @@ import (
|
|||||||
|
|
||||||
// Constants 配置常量
|
// Constants 配置常量
|
||||||
type Constants struct {
|
type Constants struct {
|
||||||
MinPort int
|
MinPort int
|
||||||
MaxPort int
|
MaxPort int
|
||||||
MinTimeout int
|
MinTimeout int
|
||||||
DefaultTimeout int
|
DefaultTimeout int
|
||||||
DefaultMaxSockets int
|
DefaultMaxSockets int
|
||||||
DefaultMaxFreeSockets int
|
DefaultMaxFreeSockets int
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultConstants 默认常量
|
// DefaultConstants 默认常量
|
||||||
var DefaultConstants = Constants{
|
var DefaultConstants = Constants{
|
||||||
MinPort: 1,
|
MinPort: 1,
|
||||||
MaxPort: 65535,
|
MaxPort: 65535,
|
||||||
MinTimeout: 1000,
|
MinTimeout: 1000,
|
||||||
DefaultTimeout: 30000,
|
DefaultTimeout: 30000,
|
||||||
DefaultMaxSockets: 50,
|
DefaultMaxSockets: 50,
|
||||||
DefaultMaxFreeSockets: 10,
|
DefaultMaxFreeSockets: 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,8 +67,11 @@ type CORSConfig struct {
|
|||||||
|
|
||||||
// PerformanceConfig 性能配置
|
// PerformanceConfig 性能配置
|
||||||
type PerformanceConfig struct {
|
type PerformanceConfig struct {
|
||||||
MaxSockets int `json:"maxSockets"`
|
MaxSockets int `json:"maxSockets"`
|
||||||
MaxFreeSockets int `json:"maxFreeSockets"`
|
MaxFreeSockets int `json:"maxFreeSockets"`
|
||||||
|
EnableKeepAlive bool `json:"enableKeepAlive"`
|
||||||
|
DisableCompression bool `json:"disableCompression"`
|
||||||
|
BufferSize int `json:"bufferSize"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config 应用配置
|
// Config 应用配置
|
||||||
@@ -114,8 +117,11 @@ func LoadConfig() (*Config, error) {
|
|||||||
AllowedOrigins: parseArray(os.Getenv("ALLOWED_ORIGINS"), []string{"*"}),
|
AllowedOrigins: parseArray(os.Getenv("ALLOWED_ORIGINS"), []string{"*"}),
|
||||||
},
|
},
|
||||||
Performance: PerformanceConfig{
|
Performance: PerformanceConfig{
|
||||||
MaxSockets: parseInteger(os.Getenv("MAX_SOCKETS"), DefaultConstants.DefaultMaxSockets),
|
MaxSockets: parseInteger(os.Getenv("MAX_SOCKETS"), DefaultConstants.DefaultMaxSockets),
|
||||||
MaxFreeSockets: parseInteger(os.Getenv("MAX_FREE_SOCKETS"), DefaultConstants.DefaultMaxFreeSockets),
|
MaxFreeSockets: parseInteger(os.Getenv("MAX_FREE_SOCKETS"), DefaultConstants.DefaultMaxFreeSockets),
|
||||||
|
EnableKeepAlive: parseBoolean(os.Getenv("ENABLE_KEEP_ALIVE"), true),
|
||||||
|
DisableCompression: parseBoolean(os.Getenv("DISABLE_COMPRESSION"), true),
|
||||||
|
BufferSize: parseInteger(os.Getenv("BUFFER_SIZE"), 32*1024),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -27,13 +27,13 @@ type KeyInfo struct {
|
|||||||
|
|
||||||
// Stats 统计信息
|
// Stats 统计信息
|
||||||
type Stats struct {
|
type Stats struct {
|
||||||
CurrentIndex int64 `json:"currentIndex"`
|
CurrentIndex int64 `json:"currentIndex"`
|
||||||
TotalKeys int `json:"totalKeys"`
|
TotalKeys int `json:"totalKeys"`
|
||||||
HealthyKeys int `json:"healthyKeys"`
|
HealthyKeys int `json:"healthyKeys"`
|
||||||
BlacklistedKeys int `json:"blacklistedKeys"`
|
BlacklistedKeys int `json:"blacklistedKeys"`
|
||||||
SuccessCount int64 `json:"successCount"`
|
SuccessCount int64 `json:"successCount"`
|
||||||
FailureCount int64 `json:"failureCount"`
|
FailureCount int64 `json:"failureCount"`
|
||||||
MemoryUsage MemoryUsage `json:"memoryUsage"`
|
MemoryUsage MemoryUsage `json:"memoryUsage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// MemoryUsage 内存使用情况
|
// MemoryUsage 内存使用情况
|
||||||
@@ -60,13 +60,13 @@ type BlacklistInfo struct {
|
|||||||
|
|
||||||
// KeyManager 密钥管理器
|
// KeyManager 密钥管理器
|
||||||
type KeyManager struct {
|
type KeyManager struct {
|
||||||
keysFilePath string
|
keysFilePath string
|
||||||
keys []string
|
keys []string
|
||||||
keyPreviews []string
|
keyPreviews []string
|
||||||
currentIndex int64
|
currentIndex int64
|
||||||
blacklistedKeys sync.Map
|
blacklistedKeys sync.Map
|
||||||
successCount int64
|
successCount int64
|
||||||
failureCount int64
|
failureCount int64
|
||||||
keyFailureCounts sync.Map
|
keyFailureCounts sync.Map
|
||||||
|
|
||||||
// 性能优化:预编译正则表达式
|
// 性能优化:预编译正则表达式
|
||||||
@@ -155,34 +155,38 @@ func (km *KeyManager) LoadKeys() error {
|
|||||||
func (km *KeyManager) GetNextKey() (*KeyInfo, error) {
|
func (km *KeyManager) GetNextKey() (*KeyInfo, error) {
|
||||||
km.keysMutex.RLock()
|
km.keysMutex.RLock()
|
||||||
keysLen := len(km.keys)
|
keysLen := len(km.keys)
|
||||||
km.keysMutex.RUnlock()
|
|
||||||
|
|
||||||
if keysLen == 0 {
|
if keysLen == 0 {
|
||||||
|
km.keysMutex.RUnlock()
|
||||||
return nil, fmt.Errorf("没有可用的 API 密钥")
|
return nil, fmt.Errorf("没有可用的 API 密钥")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否所有密钥都被拉黑
|
// 快速路径:直接获取下一个密钥,避免黑名单检查的开销
|
||||||
blacklistedCount := 0
|
currentIdx := atomic.AddInt64(&km.currentIndex, 1) - 1
|
||||||
km.blacklistedKeys.Range(func(key, value interface{}) bool {
|
keyIndex := int(currentIdx) % keysLen
|
||||||
blacklistedCount++
|
selectedKey := km.keys[keyIndex]
|
||||||
return true
|
keyPreview := km.keyPreviews[keyIndex]
|
||||||
})
|
km.keysMutex.RUnlock()
|
||||||
|
|
||||||
if blacklistedCount >= keysLen {
|
// 检查是否被拉黑
|
||||||
logrus.Warn("⚠️ 所有密钥都被拉黑,重置黑名单")
|
if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted {
|
||||||
km.blacklistedKeys = sync.Map{}
|
return &KeyInfo{
|
||||||
km.keyFailureCounts = sync.Map{}
|
Key: selectedKey,
|
||||||
|
Index: keyIndex,
|
||||||
|
Preview: keyPreview,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用原子操作避免竞态条件
|
// 慢速路径:寻找可用密钥
|
||||||
attempts := 0
|
attempts := 0
|
||||||
for attempts < keysLen {
|
maxAttempts := keysLen * 2 // 最多尝试两轮
|
||||||
currentIdx := atomic.AddInt64(&km.currentIndex, 1) - 1
|
|
||||||
keyIndex := int(currentIdx) % keysLen
|
for attempts < maxAttempts {
|
||||||
|
currentIdx = atomic.AddInt64(&km.currentIndex, 1) - 1
|
||||||
|
keyIndex = int(currentIdx) % keysLen
|
||||||
|
|
||||||
km.keysMutex.RLock()
|
km.keysMutex.RLock()
|
||||||
selectedKey := km.keys[keyIndex]
|
selectedKey = km.keys[keyIndex]
|
||||||
keyPreview := km.keyPreviews[keyIndex]
|
keyPreview = km.keyPreviews[keyIndex]
|
||||||
km.keysMutex.RUnlock()
|
km.keysMutex.RUnlock()
|
||||||
|
|
||||||
if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted {
|
if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted {
|
||||||
@@ -196,17 +200,32 @@ func (km *KeyManager) GetNextKey() (*KeyInfo, error) {
|
|||||||
attempts++
|
attempts++
|
||||||
}
|
}
|
||||||
|
|
||||||
// 兜底:返回第一个密钥
|
// 检查是否所有密钥都被拉黑,如果是则重置
|
||||||
km.keysMutex.RLock()
|
blacklistedCount := 0
|
||||||
firstKey := km.keys[0]
|
km.blacklistedKeys.Range(func(key, value interface{}) bool {
|
||||||
firstPreview := km.keyPreviews[0]
|
blacklistedCount++
|
||||||
km.keysMutex.RUnlock()
|
return blacklistedCount < keysLen // 提前退出优化
|
||||||
|
})
|
||||||
|
|
||||||
return &KeyInfo{
|
if blacklistedCount >= keysLen {
|
||||||
Key: firstKey,
|
logrus.Warn("⚠️ 所有密钥都被拉黑,重置黑名单")
|
||||||
Index: 0,
|
km.blacklistedKeys = sync.Map{}
|
||||||
Preview: firstPreview,
|
km.keyFailureCounts = sync.Map{}
|
||||||
}, nil
|
|
||||||
|
// 重置后返回第一个密钥
|
||||||
|
km.keysMutex.RLock()
|
||||||
|
firstKey := km.keys[0]
|
||||||
|
firstPreview := km.keyPreviews[0]
|
||||||
|
km.keysMutex.RUnlock()
|
||||||
|
|
||||||
|
return &KeyInfo{
|
||||||
|
Key: firstKey,
|
||||||
|
Index: 0,
|
||||||
|
Preview: firstPreview,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("暂时没有可用的 API 密钥")
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordSuccess 记录密钥使用成功
|
// RecordSuccess 记录密钥使用成功
|
||||||
|
@@ -4,7 +4,6 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -46,16 +45,22 @@ func NewProxyServer() (*ProxyServer, error) {
|
|||||||
|
|
||||||
// 创建高性能HTTP客户端
|
// 创建高性能HTTP客户端
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
MaxIdleConns: config.AppConfig.Performance.MaxSockets,
|
MaxIdleConns: config.AppConfig.Performance.MaxSockets,
|
||||||
MaxIdleConnsPerHost: config.AppConfig.Performance.MaxFreeSockets,
|
MaxIdleConnsPerHost: config.AppConfig.Performance.MaxFreeSockets,
|
||||||
IdleConnTimeout: 90 * time.Second,
|
MaxConnsPerHost: 0, // 无限制,避免连接池瓶颈
|
||||||
DisableCompression: false,
|
IdleConnTimeout: 90 * time.Second,
|
||||||
ForceAttemptHTTP2: true,
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
DisableCompression: true, // 禁用压缩以减少CPU开销,让上游处理
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
WriteBufferSize: 32 * 1024, // 32KB写缓冲
|
||||||
|
ReadBufferSize: 32 * 1024, // 32KB读缓冲
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Transport: transport,
|
Transport: transport,
|
||||||
Timeout: time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond,
|
// 移除全局超时,使用更细粒度的超时控制
|
||||||
|
// Timeout: time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond,
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ProxyServer{
|
return &ProxyServer{
|
||||||
@@ -180,9 +185,19 @@ func (ps *ProxyServer) authMiddleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// loggerMiddleware 自定义日志中间件
|
// loggerMiddleware 高性能日志中间件
|
||||||
func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc {
|
func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
// 只在非生产环境或调试模式下记录详细日志
|
||||||
|
if gin.Mode() == gin.ReleaseMode {
|
||||||
|
// 生产模式下只记录错误和关键信息
|
||||||
|
c.Next()
|
||||||
|
if c.Writer.Status() >= 400 {
|
||||||
|
logrus.Errorf("Error %d: %s %s", c.Writer.Status(), c.Request.Method, c.Request.URL.Path)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
path := c.Request.URL.Path
|
path := c.Request.URL.Path
|
||||||
raw := c.Request.URL.RawQuery
|
raw := c.Request.URL.RawQuery
|
||||||
@@ -193,16 +208,14 @@ func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc {
|
|||||||
// 计算响应时间
|
// 计算响应时间
|
||||||
latency := time.Since(start)
|
latency := time.Since(start)
|
||||||
|
|
||||||
// 获取客户端IP
|
// 获取基本信息
|
||||||
clientIP := c.ClientIP()
|
|
||||||
|
|
||||||
// 获取方法和状态码
|
|
||||||
method := c.Request.Method
|
method := c.Request.Method
|
||||||
statusCode := c.Writer.Status()
|
statusCode := c.Writer.Status()
|
||||||
|
|
||||||
// 构建完整路径
|
// 构建完整路径(避免字符串拼接)
|
||||||
|
fullPath := path
|
||||||
if raw != "" {
|
if raw != "" {
|
||||||
path = path + "?" + raw
|
fullPath = path + "?" + raw
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取密钥信息(如果存在)
|
// 获取密钥信息(如果存在)
|
||||||
@@ -213,22 +226,8 @@ func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 根据状态码选择颜色
|
// 简化日志输出,减少格式化开销
|
||||||
var statusColor string
|
logrus.Infof("%s %s - %d - %v%s", method, fullPath, statusCode, latency, keyInfo)
|
||||||
if statusCode >= 200 && statusCode < 300 {
|
|
||||||
statusColor = "\033[32m" // 绿色
|
|
||||||
} else {
|
|
||||||
statusColor = "\033[31m" // 红色
|
|
||||||
}
|
|
||||||
resetColor := "\033[0m"
|
|
||||||
keyColor := "\033[36m" // 青色
|
|
||||||
|
|
||||||
// 输出日志
|
|
||||||
logrus.Infof("%s[%s] %s %s%s%s%s - %s%d%s - %v - %s",
|
|
||||||
statusColor, time.Now().Format(time.RFC3339), method, path, resetColor,
|
|
||||||
keyColor, keyInfo, resetColor,
|
|
||||||
statusColor, statusCode, resetColor,
|
|
||||||
latency, clientIP)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -311,36 +310,30 @@ func (ps *ProxyServer) handleProxy(c *gin.Context) {
|
|||||||
c.Set("keyIndex", keyInfo.Index)
|
c.Set("keyIndex", keyInfo.Index)
|
||||||
c.Set("keyPreview", keyInfo.Preview)
|
c.Set("keyPreview", keyInfo.Preview)
|
||||||
|
|
||||||
// 读取请求体
|
// 直接使用请求体,避免完整读取到内存
|
||||||
var bodyBytes []byte
|
var requestBody io.Reader = c.Request.Body
|
||||||
if c.Request.Body != nil {
|
var contentLength int64 = c.Request.ContentLength
|
||||||
bodyBytes, err = io.ReadAll(c.Request.Body)
|
|
||||||
if err != nil {
|
|
||||||
logrus.Errorf("读取请求体失败: %v", err)
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"message": "读取请求体失败",
|
|
||||||
"type": "request_error",
|
|
||||||
"code": "invalid_request_body",
|
|
||||||
"timestamp": time.Now().Format(time.RFC3339),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建上游请求URL
|
// 构建上游请求URL
|
||||||
targetURL := *ps.upstreamURL
|
targetURL := *ps.upstreamURL
|
||||||
targetURL.Path = c.Request.URL.Path
|
targetURL.Path = c.Request.URL.Path
|
||||||
targetURL.RawQuery = c.Request.URL.RawQuery
|
targetURL.RawQuery = c.Request.URL.RawQuery
|
||||||
|
|
||||||
// 创建上游请求
|
// 创建带超时的上下文
|
||||||
|
timeout := time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 创建上游请求,直接使用原始请求体进行流式传输
|
||||||
req, err := http.NewRequestWithContext(
|
req, err := http.NewRequestWithContext(
|
||||||
context.Background(),
|
ctx,
|
||||||
c.Request.Method,
|
c.Request.Method,
|
||||||
targetURL.String(),
|
targetURL.String(),
|
||||||
bytes.NewReader(bodyBytes),
|
requestBody,
|
||||||
)
|
)
|
||||||
|
if err == nil && contentLength > 0 {
|
||||||
|
req.ContentLength = contentLength
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorf("创建上游请求失败: %v", err)
|
logrus.Errorf("创建上游请求失败: %v", err)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
|
Reference in New Issue
Block a user