From 76fe4d4dd3388492c8bd63d7bf92efcb778a9261 Mon Sep 17 00:00:00 2001 From: tbphp Date: Tue, 22 Jul 2025 15:38:55 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=AE=A4=E8=AF=81key=E5=85=BC=E5=AE=B9?= =?UTF-8?q?=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/channel/anthropic_channel.go | 19 ------ internal/channel/channel.go | 3 - internal/channel/gemini_channel.go | 16 ----- internal/channel/openai_channel.go | 12 ---- internal/middleware/middleware.go | 88 ++++++--------------------- internal/router/router.go | 12 ++-- web/src/api/keys.ts | 2 +- web/src/api/logs.ts | 2 +- 8 files changed, 25 insertions(+), 129 deletions(-) diff --git a/internal/channel/anthropic_channel.go b/internal/channel/anthropic_channel.go index 83997c2..1919a32 100644 --- a/internal/channel/anthropic_channel.go +++ b/internal/channel/anthropic_channel.go @@ -60,25 +60,6 @@ func (ch *AnthropicChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bo return false } -// ExtractKey extracts the API key from the x-api-key header. -func (ch *AnthropicChannel) ExtractKey(c *gin.Context) string { - // Check x-api-key header (Anthropic's standard) - if key := c.GetHeader("x-api-key"); key != "" { - return key - } - - // Fallback to Authorization header for compatibility - authHeader := c.GetHeader("Authorization") - if authHeader != "" { - const bearerPrefix = "Bearer " - if strings.HasPrefix(authHeader, bearerPrefix) { - return authHeader[len(bearerPrefix):] - } - } - - return "" -} - // ValidateKey checks if the given API key is valid by making a messages request. func (ch *AnthropicChannel) ValidateKey(ctx context.Context, key string) (bool, error) { upstreamURL := ch.getUpstreamURL() diff --git a/internal/channel/channel.go b/internal/channel/channel.go index aac40cc..9a03c2d 100644 --- a/internal/channel/channel.go +++ b/internal/channel/channel.go @@ -29,9 +29,6 @@ type ChannelProxy interface { // IsStreamRequest checks if the request is for a streaming response, IsStreamRequest(c *gin.Context, bodyBytes []byte) bool - // ExtractKey extracts the API key from the request. - ExtractKey(c *gin.Context) string - // ValidateKey checks if the given API key is valid. ValidateKey(ctx context.Context, key string) (bool, error) } diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index 6d552c1..4e769d8 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -40,7 +40,6 @@ func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey, req.URL.RawQuery = q.Encode() } - // IsStreamRequest checks if the request is for a streaming response. func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool { path := c.Request.URL.Path @@ -59,21 +58,6 @@ func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool return false } -// ExtractKey extracts the API key from the X-Goog-Api-Key header or the "key" query parameter. -func (ch *GeminiChannel) ExtractKey(c *gin.Context) string { - // 1. Check X-Goog-Api-Key header - if key := c.GetHeader("X-Goog-Api-Key"); key != "" { - return key - } - - // 2. Check "key" query parameter - if key := c.Query("key"); key != "" { - return key - } - - return "" -} - // ValidateKey checks if the given API key is valid by making a generateContent request. func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) { upstreamURL := ch.getUpstreamURL() diff --git a/internal/channel/openai_channel.go b/internal/channel/openai_channel.go index 1f014a8..caa0a7b 100644 --- a/internal/channel/openai_channel.go +++ b/internal/channel/openai_channel.go @@ -59,18 +59,6 @@ func (ch *OpenAIChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool return false } -// ExtractKey extracts the API key from the Authorization header. -func (ch *OpenAIChannel) ExtractKey(c *gin.Context) string { - authHeader := c.GetHeader("Authorization") - if authHeader != "" { - const bearerPrefix = "Bearer " - if strings.HasPrefix(authHeader, bearerPrefix) { - return authHeader[len(bearerPrefix):] - } - } - return "" -} - // ValidateKey checks if the given API key is valid by making a chat completion request. func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) { upstreamURL := ch.getUpstreamURL() diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 8e8ff62..91b77fd 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -6,10 +6,8 @@ import ( "strings" "time" - "gpt-load/internal/channel" app_errors "gpt-load/internal/errors" "gpt-load/internal/response" - "gpt-load/internal/services" "gpt-load/internal/types" "github.com/gin-gonic/gin" @@ -116,45 +114,16 @@ func CORS(config types.CORSConfig) gin.HandlerFunc { } // Auth creates an authentication middleware -func Auth( - authConfig types.AuthConfig, - groupManager *services.GroupManager, - channelFactory *channel.Factory, -) gin.HandlerFunc { +func Auth(authConfig types.AuthConfig) gin.HandlerFunc { return func(c *gin.Context) { path := c.Request.URL.Path - // Skip authentication for health endpoints if isMonitoringEndpoint(path) { c.Next() return } - var key string - var err error - - if strings.HasPrefix(path, "/api") { - // Handle backend API authentication - key = extractApiKey(c) - } else if strings.HasPrefix(path, "/proxy/") { - // Handle proxy authentication - key, err = extractProxyKey(c, groupManager, channelFactory) - if err != nil { - // The error from extractProxyKey is already an APIError - if apiErr, ok := err.(*app_errors.APIError); ok { - response.Error(c, apiErr) - } else { - response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, err.Error())) - } - c.Abort() - return - } - } else { - // For any other paths, deny access by default - response.Error(c, app_errors.ErrResourceNotFound) - c.Abort() - return - } + key := extractAuthKey(c) if key == "" || key != authConfig.Key { response.Error(c, app_errors.ErrUnauthorized) @@ -162,8 +131,6 @@ func Auth( return } - // Key is extracted, but validation is handled by the proxy logic itself. - // For the backend API, we've already validated it. c.Next() } } @@ -227,8 +194,10 @@ func isMonitoringEndpoint(path string) bool { return false } -// extractBearerKey extracts a key from the "Authorization: Bearer " header. -func extractApiKey(c *gin.Context) string { +// extractAuthKey extracts a auth key. +func extractAuthKey(c *gin.Context) string { + + // Bearer token authHeader := c.GetHeader("Authorization") if authHeader != "" { const bearerPrefix = "Bearer " @@ -237,39 +206,20 @@ func extractApiKey(c *gin.Context) string { } } - authKey := c.Query("auth_key") - if authKey != "" { - return authKey + // X-Api-Key + if key := c.GetHeader("X-Api-Key"); key != "" { + return key + } + + // X-Goog-Api-Key + if key := c.GetHeader("X-Goog-Api-Key"); key != "" { + return key + } + + // Query key + if key := c.Query("key"); key != "" { + return key } return "" } - -// extractProxyKey handles key extraction for proxy routes. -func extractProxyKey( - c *gin.Context, - groupManager *services.GroupManager, - channelFactory *channel.Factory, -) (string, error) { - groupName := c.Param("group_name") - if groupName == "" { - return "", app_errors.NewAPIError(app_errors.ErrBadRequest, "Group name is missing in the URL path") - } - - group, err := groupManager.GetGroupByName(groupName) - if err != nil { - return "", app_errors.NewAPIError(app_errors.ErrResourceNotFound, fmt.Sprintf("Group '%s' not found", groupName)) - } - - channel, err := channelFactory.GetChannel(group) - if err != nil { - return "", app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s'", groupName)) - } - - key := channel.ExtractKey(c) - if key == "" { - return "", app_errors.ErrUnauthorized - } - - return key, nil -} diff --git a/internal/router/router.go b/internal/router/router.go index 57acb5a..85ef37a 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -65,8 +65,8 @@ func NewRouter( // 注册路由 registerSystemRoutes(router, serverHandler) - registerAPIRoutes(router, serverHandler, configManager, groupManager, channelFactory) - registerProxyRoutes(router, proxyServer, configManager, groupManager, channelFactory) + registerAPIRoutes(router, serverHandler, configManager) + registerProxyRoutes(router, proxyServer, configManager) registerFrontendRoutes(router, buildFS, indexPage) return router @@ -82,8 +82,6 @@ func registerAPIRoutes( router *gin.Engine, serverHandler *handler.Server, configManager types.ConfigManager, - groupManager *services.GroupManager, - channelFactory *channel.Factory, ) { api := router.Group("/api") authConfig := configManager.GetAuthConfig() @@ -93,7 +91,7 @@ func registerAPIRoutes( // 认证 protectedAPI := api.Group("") - protectedAPI.Use(middleware.Auth(authConfig, groupManager, channelFactory)) + protectedAPI.Use(middleware.Auth(authConfig)) registerProtectedAPIRoutes(protectedAPI, serverHandler) } @@ -162,13 +160,11 @@ func registerProxyRoutes( router *gin.Engine, proxyServer *proxy.ProxyServer, configManager types.ConfigManager, - groupManager *services.GroupManager, - channelFactory *channel.Factory, ) { proxyGroup := router.Group("/proxy") authConfig := configManager.GetAuthConfig() - proxyGroup.Use(middleware.Auth(authConfig, groupManager, channelFactory)) + proxyGroup.Use(middleware.Auth(authConfig)) proxyGroup.Any("/:group_name/*path", proxyServer.HandleProxy) } diff --git a/web/src/api/keys.ts b/web/src/api/keys.ts index 2a10a1e..e1f59c5 100644 --- a/web/src/api/keys.ts +++ b/web/src/api/keys.ts @@ -157,7 +157,7 @@ export const keysApi = { const params = new URLSearchParams({ group_id: groupId.toString(), - auth_key: authKey, + key: authKey, }); if (status !== "all") { diff --git a/web/src/api/logs.ts b/web/src/api/logs.ts index 952b560..578421e 100644 --- a/web/src/api/logs.ts +++ b/web/src/api/logs.ts @@ -31,7 +31,7 @@ export const logApi = { {} as Record ) ); - queryParams.append("auth_key", authKey); + queryParams.append("key", authKey); const url = `${http.defaults.baseURL}/logs/export?${queryParams.toString()}`;