feat: 代理认证中间件
This commit is contained in:
@@ -3,12 +3,15 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
app_errors "gpt-load/internal/errors"
|
app_errors "gpt-load/internal/errors"
|
||||||
"gpt-load/internal/response"
|
"gpt-load/internal/response"
|
||||||
|
"gpt-load/internal/services"
|
||||||
"gpt-load/internal/types"
|
"gpt-load/internal/types"
|
||||||
|
"gpt-load/internal/utils"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -135,6 +138,43 @@ func Auth(authConfig types.AuthConfig) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProxyAuth
|
||||||
|
func ProxyAuth(gm *services.GroupManager) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// Check key
|
||||||
|
key := extractAuthKey(c)
|
||||||
|
if key == "" {
|
||||||
|
response.Error(c, app_errors.ErrUnauthorized)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
group, err := gm.GetGroupByName(c.Param("group_name"))
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, "Failed to retrieve proxy group"))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Group keys first
|
||||||
|
groupKeys := utils.SplitAndTrim(group.ProxyKeys, ",")
|
||||||
|
if slices.Contains(groupKeys, key) {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then check System-wide keys
|
||||||
|
systemKeys := utils.SplitAndTrim(group.EffectiveConfig.ProxyKeys, ",")
|
||||||
|
if slices.Contains(systemKeys, key) {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Error(c, app_errors.ErrUnauthorized)
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Recovery creates a recovery middleware with custom error handling
|
// Recovery creates a recovery middleware with custom error handling
|
||||||
func Recovery() gin.HandlerFunc {
|
func Recovery() gin.HandlerFunc {
|
||||||
return gin.CustomRecovery(func(c *gin.Context, recovered any) {
|
return gin.CustomRecovery(func(c *gin.Context, recovered any) {
|
||||||
|
@@ -2,7 +2,6 @@ package router
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
"gpt-load/internal/channel"
|
|
||||||
"gpt-load/internal/handler"
|
"gpt-load/internal/handler"
|
||||||
"gpt-load/internal/middleware"
|
"gpt-load/internal/middleware"
|
||||||
"gpt-load/internal/proxy"
|
"gpt-load/internal/proxy"
|
||||||
@@ -43,7 +42,6 @@ func NewRouter(
|
|||||||
proxyServer *proxy.ProxyServer,
|
proxyServer *proxy.ProxyServer,
|
||||||
configManager types.ConfigManager,
|
configManager types.ConfigManager,
|
||||||
groupManager *services.GroupManager,
|
groupManager *services.GroupManager,
|
||||||
channelFactory *channel.Factory,
|
|
||||||
buildFS embed.FS,
|
buildFS embed.FS,
|
||||||
indexPage []byte,
|
indexPage []byte,
|
||||||
) *gin.Engine {
|
) *gin.Engine {
|
||||||
@@ -66,7 +64,7 @@ func NewRouter(
|
|||||||
// 注册路由
|
// 注册路由
|
||||||
registerSystemRoutes(router, serverHandler)
|
registerSystemRoutes(router, serverHandler)
|
||||||
registerAPIRoutes(router, serverHandler, configManager)
|
registerAPIRoutes(router, serverHandler, configManager)
|
||||||
registerProxyRoutes(router, proxyServer, configManager)
|
registerProxyRoutes(router, proxyServer, groupManager)
|
||||||
registerFrontendRoutes(router, buildFS, indexPage)
|
registerFrontendRoutes(router, buildFS, indexPage)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
@@ -159,12 +157,11 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser
|
|||||||
func registerProxyRoutes(
|
func registerProxyRoutes(
|
||||||
router *gin.Engine,
|
router *gin.Engine,
|
||||||
proxyServer *proxy.ProxyServer,
|
proxyServer *proxy.ProxyServer,
|
||||||
configManager types.ConfigManager,
|
groupManager *services.GroupManager,
|
||||||
) {
|
) {
|
||||||
proxyGroup := router.Group("/proxy")
|
proxyGroup := router.Group("/proxy")
|
||||||
authConfig := configManager.GetAuthConfig()
|
|
||||||
|
|
||||||
proxyGroup.Use(middleware.Auth(authConfig))
|
proxyGroup.Use(middleware.ProxyAuth(groupManager))
|
||||||
|
|
||||||
proxyGroup.Any("/:group_name/*path", proxyServer.HandleProxy)
|
proxyGroup.Any("/:group_name/*path", proxyServer.HandleProxy)
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,9 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
// MaskAPIKey masks an API key for safe logging.
|
// MaskAPIKey masks an API key for safe logging.
|
||||||
func MaskAPIKey(key string) string {
|
func MaskAPIKey(key string) string {
|
||||||
@@ -18,3 +21,22 @@ func TruncateString(s string, maxLength int) string {
|
|||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SplitAndTrim splits a string by a separator
|
||||||
|
func SplitAndTrim(s string, sep string) []string {
|
||||||
|
if s == "" {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(s, sep)
|
||||||
|
result := make([]string, 0, len(parts))
|
||||||
|
|
||||||
|
for _, part := range parts {
|
||||||
|
trimmed := strings.TrimSpace(part)
|
||||||
|
if trimmed != "" {
|
||||||
|
result = append(result, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user