fix: 优化

This commit is contained in:
tbphp
2025-06-06 21:56:40 +08:00
parent 219c068dbf
commit 1faa0b1b73
5 changed files with 193 additions and 156 deletions

View File

@@ -32,6 +32,24 @@ OPENAI_BASE_URL=https://api.openai.com
# 请求超时时间(毫秒) # 请求超时时间(毫秒)
REQUEST_TIMEOUT=30000 REQUEST_TIMEOUT=30000
# ===========================================
# 性能优化配置
# ===========================================
# 最大连接数
MAX_SOCKETS=100
# 最大空闲连接数
MAX_FREE_SOCKETS=20
# 启用 Keep-Alive
ENABLE_KEEP_ALIVE=true
# 禁用压缩减少CPU开销
DISABLE_COMPRESSION=true
# 缓冲区大小(字节)
BUFFER_SIZE=32768
# =========================================== # ===========================================
# 认证配置 # 认证配置
# =========================================== # ===========================================

View File

@@ -44,13 +44,14 @@ func main() {
// 设置路由 // 设置路由
router := proxyServer.SetupRoutes() router := proxyServer.SetupRoutes()
// 创建HTTP服务器 // 创建HTTP服务器,优化超时配置
server := &http.Server{ server := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
Handler: router, Handler: router,
ReadTimeout: 30 * time.Second, ReadTimeout: 60 * time.Second, // 增加读超时,支持大文件上传
WriteTimeout: 30 * time.Second, WriteTimeout: 300 * time.Second, // 增加写超时,支持流式响应
IdleTimeout: 60 * time.Second, IdleTimeout: 120 * time.Second, // 增加空闲超时,复用连接
MaxHeaderBytes: 1 << 20, // 1MB header limit
} }
// 启动服务器 // 启动服务器
@@ -90,7 +91,7 @@ func main() {
func displayStartupInfo(cfg *config.Config) { func displayStartupInfo(cfg *config.Config) {
logrus.Info("🚀 OpenAI 多密钥代理服务器 v2.0.0 (Go版本)") logrus.Info("🚀 OpenAI 多密钥代理服务器 v2.0.0 (Go版本)")
logrus.Info("") logrus.Info("")
// 显示配置 // 显示配置
config.DisplayConfig(cfg) config.DisplayConfig(cfg)
logrus.Info("") logrus.Info("")

View File

@@ -16,21 +16,21 @@ import (
// Constants 配置常量 // Constants 配置常量
type Constants struct { type Constants struct {
MinPort int MinPort int
MaxPort int MaxPort int
MinTimeout int MinTimeout int
DefaultTimeout int DefaultTimeout int
DefaultMaxSockets int DefaultMaxSockets int
DefaultMaxFreeSockets int DefaultMaxFreeSockets int
} }
// DefaultConstants 默认常量 // DefaultConstants 默认常量
var DefaultConstants = Constants{ var DefaultConstants = Constants{
MinPort: 1, MinPort: 1,
MaxPort: 65535, MaxPort: 65535,
MinTimeout: 1000, MinTimeout: 1000,
DefaultTimeout: 30000, DefaultTimeout: 30000,
DefaultMaxSockets: 50, DefaultMaxSockets: 50,
DefaultMaxFreeSockets: 10, DefaultMaxFreeSockets: 10,
} }
@@ -67,8 +67,11 @@ type CORSConfig struct {
// PerformanceConfig 性能配置 // PerformanceConfig 性能配置
type PerformanceConfig struct { type PerformanceConfig struct {
MaxSockets int `json:"maxSockets"` MaxSockets int `json:"maxSockets"`
MaxFreeSockets int `json:"maxFreeSockets"` MaxFreeSockets int `json:"maxFreeSockets"`
EnableKeepAlive bool `json:"enableKeepAlive"`
DisableCompression bool `json:"disableCompression"`
BufferSize int `json:"bufferSize"`
} }
// Config 应用配置 // Config 应用配置
@@ -114,8 +117,11 @@ func LoadConfig() (*Config, error) {
AllowedOrigins: parseArray(os.Getenv("ALLOWED_ORIGINS"), []string{"*"}), AllowedOrigins: parseArray(os.Getenv("ALLOWED_ORIGINS"), []string{"*"}),
}, },
Performance: PerformanceConfig{ Performance: PerformanceConfig{
MaxSockets: parseInteger(os.Getenv("MAX_SOCKETS"), DefaultConstants.DefaultMaxSockets), MaxSockets: parseInteger(os.Getenv("MAX_SOCKETS"), DefaultConstants.DefaultMaxSockets),
MaxFreeSockets: parseInteger(os.Getenv("MAX_FREE_SOCKETS"), DefaultConstants.DefaultMaxFreeSockets), MaxFreeSockets: parseInteger(os.Getenv("MAX_FREE_SOCKETS"), DefaultConstants.DefaultMaxFreeSockets),
EnableKeepAlive: parseBoolean(os.Getenv("ENABLE_KEEP_ALIVE"), true),
DisableCompression: parseBoolean(os.Getenv("DISABLE_COMPRESSION"), true),
BufferSize: parseInteger(os.Getenv("BUFFER_SIZE"), 32*1024),
}, },
} }
@@ -186,13 +192,13 @@ func DisplayConfig(config *Config) {
logrus.Infof(" 黑名单阈值: %d 次错误", config.Keys.BlacklistThreshold) logrus.Infof(" 黑名单阈值: %d 次错误", config.Keys.BlacklistThreshold)
logrus.Infof(" 上游地址: %s", config.OpenAI.BaseURL) logrus.Infof(" 上游地址: %s", config.OpenAI.BaseURL)
logrus.Infof(" 请求超时: %dms", config.OpenAI.Timeout) logrus.Infof(" 请求超时: %dms", config.OpenAI.Timeout)
authStatus := "未启用" authStatus := "未启用"
if config.Auth.Enabled { if config.Auth.Enabled {
authStatus = "已启用" authStatus = "已启用"
} }
logrus.Infof(" 认证: %s", authStatus) logrus.Infof(" 认证: %s", authStatus)
corsStatus := "已禁用" corsStatus := "已禁用"
if config.CORS.Enabled { if config.CORS.Enabled {
corsStatus = "已启用" corsStatus = "已启用"
@@ -227,7 +233,7 @@ func parseArray(value string, defaultValue []string) []string {
if value == "" { if value == "" {
return defaultValue return defaultValue
} }
parts := strings.Split(value, ",") parts := strings.Split(value, ",")
result := make([]string, 0, len(parts)) result := make([]string, 0, len(parts))
for _, part := range parts { for _, part := range parts {
@@ -235,7 +241,7 @@ func parseArray(value string, defaultValue []string) []string {
result = append(result, trimmed) result = append(result, trimmed)
} }
} }
if len(result) == 0 { if len(result) == 0 {
return defaultValue return defaultValue
} }

View File

@@ -27,13 +27,13 @@ type KeyInfo struct {
// Stats 统计信息 // Stats 统计信息
type Stats struct { type Stats struct {
CurrentIndex int64 `json:"currentIndex"` CurrentIndex int64 `json:"currentIndex"`
TotalKeys int `json:"totalKeys"` TotalKeys int `json:"totalKeys"`
HealthyKeys int `json:"healthyKeys"` HealthyKeys int `json:"healthyKeys"`
BlacklistedKeys int `json:"blacklistedKeys"` BlacklistedKeys int `json:"blacklistedKeys"`
SuccessCount int64 `json:"successCount"` SuccessCount int64 `json:"successCount"`
FailureCount int64 `json:"failureCount"` FailureCount int64 `json:"failureCount"`
MemoryUsage MemoryUsage `json:"memoryUsage"` MemoryUsage MemoryUsage `json:"memoryUsage"`
} }
// MemoryUsage 内存使用情况 // MemoryUsage 内存使用情况
@@ -60,22 +60,22 @@ type BlacklistInfo struct {
// KeyManager 密钥管理器 // KeyManager 密钥管理器
type KeyManager struct { type KeyManager struct {
keysFilePath string keysFilePath string
keys []string keys []string
keyPreviews []string keyPreviews []string
currentIndex int64 currentIndex int64
blacklistedKeys sync.Map blacklistedKeys sync.Map
successCount int64 successCount int64
failureCount int64 failureCount int64
keyFailureCounts sync.Map keyFailureCounts sync.Map
// 性能优化:预编译正则表达式 // 性能优化:预编译正则表达式
permanentErrorPatterns []*regexp.Regexp permanentErrorPatterns []*regexp.Regexp
// 内存管理 // 内存管理
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
stopCleanup chan bool stopCleanup chan bool
// 读写锁保护密钥列表 // 读写锁保护密钥列表
keysMutex sync.RWMutex keysMutex sync.RWMutex
} }
@@ -85,12 +85,12 @@ func NewKeyManager(keysFilePath string) *KeyManager {
if keysFilePath == "" { if keysFilePath == "" {
keysFilePath = config.AppConfig.Keys.FilePath keysFilePath = config.AppConfig.Keys.FilePath
} }
km := &KeyManager{ km := &KeyManager{
keysFilePath: keysFilePath, keysFilePath: keysFilePath,
currentIndex: int64(config.AppConfig.Keys.StartIndex), currentIndex: int64(config.AppConfig.Keys.StartIndex),
stopCleanup: make(chan bool), stopCleanup: make(chan bool),
// 预编译正则表达式 // 预编译正则表达式
permanentErrorPatterns: []*regexp.Regexp{ permanentErrorPatterns: []*regexp.Regexp{
regexp.MustCompile(`(?i)invalid api key`), regexp.MustCompile(`(?i)invalid api key`),
@@ -101,10 +101,10 @@ func NewKeyManager(keysFilePath string) *KeyManager {
regexp.MustCompile(`(?i)billing`), regexp.MustCompile(`(?i)billing`),
}, },
} }
// 启动内存清理 // 启动内存清理
km.setupMemoryCleanup() km.setupMemoryCleanup()
return km return km
} }
@@ -115,25 +115,25 @@ func (km *KeyManager) LoadKeys() error {
return fmt.Errorf("无法打开密钥文件: %w", err) return fmt.Errorf("无法打开密钥文件: %w", err)
} }
defer file.Close() defer file.Close()
var keys []string var keys []string
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
for scanner.Scan() { for scanner.Scan() {
line := strings.TrimSpace(scanner.Text()) line := strings.TrimSpace(scanner.Text())
if line != "" && strings.HasPrefix(line, "sk-") { if line != "" && strings.HasPrefix(line, "sk-") {
keys = append(keys, line) keys = append(keys, line)
} }
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
return fmt.Errorf("读取密钥文件失败: %w", err) return fmt.Errorf("读取密钥文件失败: %w", err)
} }
if len(keys) == 0 { if len(keys) == 0 {
return fmt.Errorf("密钥文件中没有有效的API密钥") return fmt.Errorf("密钥文件中没有有效的API密钥")
} }
km.keysMutex.Lock() km.keysMutex.Lock()
km.keys = keys km.keys = keys
// 预生成密钥预览,避免运行时重复计算 // 预生成密钥预览,避免运行时重复计算
@@ -146,7 +146,7 @@ func (km *KeyManager) LoadKeys() error {
} }
} }
km.keysMutex.Unlock() km.keysMutex.Unlock()
logrus.Infof("✅ 成功加载 %d 个 API 密钥", len(keys)) logrus.Infof("✅ 成功加载 %d 个 API 密钥", len(keys))
return nil return nil
} }
@@ -155,36 +155,40 @@ func (km *KeyManager) LoadKeys() error {
func (km *KeyManager) GetNextKey() (*KeyInfo, error) { func (km *KeyManager) GetNextKey() (*KeyInfo, error) {
km.keysMutex.RLock() km.keysMutex.RLock()
keysLen := len(km.keys) keysLen := len(km.keys)
km.keysMutex.RUnlock()
if keysLen == 0 { if keysLen == 0 {
km.keysMutex.RUnlock()
return nil, fmt.Errorf("没有可用的 API 密钥") return nil, fmt.Errorf("没有可用的 API 密钥")
} }
// 检查是否所有密钥都被拉黑 // 快速路径:直接获取下一个密钥,避免黑名单检查的开销
blacklistedCount := 0 currentIdx := atomic.AddInt64(&km.currentIndex, 1) - 1
km.blacklistedKeys.Range(func(key, value interface{}) bool { keyIndex := int(currentIdx) % keysLen
blacklistedCount++ selectedKey := km.keys[keyIndex]
return true keyPreview := km.keyPreviews[keyIndex]
}) km.keysMutex.RUnlock()
if blacklistedCount >= keysLen { // 检查是否被拉黑
logrus.Warn("⚠️ 所有密钥都被拉黑,重置黑名单") if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted {
km.blacklistedKeys = sync.Map{} return &KeyInfo{
km.keyFailureCounts = sync.Map{} Key: selectedKey,
Index: keyIndex,
Preview: keyPreview,
}, nil
} }
// 使用原子操作避免竞态条件 // 慢速路径:寻找可用密钥
attempts := 0 attempts := 0
for attempts < keysLen { maxAttempts := keysLen * 2 // 最多尝试两轮
currentIdx := atomic.AddInt64(&km.currentIndex, 1) - 1
keyIndex := int(currentIdx) % keysLen for attempts < maxAttempts {
currentIdx = atomic.AddInt64(&km.currentIndex, 1) - 1
keyIndex = int(currentIdx) % keysLen
km.keysMutex.RLock() km.keysMutex.RLock()
selectedKey := km.keys[keyIndex] selectedKey = km.keys[keyIndex]
keyPreview := km.keyPreviews[keyIndex] keyPreview = km.keyPreviews[keyIndex]
km.keysMutex.RUnlock() km.keysMutex.RUnlock()
if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted { if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted {
return &KeyInfo{ return &KeyInfo{
Key: selectedKey, Key: selectedKey,
@@ -192,21 +196,36 @@ func (km *KeyManager) GetNextKey() (*KeyInfo, error) {
Preview: keyPreview, Preview: keyPreview,
}, nil }, nil
} }
attempts++ attempts++
} }
// 兜底:返回第一个密钥 // 检查是否所有密钥都被拉黑,如果是则重置
km.keysMutex.RLock() blacklistedCount := 0
firstKey := km.keys[0] km.blacklistedKeys.Range(func(key, value interface{}) bool {
firstPreview := km.keyPreviews[0] blacklistedCount++
km.keysMutex.RUnlock() return blacklistedCount < keysLen // 提前退出优化
})
return &KeyInfo{
Key: firstKey, if blacklistedCount >= keysLen {
Index: 0, logrus.Warn("⚠️ 所有密钥都被拉黑,重置黑名单")
Preview: firstPreview, km.blacklistedKeys = sync.Map{}
}, nil km.keyFailureCounts = sync.Map{}
// 重置后返回第一个密钥
km.keysMutex.RLock()
firstKey := km.keys[0]
firstPreview := km.keyPreviews[0]
km.keysMutex.RUnlock()
return &KeyInfo{
Key: firstKey,
Index: 0,
Preview: firstPreview,
}, nil
}
return nil, fmt.Errorf("暂时没有可用的 API 密钥")
} }
// RecordSuccess 记录密钥使用成功 // RecordSuccess 记录密钥使用成功
@@ -219,7 +238,7 @@ func (km *KeyManager) RecordSuccess(key string) {
// RecordFailure 记录密钥使用失败 // RecordFailure 记录密钥使用失败
func (km *KeyManager) RecordFailure(key string, err error) { func (km *KeyManager) RecordFailure(key string, err error) {
atomic.AddInt64(&km.failureCount, 1) atomic.AddInt64(&km.failureCount, 1)
// 检查是否是永久性错误 // 检查是否是永久性错误
if km.isPermanentError(err) { if km.isPermanentError(err) {
km.blacklistedKeys.Store(key, true) km.blacklistedKeys.Store(key, true)
@@ -227,7 +246,7 @@ func (km *KeyManager) RecordFailure(key string, err error) {
logrus.Warnf("🚫 密钥已被拉黑(永久性错误): %s (%s)", key[:20]+"...", err.Error()) logrus.Warnf("🚫 密钥已被拉黑(永久性错误): %s (%s)", key[:20]+"...", err.Error())
return return
} }
// 临时性错误:增加失败计数 // 临时性错误:增加失败计数
currentFailures := 0 currentFailures := 0
if val, exists := km.keyFailureCounts.Load(key); exists { if val, exists := km.keyFailureCounts.Load(key); exists {
@@ -235,7 +254,7 @@ func (km *KeyManager) RecordFailure(key string, err error) {
} }
newFailures := currentFailures + 1 newFailures := currentFailures + 1
km.keyFailureCounts.Store(key, newFailures) km.keyFailureCounts.Store(key, newFailures)
threshold := config.AppConfig.Keys.BlacklistThreshold threshold := config.AppConfig.Keys.BlacklistThreshold
if newFailures >= threshold { if newFailures >= threshold {
km.blacklistedKeys.Store(key, true) km.blacklistedKeys.Store(key, true)
@@ -262,19 +281,19 @@ func (km *KeyManager) GetStats() *Stats {
km.keysMutex.RLock() km.keysMutex.RLock()
totalKeys := len(km.keys) totalKeys := len(km.keys)
km.keysMutex.RUnlock() km.keysMutex.RUnlock()
blacklistedCount := 0 blacklistedCount := 0
km.blacklistedKeys.Range(func(key, value interface{}) bool { km.blacklistedKeys.Range(func(key, value interface{}) bool {
blacklistedCount++ blacklistedCount++
return true return true
}) })
failureCountsSize := 0 failureCountsSize := 0
km.keyFailureCounts.Range(func(key, value interface{}) bool { km.keyFailureCounts.Range(func(key, value interface{}) bool {
failureCountsSize++ failureCountsSize++
return true return true
}) })
return &Stats{ return &Stats{
CurrentIndex: atomic.LoadInt64(&km.currentIndex), CurrentIndex: atomic.LoadInt64(&km.currentIndex),
TotalKeys: totalKeys, TotalKeys: totalKeys,
@@ -296,16 +315,16 @@ func (km *KeyManager) ResetKeys() map[string]interface{} {
beforeCount++ beforeCount++
return true return true
}) })
km.blacklistedKeys = sync.Map{} km.blacklistedKeys = sync.Map{}
km.keyFailureCounts = sync.Map{} km.keyFailureCounts = sync.Map{}
logrus.Infof("🔄 密钥状态已重置,清除了 %d 个黑名单密钥", beforeCount) logrus.Infof("🔄 密钥状态已重置,清除了 %d 个黑名单密钥", beforeCount)
km.keysMutex.RLock() km.keysMutex.RLock()
totalKeys := len(km.keys) totalKeys := len(km.keys)
km.keysMutex.RUnlock() km.keysMutex.RUnlock()
return map[string]interface{}{ return map[string]interface{}{
"success": true, "success": true,
"message": fmt.Sprintf("已清除 %d 个黑名单密钥", beforeCount), "message": fmt.Sprintf("已清除 %d 个黑名单密钥", beforeCount),
@@ -317,12 +336,12 @@ func (km *KeyManager) ResetKeys() map[string]interface{} {
// GetBlacklistDetails 获取黑名单详情 // GetBlacklistDetails 获取黑名单详情
func (km *KeyManager) GetBlacklistDetails() *BlacklistInfo { func (km *KeyManager) GetBlacklistDetails() *BlacklistInfo {
var blacklistDetails []BlacklistDetail var blacklistDetails []BlacklistDetail
km.keysMutex.RLock() km.keysMutex.RLock()
keys := km.keys keys := km.keys
keyPreviews := km.keyPreviews keyPreviews := km.keyPreviews
km.keysMutex.RUnlock() km.keysMutex.RUnlock()
for i, key := range keys { for i, key := range keys {
if _, blacklisted := km.blacklistedKeys.Load(key); blacklisted { if _, blacklisted := km.blacklistedKeys.Load(key); blacklisted {
blacklistDetails = append(blacklistDetails, BlacklistDetail{ blacklistDetails = append(blacklistDetails, BlacklistDetail{
@@ -333,7 +352,7 @@ func (km *KeyManager) GetBlacklistDetails() *BlacklistInfo {
}) })
} }
} }
return &BlacklistInfo{ return &BlacklistInfo{
TotalBlacklisted: len(blacklistDetails), TotalBlacklisted: len(blacklistDetails),
TotalKeys: len(keys), TotalKeys: len(keys),
@@ -345,7 +364,7 @@ func (km *KeyManager) GetBlacklistDetails() *BlacklistInfo {
// setupMemoryCleanup 设置内存清理机制 // setupMemoryCleanup 设置内存清理机制
func (km *KeyManager) setupMemoryCleanup() { func (km *KeyManager) setupMemoryCleanup() {
km.cleanupTicker = time.NewTicker(10 * time.Minute) km.cleanupTicker = time.NewTicker(10 * time.Minute)
go func() { go func() {
for { for {
select { select {
@@ -367,20 +386,20 @@ func (km *KeyManager) performMemoryCleanup() {
maxSize = 1000 maxSize = 1000
} }
km.keysMutex.RUnlock() km.keysMutex.RUnlock()
currentSize := 0 currentSize := 0
km.keyFailureCounts.Range(func(key, value interface{}) bool { km.keyFailureCounts.Range(func(key, value interface{}) bool {
currentSize++ currentSize++
return true return true
}) })
if currentSize > maxSize { if currentSize > maxSize {
logrus.Infof("🧹 清理失败计数缓存 (%d -> %d)", currentSize, maxSize) logrus.Infof("🧹 清理失败计数缓存 (%d -> %d)", currentSize, maxSize)
// 简单策略:清理一半的失败计数 // 简单策略:清理一半的失败计数
cleared := 0 cleared := 0
target := currentSize - maxSize target := currentSize - maxSize
km.keyFailureCounts.Range(func(key, value interface{}) bool { km.keyFailureCounts.Range(func(key, value interface{}) bool {
if cleared < target { if cleared < target {
km.keyFailureCounts.Delete(key) km.keyFailureCounts.Delete(key)

View File

@@ -4,7 +4,6 @@
package proxy package proxy
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"io" "io"
@@ -46,16 +45,22 @@ func NewProxyServer() (*ProxyServer, error) {
// 创建高性能HTTP客户端 // 创建高性能HTTP客户端
transport := &http.Transport{ transport := &http.Transport{
MaxIdleConns: config.AppConfig.Performance.MaxSockets, MaxIdleConns: config.AppConfig.Performance.MaxSockets,
MaxIdleConnsPerHost: config.AppConfig.Performance.MaxFreeSockets, MaxIdleConnsPerHost: config.AppConfig.Performance.MaxFreeSockets,
IdleConnTimeout: 90 * time.Second, MaxConnsPerHost: 0, // 无限制,避免连接池瓶颈
DisableCompression: false, IdleConnTimeout: 90 * time.Second,
ForceAttemptHTTP2: true, TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableCompression: true, // 禁用压缩以减少CPU开销让上游处理
ForceAttemptHTTP2: true,
WriteBufferSize: 32 * 1024, // 32KB写缓冲
ReadBufferSize: 32 * 1024, // 32KB读缓冲
} }
httpClient := &http.Client{ httpClient := &http.Client{
Transport: transport, Transport: transport,
Timeout: time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond, // 移除全局超时,使用更细粒度的超时控制
// Timeout: time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond,
} }
return &ProxyServer{ return &ProxyServer{
@@ -180,9 +185,19 @@ func (ps *ProxyServer) authMiddleware() gin.HandlerFunc {
} }
} }
// loggerMiddleware 自定义日志中间件 // loggerMiddleware 高性能日志中间件
func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc { func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 只在非生产环境或调试模式下记录详细日志
if gin.Mode() == gin.ReleaseMode {
// 生产模式下只记录错误和关键信息
c.Next()
if c.Writer.Status() >= 400 {
logrus.Errorf("Error %d: %s %s", c.Writer.Status(), c.Request.Method, c.Request.URL.Path)
}
return
}
start := time.Now() start := time.Now()
path := c.Request.URL.Path path := c.Request.URL.Path
raw := c.Request.URL.RawQuery raw := c.Request.URL.RawQuery
@@ -193,16 +208,14 @@ func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc {
// 计算响应时间 // 计算响应时间
latency := time.Since(start) latency := time.Since(start)
// 获取客户端IP // 获取基本信息
clientIP := c.ClientIP()
// 获取方法和状态码
method := c.Request.Method method := c.Request.Method
statusCode := c.Writer.Status() statusCode := c.Writer.Status()
// 构建完整路径 // 构建完整路径(避免字符串拼接)
fullPath := path
if raw != "" { if raw != "" {
path = path + "?" + raw fullPath = path + "?" + raw
} }
// 获取密钥信息(如果存在) // 获取密钥信息(如果存在)
@@ -213,22 +226,8 @@ func (ps *ProxyServer) loggerMiddleware() gin.HandlerFunc {
} }
} }
// 根据状态码选择颜色 // 简化日志输出,减少格式化开销
var statusColor string logrus.Infof("%s %s - %d - %v%s", method, fullPath, statusCode, latency, keyInfo)
if statusCode >= 200 && statusCode < 300 {
statusColor = "\033[32m" // 绿色
} else {
statusColor = "\033[31m" // 红色
}
resetColor := "\033[0m"
keyColor := "\033[36m" // 青色
// 输出日志
logrus.Infof("%s[%s] %s %s%s%s%s - %s%d%s - %v - %s",
statusColor, time.Now().Format(time.RFC3339), method, path, resetColor,
keyColor, keyInfo, resetColor,
statusColor, statusCode, resetColor,
latency, clientIP)
} }
} }
@@ -311,36 +310,30 @@ func (ps *ProxyServer) handleProxy(c *gin.Context) {
c.Set("keyIndex", keyInfo.Index) c.Set("keyIndex", keyInfo.Index)
c.Set("keyPreview", keyInfo.Preview) c.Set("keyPreview", keyInfo.Preview)
// 读取请求体 // 直接使用请求体,避免完整读取到内存
var bodyBytes []byte var requestBody io.Reader = c.Request.Body
if c.Request.Body != nil { var contentLength int64 = c.Request.ContentLength
bodyBytes, err = io.ReadAll(c.Request.Body)
if err != nil {
logrus.Errorf("读取请求体失败: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "读取请求体失败",
"type": "request_error",
"code": "invalid_request_body",
"timestamp": time.Now().Format(time.RFC3339),
},
})
return
}
}
// 构建上游请求URL // 构建上游请求URL
targetURL := *ps.upstreamURL targetURL := *ps.upstreamURL
targetURL.Path = c.Request.URL.Path targetURL.Path = c.Request.URL.Path
targetURL.RawQuery = c.Request.URL.RawQuery targetURL.RawQuery = c.Request.URL.RawQuery
// 创建上游请求 // 创建带超时的上下文
timeout := time.Duration(config.AppConfig.OpenAI.Timeout) * time.Millisecond
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// 创建上游请求,直接使用原始请求体进行流式传输
req, err := http.NewRequestWithContext( req, err := http.NewRequestWithContext(
context.Background(), ctx,
c.Request.Method, c.Request.Method,
targetURL.String(), targetURL.String(),
bytes.NewReader(bodyBytes), requestBody,
) )
if err == nil && contentLength > 0 {
req.ContentLength = contentLength
}
if err != nil { if err != nil {
logrus.Errorf("创建上游请求失败: %v", err) logrus.Errorf("创建上游请求失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{