diff --git a/.env.example b/.env.example index ed19778..8bf727a 100644 --- a/.env.example +++ b/.env.example @@ -11,21 +11,6 @@ PORT=3000 # 服务器主机地址 HOST=0.0.0.0 -# =========================================== -# 密钥管理配置 -# =========================================== -# 密钥文件路径 -KEYS_FILE=keys.txt - -# 起始密钥索引 -START_INDEX=0 - -# 黑名单阈值(错误多少次后拉黑密钥) -BLACKLIST_THRESHOLD=1 - -# 最大重试次数(换key重试) -MAX_RETRIES=3 - # =========================================== # OpenAI 兼容 API 配置 # =========================================== @@ -63,7 +48,7 @@ LOG_ENABLE_REQUEST=true # 认证配置 # =========================================== # 项目认证密钥(可选,如果设置则启用认证) -# AUTH_KEY=your-secret-key +AUTH_KEY=sk-123456 # =========================================== # CORS 配置 diff --git a/Makefile b/Makefile index 390301d..5d3291f 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,3 @@ -# OpenAI 多密钥代理服务器 Makefile (Go版本) - # 变量定义 BINARY_NAME=gpt-load MAIN_PATH=./cmd/gpt-load @@ -7,82 +5,77 @@ BUILD_DIR=./build VERSION=2.0.0 LDFLAGS=-ldflags "-X main.Version=$(VERSION) -s -w" -# 默认目标 -.PHONY: all -all: clean build +# 从 .env 文件加载环境变量,如果不存在则使用默认值 +HOST ?= $(shell sed -n 's/^HOST=//p' .env 2>/dev/null || echo "localhost") +PORT ?= $(shell sed -n 's/^PORT=//p' .env 2>/dev/null || echo "3000") +API_BASE_URL=http://$(HOST):$(PORT) -# 构建 +# 默认目标 +.DEFAULT_GOAL := help + +.PHONY: all +all: clean build ## 清理并构建项目 + +# ============================================================================== +# 构建相关命令 +# ============================================================================== .PHONY: build -build: +build: ## 构建二进制文件 @echo "🔨 构建 $(BINARY_NAME)..." @mkdir -p $(BUILD_DIR) go build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME) $(MAIN_PATH) @echo "✅ 构建完成: $(BUILD_DIR)/$(BINARY_NAME)" -# 构建所有平台 .PHONY: build-all -build-all: clean +build-all: clean ## 为所有支持的平台构建二进制文件 @echo "🔨 构建所有平台版本..." @mkdir -p $(BUILD_DIR) - - # Linux AMD64 GOOS=linux GOARCH=amd64 go build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 $(MAIN_PATH) - - # Linux ARM64 GOOS=linux GOARCH=arm64 go build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 $(MAIN_PATH) - - # macOS AMD64 GOOS=darwin GOARCH=amd64 go build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-amd64 $(MAIN_PATH) - - # macOS ARM64 (Apple Silicon) GOOS=darwin GOARCH=arm64 go build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 $(MAIN_PATH) - - # Windows AMD64 GOOS=windows GOARCH=amd64 go build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe $(MAIN_PATH) - @echo "✅ 所有平台构建完成" -# 运行 +# ============================================================================== +# 运行与开发 +# ============================================================================== .PHONY: run -.PHONY: run - -run: +run: ## 构建前端并运行服务器 @echo "--- Building frontend... ---" cd web && npm install && npm run build @echo "--- Preparing backend... ---" @rm -rf cmd/gpt-load/dist @cp -r web/dist cmd/gpt-load/dist @echo "--- Starting backend... ---" - cd $(MAIN_PATH) && go run . + go run $(MAIN_PATH) -# 开发模式运行 .PHONY: dev -dev: +dev: ## 以开发模式运行(带竞态检测) @echo "🔧 开发模式启动..." go run -race $(MAIN_PATH) -# 测试 +# ============================================================================== +# 测试与代码质量 +# ============================================================================== .PHONY: test -test: +test: ## 运行所有测试 @echo "🧪 运行测试..." go test -v -race -coverprofile=coverage.out ./... -# 测试覆盖率 .PHONY: coverage -coverage: test +coverage: test ## 生成并查看测试覆盖率报告 @echo "📊 生成测试覆盖率报告..." go tool cover -html=coverage.out -o coverage.html @echo "✅ 覆盖率报告生成: coverage.html" -# 基准测试 .PHONY: bench -bench: +bench: ## 运行基准测试 @echo "⚡ 运行基准测试..." go test -bench=. -benchmem ./... -# 代码检查 .PHONY: lint -lint: +lint: ## 使用 golangci-lint 检查代码 @echo "🔍 代码检查..." @if command -v golangci-lint >/dev/null 2>&1; then \ golangci-lint run; \ @@ -91,9 +84,8 @@ lint: echo "安装命令: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest"; \ fi -# 格式化代码 .PHONY: fmt -fmt: +fmt: ## 格式化 Go 代码 @echo "🎨 格式化代码..." go fmt ./... @if command -v goimports >/dev/null 2>&1; then \ @@ -102,51 +94,50 @@ fmt: echo "💡 建议安装 goimports: go install golang.org/x/tools/cmd/goimports@latest"; \ fi -# 整理依赖 .PHONY: tidy -tidy: +tidy: ## 整理和验证模块依赖 @echo "📦 整理依赖..." go mod tidy go mod verify -# 安装依赖 .PHONY: deps -deps: +deps: ## 下载模块依赖 @echo "📥 安装依赖..." go mod download -# 清理 +# ============================================================================== +# 清理与安装 +# ============================================================================== .PHONY: clean -clean: +clean: ## 清理所有构建产物 @echo "🧹 清理构建文件..." rm -rf $(BUILD_DIR) rm -f coverage.out coverage.html -# 安装到系统 .PHONY: install -install: build +install: build ## 构建并安装二进制文件到 /usr/local/bin @echo "📦 安装到系统..." sudo cp $(BUILD_DIR)/$(BINARY_NAME) /usr/local/bin/ @echo "✅ 安装完成: /usr/local/bin/$(BINARY_NAME)" -# 卸载 .PHONY: uninstall -uninstall: +uninstall: ## 从 /usr/local/bin 卸载二进制文件 @echo "🗑️ 从系统卸载..." sudo rm -f /usr/local/bin/$(BINARY_NAME) @echo "✅ 卸载完成" -# Docker 构建 +# ============================================================================== +# Docker 相关命令 +# ============================================================================== .PHONY: docker-build -docker-build: +docker-build: ## 构建 Docker 镜像 @echo "🐳 构建 Docker 镜像..." docker build -t gpt-load:$(VERSION) . docker tag gpt-load:$(VERSION) gpt-load:latest @echo "✅ Docker 镜像构建完成" -# Docker 运行(使用预构建镜像) .PHONY: docker-run -docker-run: +docker-run: ## 使用预构建镜像运行 Docker 容器 @echo "🐳 运行 Docker 容器(预构建镜像)..." docker run -d \ --name gpt-load \ @@ -156,9 +147,8 @@ docker-run: --restart unless-stopped \ ghcr.io/tbphp/gpt-load:latest -# Docker 运行(本地构建) .PHONY: docker-run-local -docker-run-local: +docker-run-local: ## 使用本地构建的镜像运行 Docker 容器 @echo "🐳 运行 Docker 容器(本地构建)..." docker run -d \ --name gpt-load-local \ @@ -168,96 +158,50 @@ docker-run-local: --restart unless-stopped \ gpt-load:latest -# Docker Compose 运行(预构建镜像) .PHONY: compose-up -compose-up: +compose-up: ## 使用 Docker Compose 启动(预构建镜像) @echo "🐳 使用 Docker Compose 启动(预构建镜像)..." docker-compose up -d -# Docker Compose 运行(本地构建) .PHONY: compose-up-dev -compose-up-dev: +compose-up-dev: ## 使用 Docker Compose 启动(本地构建) @echo "🐳 使用 Docker Compose 启动(本地构建)..." docker-compose -f docker-compose.dev.yml up -d -# Docker Compose 停止 .PHONY: compose-down -compose-down: +compose-down: ## 停止所有 Docker Compose 服务 @echo "🐳 停止 Docker Compose..." docker-compose down docker-compose -f docker-compose.dev.yml down 2>/dev/null || true -# 密钥验证 +# ============================================================================== +# 服务管理与工具 +# ============================================================================== .PHONY: validate-keys -validate-keys: +validate-keys: ## 验证 API 密钥的有效性 @echo "🐍 使用 Python 版本验证密钥..." python3 scripts/validate-keys.py -c 100 -t 15 -# 健康检查 .PHONY: health -health: +health: ## 检查服务的健康状况 @echo "💚 健康检查..." - @curl -s http://localhost:3000/health | jq . || echo "请安装 jq 或检查服务是否运行" + @curl -s $(API_BASE_URL)/health | jq . || echo "请安装 jq 或检查服务是否运行" -# 查看统计 .PHONY: stats -stats: +stats: ## 查看服务的统计信息 @echo "📊 查看统计信息..." - @curl -s http://localhost:3000/stats | jq . || echo "请安装 jq 或检查服务是否运行" + @curl -s $(API_BASE_URL)/stats | jq . || echo "请安装 jq 或检查服务是否运行" -# 重置密钥 .PHONY: reset-keys -reset-keys: +reset-keys: ## 重置所有密钥的状态 @echo "🔄 重置密钥状态..." - @curl -s http://localhost:3000/reset-keys | jq . || echo "请安装 jq 或检查服务是否运行" + @curl -s $(API_BASE_URL)/reset-keys | jq . || echo "请安装 jq 或检查服务是否运行" -# 查看黑名单 .PHONY: blacklist -blacklist: +blacklist: ## 查看当前黑名单中的密钥 @echo "🚫 查看黑名单..." - @curl -s http://localhost:3000/blacklist | jq . || echo "请安装 jq 或检查服务是否运行" + @curl -s $(API_BASE_URL)/blacklist | jq . || echo "请安装 jq 或检查服务是否运行" -# 帮助 .PHONY: help -help: - @echo "OpenAI 多密钥代理服务器 v$(VERSION) - 可用命令:" - @echo "" - @echo "构建相关:" - @echo " build - 构建二进制文件" - @echo " build-all - 构建所有平台版本" - @echo " clean - 清理构建文件" - @echo "" - @echo "运行相关:" - @echo " run - 运行服务器" - @echo " dev - 开发模式运行" - @echo "" - @echo "测试相关:" - @echo " test - 运行测试" - @echo " coverage - 生成测试覆盖率报告" - @echo " bench - 运行基准测试" - @echo "" - @echo "代码质量:" - @echo " lint - 代码检查" - @echo " fmt - 格式化代码" - @echo " tidy - 整理依赖" - @echo "" - @echo "安装相关:" - @echo " install - 安装到系统" - @echo " uninstall - 从系统卸载" - @echo "" - @echo "Docker 相关:" - @echo " docker-build - 构建 Docker 镜像" - @echo " docker-run - 运行 Docker 容器(预构建镜像)" - @echo " docker-run-local - 运行 Docker 容器(本地构建)" - @echo " compose-up - Docker Compose 启动(预构建镜像)" - @echo " compose-up-dev - Docker Compose 启动(本地构建)" - @echo " compose-down - Docker Compose 停止" - @echo "" - @echo "管理相关:" - @echo " health - 健康检查" - @echo " stats - 查看统计信息" - @echo " reset-keys - 重置密钥状态" - @echo " blacklist - 查看黑名单" - @echo "" - @echo "密钥验证:" - @echo " validate-keys - 验证 API 密钥" +help: ## 显示此帮助信息 + @awk 'BEGIN {FS = ":.*?## "; printf "Usage:\n make \033[36m\033[0m\n\nTargets:\n"} /^[a-zA-Z0-9_-]+:.*?## / { printf " \033[36m%-20s\033[0m %s\n", $$1, $$2 }' $(MAKEFILE_LIST) diff --git a/cmd/gpt-load/main.go b/cmd/gpt-load/main.go index 96f23c2..386a08a 100644 --- a/cmd/gpt-load/main.go +++ b/cmd/gpt-load/main.go @@ -12,6 +12,7 @@ import ( "path" "path/filepath" "strings" + "sync" "syscall" "time" @@ -50,7 +51,9 @@ func main() { // --- Asynchronous Request Logging Setup --- requestLogChan := make(chan models.RequestLog, 1000) - go startRequestLogger(database, requestLogChan) + var wg sync.WaitGroup + wg.Add(1) + go startRequestLogger(database, requestLogChan, &wg) // --- // Create proxy server @@ -103,9 +106,15 @@ func main() { // Attempt graceful shutdown if err := server.Shutdown(ctx); err != nil { logrus.Errorf("Server forced to shutdown: %v", err) - } else { - logrus.Info("Server exited gracefully") } + + // Close the request log channel and wait for the logger to finish + logrus.Info("Closing request log channel...") + close(requestLogChan) + wg.Wait() + logrus.Info("All logs have been written.") + + logrus.Info("Server exited gracefully") } // setupRoutes configures the HTTP routes @@ -233,7 +242,6 @@ func setupLogger(configManager types.ConfigManager) { // displayStartupInfo shows startup information func displayStartupInfo(configManager types.ConfigManager) { serverConfig := configManager.GetServerConfig() - keysConfig := configManager.GetKeysConfig() openaiConfig := configManager.GetOpenAIConfig() authConfig := configManager.GetAuthConfig() corsConfig := configManager.GetCORSConfig() @@ -242,10 +250,6 @@ func displayStartupInfo(configManager types.ConfigManager) { logrus.Info("Current Configuration:") logrus.Infof(" Server: %s:%d", serverConfig.Host, serverConfig.Port) - logrus.Infof(" Keys file: %s", keysConfig.FilePath) - logrus.Infof(" Start index: %d", keysConfig.StartIndex) - logrus.Infof(" Blacklist threshold: %d errors", keysConfig.BlacklistThreshold) - logrus.Infof(" Max retries: %d", keysConfig.MaxRetries) logrus.Infof(" Upstream URL: %s", openaiConfig.BaseURL) logrus.Infof(" Request timeout: %ds", openaiConfig.RequestTimeout) logrus.Infof(" Response timeout: %ds", openaiConfig.ResponseTimeout) @@ -278,7 +282,8 @@ func displayStartupInfo(configManager types.ConfigManager) { } // startRequestLogger runs a background goroutine to batch-insert request logs. -func startRequestLogger(db *gorm.DB, logChan <-chan models.RequestLog) { +func startRequestLogger(db *gorm.DB, logChan <-chan models.RequestLog, wg *sync.WaitGroup) { + defer wg.Done() ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() diff --git a/internal/config/manager.go b/internal/config/manager.go index bf3008e..564d775 100644 --- a/internal/config/manager.go +++ b/internal/config/manager.go @@ -45,7 +45,6 @@ type Manager struct { // Config represents the application configuration type Config struct { Server types.ServerConfig `json:"server"` - Keys types.KeysConfig `json:"keys"` OpenAI types.OpenAIConfig `json:"openai"` Auth types.AuthConfig `json:"auth"` CORS types.CORSConfig `json:"cors"` @@ -78,12 +77,6 @@ func (m *Manager) ReloadConfig() error { IdleTimeout: parseInteger(os.Getenv("SERVER_IDLE_TIMEOUT"), 120), GracefulShutdownTimeout: parseInteger(os.Getenv("SERVER_GRACEFUL_SHUTDOWN_TIMEOUT"), 60), }, - Keys: types.KeysConfig{ - FilePath: getEnvOrDefault("KEYS_FILE", "keys.txt"), - StartIndex: parseInteger(os.Getenv("START_INDEX"), 0), - BlacklistThreshold: parseInteger(os.Getenv("BLACKLIST_THRESHOLD"), 1), - MaxRetries: parseInteger(os.Getenv("MAX_RETRIES"), 3), - }, OpenAI: types.OpenAIConfig{ BaseURLs: parseArray(os.Getenv("OPENAI_BASE_URL"), []string{"https://api.openai.com"}), RequestTimeout: parseInteger(os.Getenv("REQUEST_TIMEOUT"), DefaultConstants.DefaultTimeout), @@ -131,11 +124,6 @@ func (m *Manager) GetServerConfig() types.ServerConfig { return m.config.Server } -// GetKeysConfig returns keys configuration -func (m *Manager) GetKeysConfig() types.KeysConfig { - return m.config.Keys -} - // GetOpenAIConfig returns OpenAI configuration func (m *Manager) GetOpenAIConfig() types.OpenAIConfig { config := m.config.OpenAI @@ -178,16 +166,6 @@ func (m *Manager) Validate() error { validationErrors = append(validationErrors, fmt.Sprintf("port must be between %d-%d", DefaultConstants.MinPort, DefaultConstants.MaxPort)) } - // Validate start index - if m.config.Keys.StartIndex < 0 { - validationErrors = append(validationErrors, "start index cannot be less than 0") - } - - // Validate blacklist threshold - if m.config.Keys.BlacklistThreshold < 1 { - validationErrors = append(validationErrors, "blacklist threshold cannot be less than 1") - } - // Validate timeout if m.config.OpenAI.RequestTimeout < DefaultConstants.MinTimeout { validationErrors = append(validationErrors, fmt.Sprintf("request timeout cannot be less than %ds", DefaultConstants.MinTimeout)) @@ -223,10 +201,6 @@ func (m *Manager) Validate() error { func (m *Manager) DisplayConfig() { logrus.Info("Current Configuration:") logrus.Infof(" Server: %s:%d", m.config.Server.Host, m.config.Server.Port) - logrus.Infof(" Keys file: %s", m.config.Keys.FilePath) - logrus.Infof(" Start index: %d", m.config.Keys.StartIndex) - logrus.Infof(" Blacklist threshold: %d errors", m.config.Keys.BlacklistThreshold) - logrus.Infof(" Max retries: %d", m.config.Keys.MaxRetries) logrus.Infof(" Upstream URLs: %s", strings.Join(m.config.OpenAI.BaseURLs, ", ")) logrus.Infof(" Request timeout: %ds", m.config.OpenAI.RequestTimeout) logrus.Infof(" Response timeout: %ds", m.config.OpenAI.ResponseTimeout) diff --git a/internal/handler/dashboard_handler.go b/internal/handler/dashboard_handler.go index 1a0c483..82e28c6 100644 --- a/internal/handler/dashboard_handler.go +++ b/internal/handler/dashboard_handler.go @@ -1,54 +1,39 @@ package handler import ( - "net/http" - "github.com/gin-gonic/gin" - "gpt-load/internal/db" "gpt-load/internal/models" "gpt-load/internal/response" ) // GetDashboardStats godoc // @Summary Get dashboard statistics -// @Description Get statistics for the dashboard, including total requests, success rate, and group distribution. +// @Description Get statistics for the dashboard, including key counts and request metrics. // @Tags Dashboard // @Accept json // @Produce json -// @Success 200 {object} models.DashboardStats +// @Success 200 {object} map[string]interface{} // @Router /api/dashboard/stats [get] -func GetDashboardStats(c *gin.Context) { +func (s *Server) Stats(c *gin.Context) { var totalRequests, successRequests int64 var groupStats []models.GroupRequestStat - // Get total requests - if err := db.DB.Model(&models.RequestLog{}).Count(&totalRequests).Error; err != nil { - response.Error(c, http.StatusInternalServerError, "Failed to get total requests") - return - } + // 1. Get total and successful requests from the api_keys table + s.DB.Model(&models.APIKey{}).Select("SUM(request_count)").Row().Scan(&totalRequests) + s.DB.Model(&models.APIKey{}).Select("SUM(request_count) - SUM(failure_count)").Row().Scan(&successRequests) - // Get success requests (status code 2xx) - if err := db.DB.Model(&models.RequestLog{}).Where("status_code >= ? AND status_code < ?", 200, 300).Count(&successRequests).Error; err != nil { - response.Error(c, http.StatusInternalServerError, "Failed to get success requests") - return - } + // 2. Get request counts per group + s.DB.Table("api_keys"). + Select("groups.name as group_name, SUM(api_keys.request_count) as request_count"). + Joins("join groups on groups.id = api_keys.group_id"). + Group("groups.name"). + Order("request_count DESC"). + Scan(&groupStats) - // Calculate success rate + // 3. Calculate success rate var successRate float64 if totalRequests > 0 { - successRate = float64(successRequests) / float64(totalRequests) - } - - // Get group stats - err := db.DB.Table("request_logs"). - Select("groups.name as group_name, count(request_logs.id) as request_count"). - Joins("join groups on groups.id = request_logs.group_id"). - Group("groups.name"). - Order("request_count desc"). - Scan(&groupStats).Error - if err != nil { - response.Error(c, http.StatusInternalServerError, "Failed to get group stats") - return + successRate = float64(successRequests) / float64(totalRequests) * 100 } stats := models.DashboardStats{ diff --git a/internal/handler/handler.go b/internal/handler/handler.go index c4697d1..87aee43 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -3,7 +3,6 @@ package handler import ( "net/http" - "runtime" "time" "gpt-load/internal/models" @@ -53,7 +52,7 @@ func (s *Server) RegisterAPIRoutes(api *gin.RouterGroup) { // Dashboard and logs routes dashboard := api.Group("/dashboard") { - dashboard.GET("/stats", GetDashboardStats) + dashboard.GET("/stats", s.Stats) } api.GET("/logs", GetLogs) @@ -101,53 +100,6 @@ func (s *Server) Health(c *gin.Context) { }) } -// Stats handles statistics requests -func (s *Server) Stats(c *gin.Context) { - var totalKeys, healthyKeys, disabledKeys int64 - s.DB.Model(&models.APIKey{}).Count(&totalKeys) - s.DB.Model(&models.APIKey{}).Where("status = ?", "active").Count(&healthyKeys) - s.DB.Model(&models.APIKey{}).Where("status != ?", "active").Count(&disabledKeys) - - // TODO: Get request counts from the database - var successCount, failureCount int64 - s.DB.Model(&models.RequestLog{}).Where("status_code = ?", http.StatusOK).Count(&successCount) - s.DB.Model(&models.RequestLog{}).Where("status_code != ?", http.StatusOK).Count(&failureCount) - - // Add additional system information - var m runtime.MemStats - runtime.ReadMemStats(&m) - - response := gin.H{ - "keys": gin.H{ - "total": totalKeys, - "healthy": healthyKeys, - "disabled": disabledKeys, - }, - "requests": gin.H{ - "success_count": successCount, - "failure_count": failureCount, - "total_count": successCount + failureCount, - }, - "memory": gin.H{ - "alloc_mb": bToMb(m.Alloc), - "total_alloc_mb": bToMb(m.TotalAlloc), - "sys_mb": bToMb(m.Sys), - "num_gc": m.NumGC, - "last_gc": time.Unix(0, int64(m.LastGC)).Format("2006-01-02 15:04:05"), - "next_gc_mb": bToMb(m.NextGC), - }, - "system": gin.H{ - "goroutines": runtime.NumGoroutine(), - "cpu_count": runtime.NumCPU(), - "go_version": runtime.Version(), - }, - "timestamp": time.Now().UTC().Format(time.RFC3339), - } - - c.JSON(http.StatusOK, response) -} - - // MethodNotAllowed handles 405 requests func (s *Server) MethodNotAllowed(c *gin.Context) { c.JSON(http.StatusMethodNotAllowed, gin.H{ @@ -169,7 +121,6 @@ func (s *Server) GetConfig(c *gin.Context) { } serverConfig := s.config.GetServerConfig() - keysConfig := s.config.GetKeysConfig() openaiConfig := s.config.GetOpenAIConfig() authConfig := s.config.GetAuthConfig() corsConfig := s.config.GetCORSConfig() @@ -182,12 +133,6 @@ func (s *Server) GetConfig(c *gin.Context) { "host": serverConfig.Host, "port": serverConfig.Port, }, - "keys": gin.H{ - "file_path": keysConfig.FilePath, - "start_index": keysConfig.StartIndex, - "blacklist_threshold": keysConfig.BlacklistThreshold, - "max_retries": keysConfig.MaxRetries, - }, "openai": gin.H{ "base_url": openaiConfig.BaseURL, "request_timeout": openaiConfig.RequestTimeout, @@ -230,8 +175,3 @@ func (s *Server) GetConfig(c *gin.Context) { c.JSON(http.StatusOK, sanitizedConfig) } - -// Helper function to convert bytes to megabytes -func bToMb(b uint64) uint64 { - return b / 1024 / 1024 -} diff --git a/internal/keymanager/manager.go b/internal/keymanager/manager.go deleted file mode 100644 index 9097789..0000000 --- a/internal/keymanager/manager.go +++ /dev/null @@ -1,336 +0,0 @@ -// Package keymanager provides high-performance API key management -package keymanager - -import ( - "bufio" - "os" - "regexp" - "runtime" - "strings" - "sync" - "sync/atomic" - "time" - - "gpt-load/internal/errors" - "gpt-load/internal/types" - - "github.com/sirupsen/logrus" -) - -// Manager implements the KeyManager interface -type Manager struct { - keysFilePath string - keys []string - keyPreviews []string - currentIndex int64 - blacklistedKeys sync.Map - successCount int64 - failureCount int64 - keyFailureCounts sync.Map - config types.KeysConfig - - // Performance optimization: pre-compiled regex patterns - permanentErrorPatterns []*regexp.Regexp - - // Memory management - cleanupTicker *time.Ticker - stopCleanup chan bool - - // Read-write lock to protect key list - keysMutex sync.RWMutex -} - -// NewManager creates a new key manager -func NewManager(config types.KeysConfig) (types.KeyManager, error) { - if config.FilePath == "" { - return nil, errors.NewAppError(errors.ErrKeyFileNotFound, "Keys file path is required") - } - - km := &Manager{ - keysFilePath: config.FilePath, - currentIndex: int64(config.StartIndex), - stopCleanup: make(chan bool), - config: config, - - // Pre-compile regex patterns - permanentErrorPatterns: []*regexp.Regexp{ - regexp.MustCompile(`(?i)invalid api key`), - regexp.MustCompile(`(?i)incorrect api key`), - regexp.MustCompile(`(?i)api key not found`), - regexp.MustCompile(`(?i)unauthorized`), - regexp.MustCompile(`(?i)account deactivated`), - regexp.MustCompile(`(?i)billing`), - }, - } - - // Start memory cleanup - km.setupMemoryCleanup() - - // Load keys - if err := km.LoadKeys(); err != nil { - return nil, err - } - - return km, nil -} - -// LoadKeys loads API keys from file -func (km *Manager) LoadKeys() error { - file, err := os.Open(km.keysFilePath) - if err != nil { - return errors.NewAppErrorWithCause(errors.ErrKeyFileNotFound, "Failed to open keys file", err) - } - defer file.Close() - - var keys []string - var keyPreviews []string - - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line != "" && !strings.HasPrefix(line, "#") { - keys = append(keys, line) - // Create preview (first 8 chars + "..." + last 4 chars) - if len(line) > 12 { - preview := line[:8] + "..." + line[len(line)-4:] - keyPreviews = append(keyPreviews, preview) - } else { - keyPreviews = append(keyPreviews, line) - } - } - } - - if err := scanner.Err(); err != nil { - return errors.NewAppErrorWithCause(errors.ErrKeyFileInvalid, "Failed to read keys file", err) - } - - if len(keys) == 0 { - return errors.NewAppError(errors.ErrNoKeysAvailable, "No valid API keys found in file") - } - - km.keysMutex.Lock() - km.keys = keys - km.keyPreviews = keyPreviews - km.keysMutex.Unlock() - - logrus.Infof("Successfully loaded %d API keys", len(keys)) - return nil -} - -// GetNextKey gets the next available key (high-performance version) -func (km *Manager) GetNextKey() (*types.KeyInfo, error) { - km.keysMutex.RLock() - keysLen := len(km.keys) - if keysLen == 0 { - km.keysMutex.RUnlock() - return nil, errors.ErrNoAPIKeysAvailable - } - - // Fast path: directly get next key, avoid blacklist check overhead - currentIdx := atomic.AddInt64(&km.currentIndex, 1) - 1 - keyIndex := int(currentIdx) % keysLen - selectedKey := km.keys[keyIndex] - keyPreview := km.keyPreviews[keyIndex] - km.keysMutex.RUnlock() - - // Check if blacklisted - if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted { - return &types.KeyInfo{ - Key: selectedKey, - Index: keyIndex, - Preview: keyPreview, - }, nil - } - - // Slow path: find next available key - return km.findNextAvailableKey(keyIndex, keysLen) -} - -// findNextAvailableKey finds the next available non-blacklisted key -func (km *Manager) findNextAvailableKey(startIndex, keysLen int) (*types.KeyInfo, error) { - km.keysMutex.RLock() - defer km.keysMutex.RUnlock() - - blacklistedCount := 0 - for i := 0; i < keysLen; i++ { - keyIndex := (startIndex + i) % keysLen - selectedKey := km.keys[keyIndex] - - if _, blacklisted := km.blacklistedKeys.Load(selectedKey); !blacklisted { - return &types.KeyInfo{ - Key: selectedKey, - Index: keyIndex, - Preview: km.keyPreviews[keyIndex], - }, nil - } - blacklistedCount++ - } - - if blacklistedCount >= keysLen { - logrus.Warn("All keys are blacklisted, resetting blacklist") - km.blacklistedKeys = sync.Map{} - km.keyFailureCounts = sync.Map{} - - // Return first key after reset - firstKey := km.keys[0] - firstPreview := km.keyPreviews[0] - - return &types.KeyInfo{ - Key: firstKey, - Index: 0, - Preview: firstPreview, - }, nil - } - - return nil, errors.ErrAllAPIKeysBlacklisted -} - -// RecordSuccess records successful key usage -func (km *Manager) RecordSuccess(key string) { - atomic.AddInt64(&km.successCount, 1) - // Reset failure count for this key on success - km.keyFailureCounts.Delete(key) -} - -// RecordFailure records key failure and potentially blacklists it -func (km *Manager) RecordFailure(key string, err error) { - atomic.AddInt64(&km.failureCount, 1) - - // Check if this is a permanent error - if km.isPermanentError(err) { - km.blacklistedKeys.Store(key, time.Now()) - logrus.Debugf("Key blacklisted due to permanent error: %v", err) - return - } - - // Increment failure count - failCount, _ := km.keyFailureCounts.LoadOrStore(key, new(int64)) - if counter, ok := failCount.(*int64); ok { - newFailCount := atomic.AddInt64(counter, 1) - - // Blacklist if threshold exceeded - if int(newFailCount) >= km.config.BlacklistThreshold { - km.blacklistedKeys.Store(key, time.Now()) - logrus.Debugf("Key blacklisted after %d failures", newFailCount) - } - } -} - -// isPermanentError checks if an error is permanent -func (km *Manager) isPermanentError(err error) bool { - if err == nil { - return false - } - - errorStr := strings.ToLower(err.Error()) - for _, pattern := range km.permanentErrorPatterns { - if pattern.MatchString(errorStr) { - return true - } - } - return false -} - -// GetStats returns current statistics -func (km *Manager) GetStats() types.Stats { - km.keysMutex.RLock() - totalKeys := len(km.keys) - km.keysMutex.RUnlock() - - blacklistedCount := 0 - km.blacklistedKeys.Range(func(key, value any) bool { - blacklistedCount++ - return true - }) - - var m runtime.MemStats - runtime.ReadMemStats(&m) - - return types.Stats{ - CurrentIndex: atomic.LoadInt64(&km.currentIndex), - TotalKeys: totalKeys, - HealthyKeys: totalKeys - blacklistedCount, - BlacklistedKeys: blacklistedCount, - SuccessCount: atomic.LoadInt64(&km.successCount), - FailureCount: atomic.LoadInt64(&km.failureCount), - MemoryUsage: types.MemoryUsage{ - Alloc: m.Alloc, - TotalAlloc: m.TotalAlloc, - Sys: m.Sys, - NumGC: m.NumGC, - LastGCTime: time.Unix(0, int64(m.LastGC)).Format("2006-01-02 15:04:05"), - NextGCTarget: m.NextGC, - }, - } -} - -// ResetBlacklist resets the blacklist -func (km *Manager) ResetBlacklist() { - km.blacklistedKeys = sync.Map{} - km.keyFailureCounts = sync.Map{} - logrus.Info("Blacklist reset successfully") -} - -// GetBlacklist returns current blacklisted keys -func (km *Manager) GetBlacklist() []types.BlacklistEntry { - var blacklist []types.BlacklistEntry - - km.blacklistedKeys.Range(func(key, value any) bool { - keyStr := key.(string) - blacklistTime := value.(time.Time) - - // Create preview - preview := keyStr - if len(keyStr) > 12 { - preview = keyStr[:8] + "..." + keyStr[len(keyStr)-4:] - } - - // Get failure count - failCount := 0 - if count, exists := km.keyFailureCounts.Load(keyStr); exists { - failCount = int(atomic.LoadInt64(count.(*int64))) - } - - blacklist = append(blacklist, types.BlacklistEntry{ - Key: keyStr, - Preview: preview, - Reason: "Exceeded failure threshold", - BlacklistAt: blacklistTime, - FailCount: failCount, - }) - return true - }) - - return blacklist -} - -// setupMemoryCleanup sets up periodic memory cleanup -func (km *Manager) setupMemoryCleanup() { - // Reduce GC frequency to every 15 minutes to avoid performance impact - km.cleanupTicker = time.NewTicker(15 * time.Minute) - go func() { - for { - select { - case <-km.cleanupTicker.C: - // Only trigger GC if memory usage is high - var m runtime.MemStats - runtime.ReadMemStats(&m) - // Trigger GC only if allocated memory is above 100MB - if m.Alloc > 100*1024*1024 { - runtime.GC() - logrus.Debugf("Manual GC triggered, memory usage: %d MB", m.Alloc/1024/1024) - } - case <-km.stopCleanup: - return - } - } - }() -} - -// Close closes the key manager and cleans up resources -func (km *Manager) Close() { - if km.cleanupTicker != nil { - km.cleanupTicker.Stop() - } - close(km.stopCleanup) -} diff --git a/internal/proxy/server.go b/internal/proxy/server.go index bb529a2..c93fbdf 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -8,6 +8,7 @@ import ( "gpt-load/internal/response" "net/http" "sync" + "sync/atomic" "time" "github.com/gin-gonic/gin" @@ -18,7 +19,7 @@ import ( // ProxyServer represents the proxy server type ProxyServer struct { DB *gorm.DB - groupCounters sync.Map // For round-robin key selection + groupCounters sync.Map // map[uint]*atomic.Uint64 requestLogChan chan models.RequestLog } @@ -82,18 +83,22 @@ func (ps *ProxyServer) selectAPIKey(group *models.Group) (*models.APIKey, error) return nil, fmt.Errorf("no active API keys available in group '%s'", group.Name) } - // Get the current counter for the group - counter, _ := ps.groupCounters.LoadOrStore(group.ID, uint64(0)) - currentCounter := counter.(uint64) + // Get or create a counter for the group. The value is a pointer to a uint64. + val, _ := ps.groupCounters.LoadOrStore(group.ID, new(atomic.Uint64)) + counter := val.(*atomic.Uint64) - // Select the key and increment the counter - selectedKey := activeKeys[int(currentCounter%uint64(len(activeKeys)))] - ps.groupCounters.Store(group.ID, currentCounter+1) + // Atomically increment the counter and get the index for this request. + index := counter.Add(1) - 1 + selectedKey := activeKeys[int(index%uint64(len(activeKeys)))] return &selectedKey, nil } func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time) { + // Update key stats based on request success + isSuccess := c.Writer.Status() < 400 + go ps.updateKeyStats(key.ID, isSuccess) + logEntry := models.RequestLog{ ID: fmt.Sprintf("req_%d", time.Now().UnixNano()), Timestamp: startTime, @@ -113,6 +118,27 @@ func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *mode } } +// updateKeyStats atomically updates the request and failure counts for a key +func (ps *ProxyServer) updateKeyStats(keyID uint, success bool) { + // Always increment the request count + updates := map[string]interface{}{ + "request_count": gorm.Expr("request_count + 1"), + } + + // Additionally, increment the failure count if the request was not successful + if !success { + updates["failure_count"] = gorm.Expr("failure_count + 1") + } + + result := ps.DB.Model(&models.APIKey{}).Where("id = ?", keyID).Updates(updates) + if result.Error != nil { + logrus.WithFields(logrus.Fields{ + "keyID": keyID, + "error": result.Error, + }).Error("Failed to update key stats") + } +} + // Close cleans up resources func (ps *ProxyServer) Close() { // Nothing to close for now diff --git a/internal/types/types.go b/internal/types/types.go index d46589e..083ac13 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -2,15 +2,12 @@ package types import ( - "time" - "github.com/gin-gonic/gin" ) // ConfigManager defines the interface for configuration management type ConfigManager interface { GetServerConfig() ServerConfig - GetKeysConfig() KeysConfig GetOpenAIConfig() OpenAIConfig GetAuthConfig() AuthConfig GetCORSConfig() CORSConfig @@ -21,18 +18,6 @@ type ConfigManager interface { ReloadConfig() error } -// KeyManager defines the interface for API key management -type KeyManager interface { - LoadKeys() error - GetNextKey() (*KeyInfo, error) - RecordSuccess(key string) - RecordFailure(key string, err error) - GetStats() Stats - ResetBlacklist() - GetBlacklist() []BlacklistEntry - Close() -} - // ProxyServer defines the interface for proxy server type ProxyServer interface { HandleProxy(c *gin.Context) @@ -49,14 +34,6 @@ type ServerConfig struct { GracefulShutdownTimeout int `json:"gracefulShutdownTimeout"` } -// KeysConfig represents keys configuration -type KeysConfig struct { - FilePath string `json:"filePath"` - StartIndex int `json:"startIndex"` - BlacklistThreshold int `json:"blacklistThreshold"` - MaxRetries int `json:"maxRetries"` -} - // OpenAIConfig represents OpenAI API configuration type OpenAIConfig struct { BaseURL string `json:"baseUrl"` @@ -95,48 +72,3 @@ type LogConfig struct { FilePath string `json:"filePath"` EnableRequest bool `json:"enableRequest"` } - -// KeyInfo represents API key information -type KeyInfo struct { - Key string `json:"key"` - Index int `json:"index"` - Preview string `json:"preview"` -} - -// Stats represents system statistics -type Stats struct { - CurrentIndex int64 `json:"currentIndex"` - TotalKeys int `json:"totalKeys"` - HealthyKeys int `json:"healthyKeys"` - BlacklistedKeys int `json:"blacklistedKeys"` - SuccessCount int64 `json:"successCount"` - FailureCount int64 `json:"failureCount"` - MemoryUsage MemoryUsage `json:"memoryUsage"` -} - -// MemoryUsage represents memory usage statistics -type MemoryUsage struct { - Alloc uint64 `json:"alloc"` - TotalAlloc uint64 `json:"totalAlloc"` - Sys uint64 `json:"sys"` - NumGC uint32 `json:"numGC"` - LastGCTime string `json:"lastGCTime"` - NextGCTarget uint64 `json:"nextGCTarget"` -} - -// BlacklistEntry represents a blacklisted key entry -type BlacklistEntry struct { - Key string `json:"key"` - Preview string `json:"preview"` - Reason string `json:"reason"` - BlacklistAt time.Time `json:"blacklistAt"` - FailCount int `json:"failCount"` -} - -// RetryError represents retry error information -type RetryError struct { - StatusCode int `json:"statusCode"` - ErrorMessage string `json:"errorMessage"` - KeyIndex int `json:"keyIndex"` - Attempt int `json:"attempt"` -} \ No newline at end of file diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 0000000..f8d9ae0 --- /dev/null +++ b/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "gpt-load", + "lockfileVersion": 3, + "requires": true, + "packages": {} +}