feat: 认证key兼容处理

This commit is contained in:
tbphp
2025-07-22 15:38:55 +08:00
parent 681f0de81c
commit 76fe4d4dd3
8 changed files with 25 additions and 129 deletions

View File

@@ -60,25 +60,6 @@ func (ch *AnthropicChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bo
return false return false
} }
// ExtractKey extracts the API key from the x-api-key header.
func (ch *AnthropicChannel) ExtractKey(c *gin.Context) string {
// Check x-api-key header (Anthropic's standard)
if key := c.GetHeader("x-api-key"); key != "" {
return key
}
// Fallback to Authorization header for compatibility
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
const bearerPrefix = "Bearer "
if strings.HasPrefix(authHeader, bearerPrefix) {
return authHeader[len(bearerPrefix):]
}
}
return ""
}
// ValidateKey checks if the given API key is valid by making a messages request. // ValidateKey checks if the given API key is valid by making a messages request.
func (ch *AnthropicChannel) ValidateKey(ctx context.Context, key string) (bool, error) { func (ch *AnthropicChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
upstreamURL := ch.getUpstreamURL() upstreamURL := ch.getUpstreamURL()

View File

@@ -29,9 +29,6 @@ type ChannelProxy interface {
// IsStreamRequest checks if the request is for a streaming response, // IsStreamRequest checks if the request is for a streaming response,
IsStreamRequest(c *gin.Context, bodyBytes []byte) bool IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
// ExtractKey extracts the API key from the request.
ExtractKey(c *gin.Context) string
// ValidateKey checks if the given API key is valid. // ValidateKey checks if the given API key is valid.
ValidateKey(ctx context.Context, key string) (bool, error) ValidateKey(ctx context.Context, key string) (bool, error)
} }

View File

@@ -40,7 +40,6 @@ func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey,
req.URL.RawQuery = q.Encode() req.URL.RawQuery = q.Encode()
} }
// IsStreamRequest checks if the request is for a streaming response. // IsStreamRequest checks if the request is for a streaming response.
func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool { func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool {
path := c.Request.URL.Path path := c.Request.URL.Path
@@ -59,21 +58,6 @@ func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
return false return false
} }
// ExtractKey extracts the API key from the X-Goog-Api-Key header or the "key" query parameter.
func (ch *GeminiChannel) ExtractKey(c *gin.Context) string {
// 1. Check X-Goog-Api-Key header
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
return key
}
// 2. Check "key" query parameter
if key := c.Query("key"); key != "" {
return key
}
return ""
}
// ValidateKey checks if the given API key is valid by making a generateContent request. // ValidateKey checks if the given API key is valid by making a generateContent request.
func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) { func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
upstreamURL := ch.getUpstreamURL() upstreamURL := ch.getUpstreamURL()

View File

@@ -59,18 +59,6 @@ func (ch *OpenAIChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
return false return false
} }
// ExtractKey extracts the API key from the Authorization header.
func (ch *OpenAIChannel) ExtractKey(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
const bearerPrefix = "Bearer "
if strings.HasPrefix(authHeader, bearerPrefix) {
return authHeader[len(bearerPrefix):]
}
}
return ""
}
// ValidateKey checks if the given API key is valid by making a chat completion request. // ValidateKey checks if the given API key is valid by making a chat completion request.
func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) { func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
upstreamURL := ch.getUpstreamURL() upstreamURL := ch.getUpstreamURL()

View File

@@ -6,10 +6,8 @@ import (
"strings" "strings"
"time" "time"
"gpt-load/internal/channel"
app_errors "gpt-load/internal/errors" app_errors "gpt-load/internal/errors"
"gpt-load/internal/response" "gpt-load/internal/response"
"gpt-load/internal/services"
"gpt-load/internal/types" "gpt-load/internal/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -116,45 +114,16 @@ func CORS(config types.CORSConfig) gin.HandlerFunc {
} }
// Auth creates an authentication middleware // Auth creates an authentication middleware
func Auth( func Auth(authConfig types.AuthConfig) gin.HandlerFunc {
authConfig types.AuthConfig,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
path := c.Request.URL.Path path := c.Request.URL.Path
// Skip authentication for health endpoints
if isMonitoringEndpoint(path) { if isMonitoringEndpoint(path) {
c.Next() c.Next()
return return
} }
var key string key := extractAuthKey(c)
var err error
if strings.HasPrefix(path, "/api") {
// Handle backend API authentication
key = extractApiKey(c)
} else if strings.HasPrefix(path, "/proxy/") {
// Handle proxy authentication
key, err = extractProxyKey(c, groupManager, channelFactory)
if err != nil {
// The error from extractProxyKey is already an APIError
if apiErr, ok := err.(*app_errors.APIError); ok {
response.Error(c, apiErr)
} else {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, err.Error()))
}
c.Abort()
return
}
} else {
// For any other paths, deny access by default
response.Error(c, app_errors.ErrResourceNotFound)
c.Abort()
return
}
if key == "" || key != authConfig.Key { if key == "" || key != authConfig.Key {
response.Error(c, app_errors.ErrUnauthorized) response.Error(c, app_errors.ErrUnauthorized)
@@ -162,8 +131,6 @@ func Auth(
return return
} }
// Key is extracted, but validation is handled by the proxy logic itself.
// For the backend API, we've already validated it.
c.Next() c.Next()
} }
} }
@@ -227,8 +194,10 @@ func isMonitoringEndpoint(path string) bool {
return false return false
} }
// extractBearerKey extracts a key from the "Authorization: Bearer <key>" header. // extractAuthKey extracts a auth key.
func extractApiKey(c *gin.Context) string { func extractAuthKey(c *gin.Context) string {
// Bearer token
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
if authHeader != "" { if authHeader != "" {
const bearerPrefix = "Bearer " const bearerPrefix = "Bearer "
@@ -237,39 +206,20 @@ func extractApiKey(c *gin.Context) string {
} }
} }
authKey := c.Query("auth_key") // X-Api-Key
if authKey != "" { if key := c.GetHeader("X-Api-Key"); key != "" {
return authKey return key
}
// X-Goog-Api-Key
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
return key
}
// Query key
if key := c.Query("key"); key != "" {
return key
} }
return "" return ""
} }
// extractProxyKey handles key extraction for proxy routes.
func extractProxyKey(
c *gin.Context,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
) (string, error) {
groupName := c.Param("group_name")
if groupName == "" {
return "", app_errors.NewAPIError(app_errors.ErrBadRequest, "Group name is missing in the URL path")
}
group, err := groupManager.GetGroupByName(groupName)
if err != nil {
return "", app_errors.NewAPIError(app_errors.ErrResourceNotFound, fmt.Sprintf("Group '%s' not found", groupName))
}
channel, err := channelFactory.GetChannel(group)
if err != nil {
return "", app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s'", groupName))
}
key := channel.ExtractKey(c)
if key == "" {
return "", app_errors.ErrUnauthorized
}
return key, nil
}

View File

@@ -65,8 +65,8 @@ func NewRouter(
// 注册路由 // 注册路由
registerSystemRoutes(router, serverHandler) registerSystemRoutes(router, serverHandler)
registerAPIRoutes(router, serverHandler, configManager, groupManager, channelFactory) registerAPIRoutes(router, serverHandler, configManager)
registerProxyRoutes(router, proxyServer, configManager, groupManager, channelFactory) registerProxyRoutes(router, proxyServer, configManager)
registerFrontendRoutes(router, buildFS, indexPage) registerFrontendRoutes(router, buildFS, indexPage)
return router return router
@@ -82,8 +82,6 @@ func registerAPIRoutes(
router *gin.Engine, router *gin.Engine,
serverHandler *handler.Server, serverHandler *handler.Server,
configManager types.ConfigManager, configManager types.ConfigManager,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
) { ) {
api := router.Group("/api") api := router.Group("/api")
authConfig := configManager.GetAuthConfig() authConfig := configManager.GetAuthConfig()
@@ -93,7 +91,7 @@ func registerAPIRoutes(
// 认证 // 认证
protectedAPI := api.Group("") protectedAPI := api.Group("")
protectedAPI.Use(middleware.Auth(authConfig, groupManager, channelFactory)) protectedAPI.Use(middleware.Auth(authConfig))
registerProtectedAPIRoutes(protectedAPI, serverHandler) registerProtectedAPIRoutes(protectedAPI, serverHandler)
} }
@@ -162,13 +160,11 @@ func registerProxyRoutes(
router *gin.Engine, router *gin.Engine,
proxyServer *proxy.ProxyServer, proxyServer *proxy.ProxyServer,
configManager types.ConfigManager, configManager types.ConfigManager,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
) { ) {
proxyGroup := router.Group("/proxy") proxyGroup := router.Group("/proxy")
authConfig := configManager.GetAuthConfig() authConfig := configManager.GetAuthConfig()
proxyGroup.Use(middleware.Auth(authConfig, groupManager, channelFactory)) proxyGroup.Use(middleware.Auth(authConfig))
proxyGroup.Any("/:group_name/*path", proxyServer.HandleProxy) proxyGroup.Any("/:group_name/*path", proxyServer.HandleProxy)
} }

View File

@@ -157,7 +157,7 @@ export const keysApi = {
const params = new URLSearchParams({ const params = new URLSearchParams({
group_id: groupId.toString(), group_id: groupId.toString(),
auth_key: authKey, key: authKey,
}); });
if (status !== "all") { if (status !== "all") {

View File

@@ -31,7 +31,7 @@ export const logApi = {
{} as Record<string, string> {} as Record<string, string>
) )
); );
queryParams.append("auth_key", authKey); queryParams.append("key", authKey);
const url = `${http.defaults.baseURL}/logs/export?${queryParams.toString()}`; const url = `${http.defaults.baseURL}/logs/export?${queryParams.toString()}`;