feat: 代理认证中间件

This commit is contained in:
tbphp
2025-07-24 17:17:18 +08:00
parent 746c9f3108
commit abb8fa1d19
3 changed files with 66 additions and 7 deletions

View File

@@ -3,12 +3,15 @@ package middleware
import (
"fmt"
"slices"
"strings"
"time"
app_errors "gpt-load/internal/errors"
"gpt-load/internal/response"
"gpt-load/internal/services"
"gpt-load/internal/types"
"gpt-load/internal/utils"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
@@ -135,6 +138,43 @@ func Auth(authConfig types.AuthConfig) gin.HandlerFunc {
}
}
// ProxyAuth
func ProxyAuth(gm *services.GroupManager) gin.HandlerFunc {
return func(c *gin.Context) {
// Check key
key := extractAuthKey(c)
if key == "" {
response.Error(c, app_errors.ErrUnauthorized)
c.Abort()
return
}
group, err := gm.GetGroupByName(c.Param("group_name"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, "Failed to retrieve proxy group"))
c.Abort()
return
}
// Check Group keys first
groupKeys := utils.SplitAndTrim(group.ProxyKeys, ",")
if slices.Contains(groupKeys, key) {
c.Next()
return
}
// Then check System-wide keys
systemKeys := utils.SplitAndTrim(group.EffectiveConfig.ProxyKeys, ",")
if slices.Contains(systemKeys, key) {
c.Next()
return
}
response.Error(c, app_errors.ErrUnauthorized)
c.Abort()
}
}
// Recovery creates a recovery middleware with custom error handling
func Recovery() gin.HandlerFunc {
return gin.CustomRecovery(func(c *gin.Context, recovered any) {

View File

@@ -2,7 +2,6 @@ package router
import (
"embed"
"gpt-load/internal/channel"
"gpt-load/internal/handler"
"gpt-load/internal/middleware"
"gpt-load/internal/proxy"
@@ -43,7 +42,6 @@ func NewRouter(
proxyServer *proxy.ProxyServer,
configManager types.ConfigManager,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
buildFS embed.FS,
indexPage []byte,
) *gin.Engine {
@@ -66,7 +64,7 @@ func NewRouter(
// 注册路由
registerSystemRoutes(router, serverHandler)
registerAPIRoutes(router, serverHandler, configManager)
registerProxyRoutes(router, proxyServer, configManager)
registerProxyRoutes(router, proxyServer, groupManager)
registerFrontendRoutes(router, buildFS, indexPage)
return router
@@ -159,12 +157,11 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser
func registerProxyRoutes(
router *gin.Engine,
proxyServer *proxy.ProxyServer,
configManager types.ConfigManager,
groupManager *services.GroupManager,
) {
proxyGroup := router.Group("/proxy")
authConfig := configManager.GetAuthConfig()
proxyGroup.Use(middleware.Auth(authConfig))
proxyGroup.Use(middleware.ProxyAuth(groupManager))
proxyGroup.Any("/:group_name/*path", proxyServer.HandleProxy)
}

View File

@@ -1,6 +1,9 @@
package utils
import "fmt"
import (
"fmt"
"strings"
)
// MaskAPIKey masks an API key for safe logging.
func MaskAPIKey(key string) string {
@@ -18,3 +21,22 @@ func TruncateString(s string, maxLength int) string {
}
return s
}
// SplitAndTrim splits a string by a separator
func SplitAndTrim(s string, sep string) []string {
if s == "" {
return []string{}
}
parts := strings.Split(s, sep)
result := make([]string, 0, len(parts))
for _, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed != "" {
result = append(result, trimmed)
}
}
return result
}