fix: 优化
This commit is contained in:
@@ -16,21 +16,21 @@ import (
|
||||
|
||||
// Constants 配置常量
|
||||
type Constants struct {
|
||||
MinPort int
|
||||
MaxPort int
|
||||
MinTimeout int
|
||||
DefaultTimeout int
|
||||
DefaultMaxSockets int
|
||||
MinPort int
|
||||
MaxPort int
|
||||
MinTimeout int
|
||||
DefaultTimeout int
|
||||
DefaultMaxSockets int
|
||||
DefaultMaxFreeSockets int
|
||||
}
|
||||
|
||||
// DefaultConstants 默认常量
|
||||
var DefaultConstants = Constants{
|
||||
MinPort: 1,
|
||||
MaxPort: 65535,
|
||||
MinTimeout: 1000,
|
||||
DefaultTimeout: 30000,
|
||||
DefaultMaxSockets: 50,
|
||||
MinPort: 1,
|
||||
MaxPort: 65535,
|
||||
MinTimeout: 1000,
|
||||
DefaultTimeout: 30000,
|
||||
DefaultMaxSockets: 50,
|
||||
DefaultMaxFreeSockets: 10,
|
||||
}
|
||||
|
||||
@@ -67,8 +67,11 @@ type CORSConfig struct {
|
||||
|
||||
// PerformanceConfig 性能配置
|
||||
type PerformanceConfig struct {
|
||||
MaxSockets int `json:"maxSockets"`
|
||||
MaxFreeSockets int `json:"maxFreeSockets"`
|
||||
MaxSockets int `json:"maxSockets"`
|
||||
MaxFreeSockets int `json:"maxFreeSockets"`
|
||||
EnableKeepAlive bool `json:"enableKeepAlive"`
|
||||
DisableCompression bool `json:"disableCompression"`
|
||||
BufferSize int `json:"bufferSize"`
|
||||
}
|
||||
|
||||
// Config 应用配置
|
||||
@@ -114,8 +117,11 @@ func LoadConfig() (*Config, error) {
|
||||
AllowedOrigins: parseArray(os.Getenv("ALLOWED_ORIGINS"), []string{"*"}),
|
||||
},
|
||||
Performance: PerformanceConfig{
|
||||
MaxSockets: parseInteger(os.Getenv("MAX_SOCKETS"), DefaultConstants.DefaultMaxSockets),
|
||||
MaxFreeSockets: parseInteger(os.Getenv("MAX_FREE_SOCKETS"), DefaultConstants.DefaultMaxFreeSockets),
|
||||
MaxSockets: parseInteger(os.Getenv("MAX_SOCKETS"), DefaultConstants.DefaultMaxSockets),
|
||||
MaxFreeSockets: parseInteger(os.Getenv("MAX_FREE_SOCKETS"), DefaultConstants.DefaultMaxFreeSockets),
|
||||
EnableKeepAlive: parseBoolean(os.Getenv("ENABLE_KEEP_ALIVE"), true),
|
||||
DisableCompression: parseBoolean(os.Getenv("DISABLE_COMPRESSION"), true),
|
||||
BufferSize: parseInteger(os.Getenv("BUFFER_SIZE"), 32*1024),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -186,13 +192,13 @@ func DisplayConfig(config *Config) {
|
||||
logrus.Infof(" 黑名单阈值: %d 次错误", config.Keys.BlacklistThreshold)
|
||||
logrus.Infof(" 上游地址: %s", config.OpenAI.BaseURL)
|
||||
logrus.Infof(" 请求超时: %dms", config.OpenAI.Timeout)
|
||||
|
||||
|
||||
authStatus := "未启用"
|
||||
if config.Auth.Enabled {
|
||||
authStatus = "已启用"
|
||||
}
|
||||
logrus.Infof(" 认证: %s", authStatus)
|
||||
|
||||
|
||||
corsStatus := "已禁用"
|
||||
if config.CORS.Enabled {
|
||||
corsStatus = "已启用"
|
||||
@@ -227,7 +233,7 @@ func parseArray(value string, defaultValue []string) []string {
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
|
||||
parts := strings.Split(value, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
@@ -235,7 +241,7 @@ func parseArray(value string, defaultValue []string) []string {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if len(result) == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
|
@@ -27,13 +27,13 @@ type KeyInfo struct {
|
||||
|
||||
// Stats 统计信息
|
||||
type Stats struct {
|
||||
CurrentIndex int64 `json:"currentIndex"`
|
||||
TotalKeys int `json:"totalKeys"`
|
||||
HealthyKeys int `json:"healthyKeys"`
|
||||
BlacklistedKeys int `json:"blacklistedKeys"`
|
||||
SuccessCount int64 `json:"successCount"`
|
||||
FailureCount int64 `json:"failureCount"`
|
||||
MemoryUsage MemoryUsage `json:"memoryUsage"`
|
||||
CurrentIndex int64 `json:"currentIndex"`
|
||||
TotalKeys int `json:"totalKeys"`
|
||||
HealthyKeys int `json:"healthyKeys"`
|
||||
BlacklistedKeys int `json:"blacklistedKeys"`
|
||||
SuccessCount int64 `json:"successCount"`
|
||||
FailureCount int64 `json:"failureCount"`
|
||||
MemoryUsage MemoryUsage `json:"memoryUsage"`
|
||||
}
|
||||
|
||||
// MemoryUsage 内存使用情况
|
||||
@@ -60,22 +60,22 @@ type BlacklistInfo struct {
|
||||
|
||||
// KeyManager 密钥管理器
|
||||
type KeyManager struct {
|
||||
keysFilePath string
|
||||
keys []string
|
||||
keyPreviews []string
|
||||
currentIndex int64
|
||||
blacklistedKeys sync.Map
|
||||
successCount int64
|
||||
failureCount int64
|
||||
keysFilePath string
|
||||
keys []string
|
||||
keyPreviews []string
|
||||
currentIndex int64
|
||||
blacklistedKeys sync.Map
|
||||
successCount int64
|
||||
failureCount int64
|
||||
keyFailureCounts sync.Map
|
||||
|
||||
|
||||
// 性能优化:预编译正则表达式
|
||||
permanentErrorPatterns []*regexp.Regexp
|
||||
|
||||
|
||||
// 内存管理
|
||||
cleanupTicker *time.Ticker
|
||||
stopCleanup chan bool
|
||||
|
||||
|
||||
// 读写锁保护密钥列表
|
||||
keysMutex sync.RWMutex
|
||||
}
|
||||
@@ -85,12 +85,12 @@ func NewKeyManager(keysFilePath string) *KeyManager {
|
||||
if keysFilePath == "" {
|
||||
keysFilePath = config.AppConfig.Keys.FilePath
|
||||
}
|
||||
|
||||
|
||||
km := &KeyManager{
|
||||
keysFilePath: keysFilePath,
|
||||
currentIndex: int64(config.AppConfig.Keys.StartIndex),
|
||||
stopCleanup: make(chan bool),
|
||||
|
||||
|
||||
// 预编译正则表达式
|
||||
permanentErrorPatterns: []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)invalid api key`),
|
||||
@@ -101,10 +101,10 @@ func NewKeyManager(keysFilePath string) *KeyManager {
|
||||
regexp.MustCompile(`(?i)billing`),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
// 启动内存清理
|
||||
km.setupMemoryCleanup()
|
||||
|
||||
|
||||
return km
|
||||
}
|
||||
|
||||
@@ -115,25 +115,25 @@ func (km *KeyManager) LoadKeys() error {
|
||||
return fmt.Errorf("无法打开密钥文件: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
|
||||
var keys []string
|
||||
scanner := bufio.NewScanner(file)
|
||||
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line != "" && strings.HasPrefix(line, "sk-") {
|
||||
keys = append(keys, line)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("读取密钥文件失败: %w", err)
|
||||
}
|
||||
|
||||
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("密钥文件中没有有效的API密钥")
|
||||
}
|
||||
|
||||
|
||||
km.keysMutex.Lock()
|
||||
km.keys = keys
|
||||
// 预生成密钥预览,避免运行时重复计算
|
||||
@@ -146,7 +146,7 @@ func (km *KeyManager) LoadKeys() error {
|
||||
}
|
||||
}
|
||||
km.keysMutex.Unlock()
|
||||
|
||||
|
||||
logrus.Infof("✅ 成功加载 %d 个 API 密钥", len(keys))
|
||||
return nil
|
||||
}
|
||||
@@ -155,36 +155,40 @@ func (km *KeyManager) LoadKeys() error {
|
||||
func (km *KeyManager) GetNextKey() (*KeyInfo, error) {
|
||||
km.keysMutex.RLock()
|
||||
keysLen := len(km.keys)
|
||||
km.keysMutex.RUnlock()
|
||||
|
||||
if keysLen == 0 {
|
||||
km.keysMutex.RUnlock()
|
||||
return nil, fmt.Errorf("没有可用的 API 密钥")
|
||||
}
|
||||
|
||||
// 检查是否所有密钥都被拉黑
|
||||
blacklistedCount := 0
|
||||
km.blacklistedKeys.Range(func(key, value interface{}) bool {
|
||||
blacklistedCount++
|
||||
return true
|
||||
})
|
||||
|
||||
if blacklistedCount >= keysLen {
|
||||
logrus.Warn("⚠️ 所有密钥都被拉黑,重置黑名单")
|
||||
km.blacklistedKeys = sync.Map{}
|
||||
km.keyFailureCounts = sync.Map{}
|
||||
|
||||
// 快速路径:直接获取下一个密钥,避免黑名单检查的开销
|
||||
currentIdx := atomic.AddInt64(&km.currentIndex, 1) - 1
|
||||
keyIndex := int(currentIdx) % keysLen
|
||||
selectedKey := km.keys[keyIndex]
|
||||
keyPreview := km.keyPreviews[keyIndex]
|
||||
km.keysMutex.RUnlock()
|
||||
|
||||
// 检查是否被拉黑
|
||||
if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted {
|
||||
return &KeyInfo{
|
||||
Key: selectedKey,
|
||||
Index: keyIndex,
|
||||
Preview: keyPreview,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 使用原子操作避免竞态条件
|
||||
|
||||
// 慢速路径:寻找可用密钥
|
||||
attempts := 0
|
||||
for attempts < keysLen {
|
||||
currentIdx := atomic.AddInt64(&km.currentIndex, 1) - 1
|
||||
keyIndex := int(currentIdx) % keysLen
|
||||
|
||||
maxAttempts := keysLen * 2 // 最多尝试两轮
|
||||
|
||||
for attempts < maxAttempts {
|
||||
currentIdx = atomic.AddInt64(&km.currentIndex, 1) - 1
|
||||
keyIndex = int(currentIdx) % keysLen
|
||||
|
||||
km.keysMutex.RLock()
|
||||
selectedKey := km.keys[keyIndex]
|
||||
keyPreview := km.keyPreviews[keyIndex]
|
||||
selectedKey = km.keys[keyIndex]
|
||||
keyPreview = km.keyPreviews[keyIndex]
|
||||
km.keysMutex.RUnlock()
|
||||
|
||||
|
||||
if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted {
|
||||
return &KeyInfo{
|
||||
Key: selectedKey,
|
||||
@@ -192,21 +196,36 @@ func (km *KeyManager) GetNextKey() (*KeyInfo, error) {
|
||||
Preview: keyPreview,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
attempts++
|
||||
}
|
||||
|
||||
// 兜底:返回第一个密钥
|
||||
km.keysMutex.RLock()
|
||||
firstKey := km.keys[0]
|
||||
firstPreview := km.keyPreviews[0]
|
||||
km.keysMutex.RUnlock()
|
||||
|
||||
return &KeyInfo{
|
||||
Key: firstKey,
|
||||
Index: 0,
|
||||
Preview: firstPreview,
|
||||
}, nil
|
||||
|
||||
// 检查是否所有密钥都被拉黑,如果是则重置
|
||||
blacklistedCount := 0
|
||||
km.blacklistedKeys.Range(func(key, value interface{}) bool {
|
||||
blacklistedCount++
|
||||
return blacklistedCount < keysLen // 提前退出优化
|
||||
})
|
||||
|
||||
if blacklistedCount >= keysLen {
|
||||
logrus.Warn("⚠️ 所有密钥都被拉黑,重置黑名单")
|
||||
km.blacklistedKeys = sync.Map{}
|
||||
km.keyFailureCounts = sync.Map{}
|
||||
|
||||
// 重置后返回第一个密钥
|
||||
km.keysMutex.RLock()
|
||||
firstKey := km.keys[0]
|
||||
firstPreview := km.keyPreviews[0]
|
||||
km.keysMutex.RUnlock()
|
||||
|
||||
return &KeyInfo{
|
||||
Key: firstKey,
|
||||
Index: 0,
|
||||
Preview: firstPreview,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("暂时没有可用的 API 密钥")
|
||||
}
|
||||
|
||||
// RecordSuccess 记录密钥使用成功
|
||||
@@ -219,7 +238,7 @@ func (km *KeyManager) RecordSuccess(key string) {
|
||||
// RecordFailure 记录密钥使用失败
|
||||
func (km *KeyManager) RecordFailure(key string, err error) {
|
||||
atomic.AddInt64(&km.failureCount, 1)
|
||||
|
||||
|
||||
// 检查是否是永久性错误
|
||||
if km.isPermanentError(err) {
|
||||
km.blacklistedKeys.Store(key, true)
|
||||
@@ -227,7 +246,7 @@ func (km *KeyManager) RecordFailure(key string, err error) {
|
||||
logrus.Warnf("🚫 密钥已被拉黑(永久性错误): %s (%s)", key[:20]+"...", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 临时性错误:增加失败计数
|
||||
currentFailures := 0
|
||||
if val, exists := km.keyFailureCounts.Load(key); exists {
|
||||
@@ -235,7 +254,7 @@ func (km *KeyManager) RecordFailure(key string, err error) {
|
||||
}
|
||||
newFailures := currentFailures + 1
|
||||
km.keyFailureCounts.Store(key, newFailures)
|
||||
|
||||
|
||||
threshold := config.AppConfig.Keys.BlacklistThreshold
|
||||
if newFailures >= threshold {
|
||||
km.blacklistedKeys.Store(key, true)
|
||||
@@ -262,19 +281,19 @@ func (km *KeyManager) GetStats() *Stats {
|
||||
km.keysMutex.RLock()
|
||||
totalKeys := len(km.keys)
|
||||
km.keysMutex.RUnlock()
|
||||
|
||||
|
||||
blacklistedCount := 0
|
||||
km.blacklistedKeys.Range(func(key, value interface{}) bool {
|
||||
blacklistedCount++
|
||||
return true
|
||||
})
|
||||
|
||||
|
||||
failureCountsSize := 0
|
||||
km.keyFailureCounts.Range(func(key, value interface{}) bool {
|
||||
failureCountsSize++
|
||||
return true
|
||||
})
|
||||
|
||||
|
||||
return &Stats{
|
||||
CurrentIndex: atomic.LoadInt64(&km.currentIndex),
|
||||
TotalKeys: totalKeys,
|
||||
@@ -296,16 +315,16 @@ func (km *KeyManager) ResetKeys() map[string]interface{} {
|
||||
beforeCount++
|
||||
return true
|
||||
})
|
||||
|
||||
|
||||
km.blacklistedKeys = sync.Map{}
|
||||
km.keyFailureCounts = sync.Map{}
|
||||
|
||||
|
||||
logrus.Infof("🔄 密钥状态已重置,清除了 %d 个黑名单密钥", beforeCount)
|
||||
|
||||
|
||||
km.keysMutex.RLock()
|
||||
totalKeys := len(km.keys)
|
||||
km.keysMutex.RUnlock()
|
||||
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("已清除 %d 个黑名单密钥", beforeCount),
|
||||
@@ -317,12 +336,12 @@ func (km *KeyManager) ResetKeys() map[string]interface{} {
|
||||
// GetBlacklistDetails 获取黑名单详情
|
||||
func (km *KeyManager) GetBlacklistDetails() *BlacklistInfo {
|
||||
var blacklistDetails []BlacklistDetail
|
||||
|
||||
|
||||
km.keysMutex.RLock()
|
||||
keys := km.keys
|
||||
keyPreviews := km.keyPreviews
|
||||
km.keysMutex.RUnlock()
|
||||
|
||||
|
||||
for i, key := range keys {
|
||||
if _, blacklisted := km.blacklistedKeys.Load(key); blacklisted {
|
||||
blacklistDetails = append(blacklistDetails, BlacklistDetail{
|
||||
@@ -333,7 +352,7 @@ func (km *KeyManager) GetBlacklistDetails() *BlacklistInfo {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return &BlacklistInfo{
|
||||
TotalBlacklisted: len(blacklistDetails),
|
||||
TotalKeys: len(keys),
|
||||
@@ -345,7 +364,7 @@ func (km *KeyManager) GetBlacklistDetails() *BlacklistInfo {
|
||||
// setupMemoryCleanup 设置内存清理机制
|
||||
func (km *KeyManager) setupMemoryCleanup() {
|
||||
km.cleanupTicker = time.NewTicker(10 * time.Minute)
|
||||
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
@@ -367,20 +386,20 @@ func (km *KeyManager) performMemoryCleanup() {
|
||||
maxSize = 1000
|
||||
}
|
||||
km.keysMutex.RUnlock()
|
||||
|
||||
|
||||
currentSize := 0
|
||||
km.keyFailureCounts.Range(func(key, value interface{}) bool {
|
||||
currentSize++
|
||||
return true
|
||||
})
|
||||
|
||||
|
||||
if currentSize > maxSize {
|
||||
logrus.Infof("🧹 清理失败计数缓存 (%d -> %d)", currentSize, maxSize)
|
||||
|
||||
|
||||
// 简单策略:清理一半的失败计数
|
||||
cleared := 0
|
||||
target := currentSize - maxSize
|
||||
|
||||
|
||||
km.keyFailureCounts.Range(func(key, value interface{}) bool {
|
||||
if cleared < target {
|
||||
km.keyFailureCounts.Delete(key)
|
||||
|
@@ -4,7 +4,6 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -46,16 +45,22 @@ func NewProxyServer() (*ProxyServer, error) {
|
||||
|
||||
// 创建高性能HTTP客户端
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: config.AppConfig.Performance.MaxSockets,
|
||||
MaxIdleConnsPerHost: config.AppConfig.Performance.MaxFreeSockets,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DisableCompression: false,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: config.AppConfig.Performance.MaxSockets,
|
||||
MaxIdleConnsPerHost: config.AppConfig.Performance.MaxFreeSockets,
|
||||
MaxConnsPerHost: 0, // 无限制,避免连接池瓶颈
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
DisableCompression: true, // 禁用压缩以减少CPU开销,让上游处理
|
||||
ForceAttemptHTTP2: true,
|
||||
WriteBufferSize: 32 * 1024, // 32KB写缓冲
|
||||
ReadBufferSize: 32 * 1024, // 32KB读缓冲
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond,
|
||||
// 移除全局超时,使用更细粒度的超时控制
|
||||
// Timeout: time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond,
|
||||
}
|
||||
|
||||
return &ProxyServer{
|
||||
@@ -180,9 +185,19 @@ func (ps *ProxyServer) authMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// loggerMiddleware 自定义日志中间件
|
||||
// loggerMiddleware 高性能日志中间件
|
||||
func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 只在非生产环境或调试模式下记录详细日志
|
||||
if gin.Mode() == gin.ReleaseMode {
|
||||
// 生产模式下只记录错误和关键信息
|
||||
c.Next()
|
||||
if c.Writer.Status() >= 400 {
|
||||
logrus.Errorf("Error %d: %s %s", c.Writer.Status(), c.Request.Method, c.Request.URL.Path)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
raw := c.Request.URL.RawQuery
|
||||
@@ -193,16 +208,14 @@ func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc {
|
||||
// 计算响应时间
|
||||
latency := time.Since(start)
|
||||
|
||||
// 获取客户端IP
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
// 获取方法和状态码
|
||||
// 获取基本信息
|
||||
method := c.Request.Method
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
// 构建完整路径
|
||||
// 构建完整路径(避免字符串拼接)
|
||||
fullPath := path
|
||||
if raw != "" {
|
||||
path = path + "?" + raw
|
||||
fullPath = path + "?" + raw
|
||||
}
|
||||
|
||||
// 获取密钥信息(如果存在)
|
||||
@@ -213,22 +226,8 @@ func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// 根据状态码选择颜色
|
||||
var statusColor string
|
||||
if statusCode >= 200 && statusCode < 300 {
|
||||
statusColor = "\033[32m" // 绿色
|
||||
} else {
|
||||
statusColor = "\033[31m" // 红色
|
||||
}
|
||||
resetColor := "\033[0m"
|
||||
keyColor := "\033[36m" // 青色
|
||||
|
||||
// 输出日志
|
||||
logrus.Infof("%s[%s] %s %s%s%s%s - %s%d%s - %v - %s",
|
||||
statusColor, time.Now().Format(time.RFC3339), method, path, resetColor,
|
||||
keyColor, keyInfo, resetColor,
|
||||
statusColor, statusCode, resetColor,
|
||||
latency, clientIP)
|
||||
// 简化日志输出,减少格式化开销
|
||||
logrus.Infof("%s %s - %d - %v%s", method, fullPath, statusCode, latency, keyInfo)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -311,36 +310,30 @@ func (ps *ProxyServer) handleProxy(c *gin.Context) {
|
||||
c.Set("keyIndex", keyInfo.Index)
|
||||
c.Set("keyPreview", keyInfo.Preview)
|
||||
|
||||
// 读取请求体
|
||||
var bodyBytes []byte
|
||||
if c.Request.Body != nil {
|
||||
bodyBytes, err = io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
logrus.Errorf("读取请求体失败: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "读取请求体失败",
|
||||
"type": "request_error",
|
||||
"code": "invalid_request_body",
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
// 直接使用请求体,避免完整读取到内存
|
||||
var requestBody io.Reader = c.Request.Body
|
||||
var contentLength int64 = c.Request.ContentLength
|
||||
|
||||
// 构建上游请求URL
|
||||
targetURL := *ps.upstreamURL
|
||||
targetURL.Path = c.Request.URL.Path
|
||||
targetURL.RawQuery = c.Request.URL.RawQuery
|
||||
|
||||
// 创建上游请求
|
||||
// 创建带超时的上下文
|
||||
timeout := time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// 创建上游请求,直接使用原始请求体进行流式传输
|
||||
req, err := http.NewRequestWithContext(
|
||||
context.Background(),
|
||||
ctx,
|
||||
c.Request.Method,
|
||||
targetURL.String(),
|
||||
bytes.NewReader(bodyBytes),
|
||||
requestBody,
|
||||
)
|
||||
if err == nil && contentLength > 0 {
|
||||
req.ContentLength = contentLength
|
||||
}
|
||||
if err != nil {
|
||||
logrus.Errorf("创建上游请求失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
|
Reference in New Issue
Block a user