From abb8fa1d1955763eb793f3c67b744c8046c2e4e9 Mon Sep 17 00:00:00 2001 From: tbphp Date: Thu, 24 Jul 2025 17:17:18 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BB=A3=E7=90=86=E8=AE=A4=E8=AF=81?= =?UTF-8?q?=E4=B8=AD=E9=97=B4=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/middleware/middleware.go | 40 +++++++++++++++++++++++++++++++ internal/router/router.go | 9 +++---- internal/utils/string_utils.go | 24 ++++++++++++++++++- 3 files changed, 66 insertions(+), 7 deletions(-) diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index a1f10dc..a3a971e 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -3,12 +3,15 @@ package middleware import ( "fmt" + "slices" "strings" "time" app_errors "gpt-load/internal/errors" "gpt-load/internal/response" + "gpt-load/internal/services" "gpt-load/internal/types" + "gpt-load/internal/utils" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" @@ -135,6 +138,43 @@ func Auth(authConfig types.AuthConfig) gin.HandlerFunc { } } +// ProxyAuth +func ProxyAuth(gm *services.GroupManager) gin.HandlerFunc { + return func(c *gin.Context) { + // Check key + key := extractAuthKey(c) + if key == "" { + response.Error(c, app_errors.ErrUnauthorized) + c.Abort() + return + } + + group, err := gm.GetGroupByName(c.Param("group_name")) + if err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, "Failed to retrieve proxy group")) + c.Abort() + return + } + + // Check Group keys first + groupKeys := utils.SplitAndTrim(group.ProxyKeys, ",") + if slices.Contains(groupKeys, key) { + c.Next() + return + } + + // Then check System-wide keys + systemKeys := utils.SplitAndTrim(group.EffectiveConfig.ProxyKeys, ",") + if slices.Contains(systemKeys, key) { + c.Next() + return + } + + response.Error(c, app_errors.ErrUnauthorized) + c.Abort() + } +} + // Recovery creates a recovery middleware with custom error handling func Recovery() gin.HandlerFunc { return gin.CustomRecovery(func(c *gin.Context, recovered any) { diff --git a/internal/router/router.go b/internal/router/router.go index f5703ec..fa9989c 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -2,7 +2,6 @@ package router import ( "embed" - "gpt-load/internal/channel" "gpt-load/internal/handler" "gpt-load/internal/middleware" "gpt-load/internal/proxy" @@ -43,7 +42,6 @@ func NewRouter( proxyServer *proxy.ProxyServer, configManager types.ConfigManager, groupManager *services.GroupManager, - channelFactory *channel.Factory, buildFS embed.FS, indexPage []byte, ) *gin.Engine { @@ -66,7 +64,7 @@ func NewRouter( // 注册路由 registerSystemRoutes(router, serverHandler) registerAPIRoutes(router, serverHandler, configManager) - registerProxyRoutes(router, proxyServer, configManager) + registerProxyRoutes(router, proxyServer, groupManager) registerFrontendRoutes(router, buildFS, indexPage) return router @@ -159,12 +157,11 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser func registerProxyRoutes( router *gin.Engine, proxyServer *proxy.ProxyServer, - configManager types.ConfigManager, + groupManager *services.GroupManager, ) { proxyGroup := router.Group("/proxy") - authConfig := configManager.GetAuthConfig() - proxyGroup.Use(middleware.Auth(authConfig)) + proxyGroup.Use(middleware.ProxyAuth(groupManager)) proxyGroup.Any("/:group_name/*path", proxyServer.HandleProxy) } diff --git a/internal/utils/string_utils.go b/internal/utils/string_utils.go index 36f12b9..ebf4ad7 100644 --- a/internal/utils/string_utils.go +++ b/internal/utils/string_utils.go @@ -1,6 +1,9 @@ package utils -import "fmt" +import ( + "fmt" + "strings" +) // MaskAPIKey masks an API key for safe logging. func MaskAPIKey(key string) string { @@ -18,3 +21,22 @@ func TruncateString(s string, maxLength int) string { } return s } + +// SplitAndTrim splits a string by a separator +func SplitAndTrim(s string, sep string) []string { + if s == "" { + return []string{} + } + + parts := strings.Split(s, sep) + result := make([]string, 0, len(parts)) + + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed != "" { + result = append(result, trimmed) + } + } + + return result +}