From 70554b8fe5762b031189a3d6d67ada5e3697480d Mon Sep 17 00:00:00 2001 From: tbphp Date: Fri, 11 Jul 2025 22:39:28 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=AA=8C=E8=AF=81key=E5=88=86=E7=A6=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/channel/channel.go | 21 +++--- internal/channel/gemini_channel.go | 53 +++++++++------ internal/channel/openai_channel.go | 55 +++++++++------ internal/middleware/middleware.go | 103 +++++++++++++++++++++-------- internal/router/router.go | 28 ++++++-- 5 files changed, 175 insertions(+), 85 deletions(-) diff --git a/internal/channel/channel.go b/internal/channel/channel.go index de074d6..aac40cc 100644 --- a/internal/channel/channel.go +++ b/internal/channel/channel.go @@ -14,15 +14,6 @@ type ChannelProxy interface { // BuildUpstreamURL constructs the target URL for the upstream service. BuildUpstreamURL(originalURL *url.URL, group *models.Group) (string, error) - // ModifyRequest allows the channel to add specific headers or modify the request - ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) - - // IsStreamRequest checks if the request is for a streaming response, - IsStreamRequest(c *gin.Context, bodyBytes []byte) bool - - // ValidateKey checks if the given API key is valid. - ValidateKey(ctx context.Context, key string) (bool, error) - // IsConfigStale checks if the channel's configuration is stale compared to the provided group. IsConfigStale(group *models.Group) bool @@ -31,4 +22,16 @@ type ChannelProxy interface { // GetStreamClient returns the client for streaming requests. GetStreamClient() *http.Client + + // ModifyRequest allows the channel to add specific headers or modify the request + ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) + + // IsStreamRequest checks if the request is for a streaming response, + IsStreamRequest(c *gin.Context, bodyBytes []byte) bool + + // ExtractKey extracts the API key from the request. + ExtractKey(c *gin.Context) string + + // ValidateKey checks if the given API key is valid. + ValidateKey(ctx context.Context, key string) (bool, error) } diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index 9b463fd..053eef4 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -40,6 +40,40 @@ func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey, req.URL.RawQuery = q.Encode() } + +// IsStreamRequest checks if the request is for a streaming response. +func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool { + path := c.Request.URL.Path + if strings.HasSuffix(path, ":streamGenerateContent") { + return true + } + + // Also check for standard streaming indicators as a fallback. + if strings.Contains(c.GetHeader("Accept"), "text/event-stream") { + return true + } + if c.Query("stream") == "true" { + return true + } + + return false +} + +// ExtractKey extracts the API key from the X-Goog-Api-Key header or the "key" query parameter. +func (ch *GeminiChannel) ExtractKey(c *gin.Context) string { + // 1. Check X-Goog-Api-Key header + if key := c.GetHeader("X-Goog-Api-Key"); key != "" { + return key + } + + // 2. Check "key" query parameter + if key := c.Query("key"); key != "" { + return key + } + + return "" +} + // ValidateKey checks if the given API key is valid by making a generateContent request. func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) { upstreamURL := ch.getUpstreamURL() @@ -89,22 +123,3 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, err return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError) } - -// IsStreamRequest checks if the request is for a streaming response. -// For Gemini, this is primarily determined by the URL path. -func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool { - path := c.Request.URL.Path - if strings.HasSuffix(path, ":streamGenerateContent") { - return true - } - - // Also check for standard streaming indicators as a fallback. - if strings.Contains(c.GetHeader("Accept"), "text/event-stream") { - return true - } - if c.Query("stream") == "true" { - return true - } - - return false -} diff --git a/internal/channel/openai_channel.go b/internal/channel/openai_channel.go index 2bc5963..016e0d1 100644 --- a/internal/channel/openai_channel.go +++ b/internal/channel/openai_channel.go @@ -33,12 +33,44 @@ func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) { }, nil } - // ModifyRequest sets the Authorization header for the OpenAI service. func (ch *OpenAIChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) { req.Header.Set("Authorization", "Bearer "+apiKey.KeyValue) } +// IsStreamRequest checks if the request is for a streaming response using the pre-read body. +func (ch *OpenAIChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool { + if strings.Contains(c.GetHeader("Accept"), "text/event-stream") { + return true + } + + if c.Query("stream") == "true" { + return true + } + + type streamPayload struct { + Stream bool `json:"stream"` + } + var p streamPayload + if err := json.Unmarshal(bodyBytes, &p); err == nil { + return p.Stream + } + + return false +} + +// ExtractKey extracts the API key from the Authorization header. +func (ch *OpenAIChannel) ExtractKey(c *gin.Context) string { + authHeader := c.GetHeader("Authorization") + if authHeader != "" { + const bearerPrefix = "Bearer " + if strings.HasPrefix(authHeader, bearerPrefix) { + return authHeader[len(bearerPrefix):] + } + } + return "" +} + // ValidateKey checks if the given API key is valid by making a chat completion request. func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) { upstreamURL := ch.getUpstreamURL() @@ -90,24 +122,3 @@ func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, err return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError) } - -// IsStreamRequest checks if the request is for a streaming response using the pre-read body. -func (ch *OpenAIChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool { - if strings.Contains(c.GetHeader("Accept"), "text/event-stream") { - return true - } - - if c.Query("stream") == "true" { - return true - } - - type streamPayload struct { - Stream bool `json:"stream"` - } - var p streamPayload - if err := json.Unmarshal(bodyBytes, &p); err == nil { - return p.Stream - } - - return false -} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index cc4c83c..fdb5880 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -8,7 +8,8 @@ import ( "gpt-load/internal/response" "gpt-load/internal/types" - + "gpt-load/internal/channel" + "gpt-load/internal/services" app_errors "gpt-load/internal/errors" "github.com/gin-gonic/gin" @@ -115,31 +116,59 @@ func CORS(config types.CORSConfig) gin.HandlerFunc { } // Auth creates an authentication middleware -func Auth(config types.AuthConfig) gin.HandlerFunc { +func Auth( + authConfig types.AuthConfig, + groupManager *services.GroupManager, + channelFactory *channel.Factory, +) gin.HandlerFunc { return func(c *gin.Context) { - - // Skip authentication for management endpoints path := c.Request.URL.Path - if path == "/health" || path == "/stats" { + + // Skip authentication for health/stats endpoints + if isMonitoringEndpoint(path) { c.Next() return } - // Extract key from multiple sources - key := extractKey(c) + var key string + var err error + + if strings.HasPrefix(path, "/api") { + // Handle backend API authentication + key = extractBearerKey(c) + if key == "" || key != authConfig.Key { + response.Error(c, app_errors.ErrUnauthorized) + c.Abort() + return + } + } else if strings.HasPrefix(path, "/proxy/") { + // Handle proxy authentication + key, err = extractProxyKey(c, groupManager, channelFactory) + if err != nil { + // The error from extractProxyKey is already an APIError + if apiErr, ok := err.(*app_errors.APIError); ok { + response.Error(c, apiErr) + } else { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, err.Error())) + } + c.Abort() + return + } + } else { + // For any other paths, deny access by default + response.Error(c, app_errors.ErrResourceNotFound) + c.Abort() + return + } + if key == "" { response.Error(c, app_errors.ErrUnauthorized) c.Abort() return } - // Validate key - if key != config.Key { - response.Error(c, app_errors.ErrUnauthorized) - c.Abort() - return - } - + // Key is extracted, but validation is handled by the proxy logic itself. + // For the backend API, we've already validated it. c.Next() } } @@ -203,10 +232,8 @@ func isMonitoringEndpoint(path string) bool { return false } -// extractKey extracts the API key from the request, checking the Authorization header, -// the X-Goog-Api-Key header, and the "key" query parameter. -func extractKey(c *gin.Context) string { - // 1. Check Authorization header +// extractBearerKey extracts a key from the "Authorization: Bearer " header. +func extractBearerKey(c *gin.Context) string { authHeader := c.GetHeader("Authorization") if authHeader != "" { const bearerPrefix = "Bearer " @@ -214,16 +241,34 @@ func extractKey(c *gin.Context) string { return authHeader[len(bearerPrefix):] } } - - // 2. Check X-Goog-Api-Key header - if key := c.GetHeader("X-Goog-Api-Key"); key != "" { - return key - } - - // 3. Check "key" query parameter - if key := c.Query("key"); key != "" { - return key - } - return "" } + +// extractProxyKey handles key extraction for proxy routes. +func extractProxyKey( + c *gin.Context, + groupManager *services.GroupManager, + channelFactory *channel.Factory, +) (string, error) { + groupName := c.Param("group_name") + if groupName == "" { + return "", app_errors.NewAPIError(app_errors.ErrBadRequest, "Group name is missing in the URL path") + } + + group, err := groupManager.GetGroupByName(groupName) + if err != nil { + return "", app_errors.NewAPIError(app_errors.ErrResourceNotFound, fmt.Sprintf("Group '%s' not found", groupName)) + } + + channel, err := channelFactory.GetChannel(group) + if err != nil { + return "", app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s'", groupName)) + } + + key := channel.ExtractKey(c) + if key == "" { + return "", app_errors.ErrUnauthorized + } + + return key, nil +} diff --git a/internal/router/router.go b/internal/router/router.go index 64eef71..0894659 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -2,9 +2,11 @@ package router import ( "embed" + "gpt-load/internal/channel" "gpt-load/internal/handler" "gpt-load/internal/middleware" "gpt-load/internal/proxy" + "gpt-load/internal/services" "gpt-load/internal/types" "io/fs" "net/http" @@ -40,6 +42,8 @@ func NewRouter( serverHandler *handler.Server, proxyServer *proxy.ProxyServer, configManager types.ConfigManager, + groupManager *services.GroupManager, + channelFactory *channel.Factory, buildFS embed.FS, indexPage []byte, ) *gin.Engine { @@ -60,8 +64,8 @@ func NewRouter( // 注册路由 registerSystemRoutes(router, serverHandler) - registerAPIRoutes(router, serverHandler, configManager) - registerProxyRoutes(router, proxyServer, configManager) + registerAPIRoutes(router, serverHandler, configManager, groupManager, channelFactory) + registerProxyRoutes(router, proxyServer, configManager, groupManager, channelFactory) registerFrontendRoutes(router, buildFS, indexPage) return router @@ -74,7 +78,13 @@ func registerSystemRoutes(router *gin.Engine, serverHandler *handler.Server) { } // registerAPIRoutes 注册API路由 -func registerAPIRoutes(router *gin.Engine, serverHandler *handler.Server, configManager types.ConfigManager) { +func registerAPIRoutes( + router *gin.Engine, + serverHandler *handler.Server, + configManager types.ConfigManager, + groupManager *services.GroupManager, + channelFactory *channel.Factory, +) { api := router.Group("/api") authConfig := configManager.GetAuthConfig() @@ -83,7 +93,7 @@ func registerAPIRoutes(router *gin.Engine, serverHandler *handler.Server, config // 认证 protectedAPI := api.Group("") - protectedAPI.Use(middleware.Auth(authConfig)) + protectedAPI.Use(middleware.Auth(authConfig, groupManager, channelFactory)) registerProtectedAPIRoutes(protectedAPI, serverHandler) } @@ -140,11 +150,17 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser } // registerProxyRoutes 注册代理路由 -func registerProxyRoutes(router *gin.Engine, proxyServer *proxy.ProxyServer, configManager types.ConfigManager) { +func registerProxyRoutes( + router *gin.Engine, + proxyServer *proxy.ProxyServer, + configManager types.ConfigManager, + groupManager *services.GroupManager, + channelFactory *channel.Factory, +) { proxyGroup := router.Group("/proxy") authConfig := configManager.GetAuthConfig() - proxyGroup.Use(middleware.Auth(authConfig)) + proxyGroup.Use(middleware.Auth(authConfig, groupManager, channelFactory)) proxyGroup.Any("/:group_name/*path", proxyServer.HandleProxy) }