feat: 验证key分离

This commit is contained in:
tbphp
2025-07-11 22:39:28 +08:00
parent 0cf590111f
commit 70554b8fe5
5 changed files with 175 additions and 85 deletions

View File

@@ -14,15 +14,6 @@ type ChannelProxy interface {
// BuildUpstreamURL constructs the target URL for the upstream service. // BuildUpstreamURL constructs the target URL for the upstream service.
BuildUpstreamURL(originalURL *url.URL, group *models.Group) (string, error) BuildUpstreamURL(originalURL *url.URL, group *models.Group) (string, error)
// ModifyRequest allows the channel to add specific headers or modify the request
ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group)
// IsStreamRequest checks if the request is for a streaming response,
IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
// ValidateKey checks if the given API key is valid.
ValidateKey(ctx context.Context, key string) (bool, error)
// IsConfigStale checks if the channel's configuration is stale compared to the provided group. // IsConfigStale checks if the channel's configuration is stale compared to the provided group.
IsConfigStale(group *models.Group) bool IsConfigStale(group *models.Group) bool
@@ -31,4 +22,16 @@ type ChannelProxy interface {
// GetStreamClient returns the client for streaming requests. // GetStreamClient returns the client for streaming requests.
GetStreamClient() *http.Client GetStreamClient() *http.Client
// ModifyRequest allows the channel to add specific headers or modify the request
ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group)
// IsStreamRequest checks if the request is for a streaming response,
IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
// ExtractKey extracts the API key from the request.
ExtractKey(c *gin.Context) string
// ValidateKey checks if the given API key is valid.
ValidateKey(ctx context.Context, key string) (bool, error)
} }

View File

@@ -40,6 +40,40 @@ func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey,
req.URL.RawQuery = q.Encode() req.URL.RawQuery = q.Encode()
} }
// IsStreamRequest checks if the request is for a streaming response.
func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool {
path := c.Request.URL.Path
if strings.HasSuffix(path, ":streamGenerateContent") {
return true
}
// Also check for standard streaming indicators as a fallback.
if strings.Contains(c.GetHeader("Accept"), "text/event-stream") {
return true
}
if c.Query("stream") == "true" {
return true
}
return false
}
// ExtractKey extracts the API key from the X-Goog-Api-Key header or the "key" query parameter.
func (ch *GeminiChannel) ExtractKey(c *gin.Context) string {
// 1. Check X-Goog-Api-Key header
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
return key
}
// 2. Check "key" query parameter
if key := c.Query("key"); key != "" {
return key
}
return ""
}
// ValidateKey checks if the given API key is valid by making a generateContent request. // ValidateKey checks if the given API key is valid by making a generateContent request.
func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) { func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
upstreamURL := ch.getUpstreamURL() upstreamURL := ch.getUpstreamURL()
@@ -89,22 +123,3 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, err
return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError) return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
} }
// IsStreamRequest checks if the request is for a streaming response.
// For Gemini, this is primarily determined by the URL path.
func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool {
path := c.Request.URL.Path
if strings.HasSuffix(path, ":streamGenerateContent") {
return true
}
// Also check for standard streaming indicators as a fallback.
if strings.Contains(c.GetHeader("Accept"), "text/event-stream") {
return true
}
if c.Query("stream") == "true" {
return true
}
return false
}

View File

@@ -33,12 +33,44 @@ func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
}, nil }, nil
} }
// ModifyRequest sets the Authorization header for the OpenAI service. // ModifyRequest sets the Authorization header for the OpenAI service.
func (ch *OpenAIChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) { func (ch *OpenAIChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) {
req.Header.Set("Authorization", "Bearer "+apiKey.KeyValue) req.Header.Set("Authorization", "Bearer "+apiKey.KeyValue)
} }
// IsStreamRequest checks if the request is for a streaming response using the pre-read body.
func (ch *OpenAIChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool {
if strings.Contains(c.GetHeader("Accept"), "text/event-stream") {
return true
}
if c.Query("stream") == "true" {
return true
}
type streamPayload struct {
Stream bool `json:"stream"`
}
var p streamPayload
if err := json.Unmarshal(bodyBytes, &p); err == nil {
return p.Stream
}
return false
}
// ExtractKey extracts the API key from the Authorization header.
func (ch *OpenAIChannel) ExtractKey(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
const bearerPrefix = "Bearer "
if strings.HasPrefix(authHeader, bearerPrefix) {
return authHeader[len(bearerPrefix):]
}
}
return ""
}
// ValidateKey checks if the given API key is valid by making a chat completion request. // ValidateKey checks if the given API key is valid by making a chat completion request.
func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) { func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
upstreamURL := ch.getUpstreamURL() upstreamURL := ch.getUpstreamURL()
@@ -90,24 +122,3 @@ func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, err
return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError) return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
} }
// IsStreamRequest checks if the request is for a streaming response using the pre-read body.
func (ch *OpenAIChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool {
if strings.Contains(c.GetHeader("Accept"), "text/event-stream") {
return true
}
if c.Query("stream") == "true" {
return true
}
type streamPayload struct {
Stream bool `json:"stream"`
}
var p streamPayload
if err := json.Unmarshal(bodyBytes, &p); err == nil {
return p.Stream
}
return false
}

View File

@@ -8,7 +8,8 @@ import (
"gpt-load/internal/response" "gpt-load/internal/response"
"gpt-load/internal/types" "gpt-load/internal/types"
"gpt-load/internal/channel"
"gpt-load/internal/services"
app_errors "gpt-load/internal/errors" app_errors "gpt-load/internal/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -115,31 +116,59 @@ func CORS(config types.CORSConfig) gin.HandlerFunc {
} }
// Auth creates an authentication middleware // Auth creates an authentication middleware
func Auth(config types.AuthConfig) gin.HandlerFunc { func Auth(
authConfig types.AuthConfig,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// Skip authentication for management endpoints
path := c.Request.URL.Path path := c.Request.URL.Path
if path == "/health" || path == "/stats" {
// Skip authentication for health/stats endpoints
if isMonitoringEndpoint(path) {
c.Next() c.Next()
return return
} }
// Extract key from multiple sources var key string
key := extractKey(c) var err error
if strings.HasPrefix(path, "/api") {
// Handle backend API authentication
key = extractBearerKey(c)
if key == "" || key != authConfig.Key {
response.Error(c, app_errors.ErrUnauthorized)
c.Abort()
return
}
} else if strings.HasPrefix(path, "/proxy/") {
// Handle proxy authentication
key, err = extractProxyKey(c, groupManager, channelFactory)
if err != nil {
// The error from extractProxyKey is already an APIError
if apiErr, ok := err.(*app_errors.APIError); ok {
response.Error(c, apiErr)
} else {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, err.Error()))
}
c.Abort()
return
}
} else {
// For any other paths, deny access by default
response.Error(c, app_errors.ErrResourceNotFound)
c.Abort()
return
}
if key == "" { if key == "" {
response.Error(c, app_errors.ErrUnauthorized) response.Error(c, app_errors.ErrUnauthorized)
c.Abort() c.Abort()
return return
} }
// Validate key // Key is extracted, but validation is handled by the proxy logic itself.
if key != config.Key { // For the backend API, we've already validated it.
response.Error(c, app_errors.ErrUnauthorized)
c.Abort()
return
}
c.Next() c.Next()
} }
} }
@@ -203,10 +232,8 @@ func isMonitoringEndpoint(path string) bool {
return false return false
} }
// extractKey extracts the API key from the request, checking the Authorization header, // extractBearerKey extracts a key from the "Authorization: Bearer <key>" header.
// the X-Goog-Api-Key header, and the "key" query parameter. func extractBearerKey(c *gin.Context) string {
func extractKey(c *gin.Context) string {
// 1. Check Authorization header
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
if authHeader != "" { if authHeader != "" {
const bearerPrefix = "Bearer " const bearerPrefix = "Bearer "
@@ -214,16 +241,34 @@ func extractKey(c *gin.Context) string {
return authHeader[len(bearerPrefix):] return authHeader[len(bearerPrefix):]
} }
} }
// 2. Check X-Goog-Api-Key header
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
return key
}
// 3. Check "key" query parameter
if key := c.Query("key"); key != "" {
return key
}
return "" return ""
} }
// extractProxyKey handles key extraction for proxy routes.
func extractProxyKey(
c *gin.Context,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
) (string, error) {
groupName := c.Param("group_name")
if groupName == "" {
return "", app_errors.NewAPIError(app_errors.ErrBadRequest, "Group name is missing in the URL path")
}
group, err := groupManager.GetGroupByName(groupName)
if err != nil {
return "", app_errors.NewAPIError(app_errors.ErrResourceNotFound, fmt.Sprintf("Group '%s' not found", groupName))
}
channel, err := channelFactory.GetChannel(group)
if err != nil {
return "", app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s'", groupName))
}
key := channel.ExtractKey(c)
if key == "" {
return "", app_errors.ErrUnauthorized
}
return key, nil
}

View File

@@ -2,9 +2,11 @@ package router
import ( import (
"embed" "embed"
"gpt-load/internal/channel"
"gpt-load/internal/handler" "gpt-load/internal/handler"
"gpt-load/internal/middleware" "gpt-load/internal/middleware"
"gpt-load/internal/proxy" "gpt-load/internal/proxy"
"gpt-load/internal/services"
"gpt-load/internal/types" "gpt-load/internal/types"
"io/fs" "io/fs"
"net/http" "net/http"
@@ -40,6 +42,8 @@ func NewRouter(
serverHandler *handler.Server, serverHandler *handler.Server,
proxyServer *proxy.ProxyServer, proxyServer *proxy.ProxyServer,
configManager types.ConfigManager, configManager types.ConfigManager,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
buildFS embed.FS, buildFS embed.FS,
indexPage []byte, indexPage []byte,
) *gin.Engine { ) *gin.Engine {
@@ -60,8 +64,8 @@ func NewRouter(
// 注册路由 // 注册路由
registerSystemRoutes(router, serverHandler) registerSystemRoutes(router, serverHandler)
registerAPIRoutes(router, serverHandler, configManager) registerAPIRoutes(router, serverHandler, configManager, groupManager, channelFactory)
registerProxyRoutes(router, proxyServer, configManager) registerProxyRoutes(router, proxyServer, configManager, groupManager, channelFactory)
registerFrontendRoutes(router, buildFS, indexPage) registerFrontendRoutes(router, buildFS, indexPage)
return router return router
@@ -74,7 +78,13 @@ func registerSystemRoutes(router *gin.Engine, serverHandler *handler.Server) {
} }
// registerAPIRoutes 注册API路由 // registerAPIRoutes 注册API路由
func registerAPIRoutes(router *gin.Engine, serverHandler *handler.Server, configManager types.ConfigManager) { func registerAPIRoutes(
router *gin.Engine,
serverHandler *handler.Server,
configManager types.ConfigManager,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
) {
api := router.Group("/api") api := router.Group("/api")
authConfig := configManager.GetAuthConfig() authConfig := configManager.GetAuthConfig()
@@ -83,7 +93,7 @@ func registerAPIRoutes(router *gin.Engine, serverHandler *handler.Server, config
// 认证 // 认证
protectedAPI := api.Group("") protectedAPI := api.Group("")
protectedAPI.Use(middleware.Auth(authConfig)) protectedAPI.Use(middleware.Auth(authConfig, groupManager, channelFactory))
registerProtectedAPIRoutes(protectedAPI, serverHandler) registerProtectedAPIRoutes(protectedAPI, serverHandler)
} }
@@ -140,11 +150,17 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser
} }
// registerProxyRoutes 注册代理路由 // registerProxyRoutes 注册代理路由
func registerProxyRoutes(router *gin.Engine, proxyServer *proxy.ProxyServer, configManager types.ConfigManager) { func registerProxyRoutes(
router *gin.Engine,
proxyServer *proxy.ProxyServer,
configManager types.ConfigManager,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
) {
proxyGroup := router.Group("/proxy") proxyGroup := router.Group("/proxy")
authConfig := configManager.GetAuthConfig() authConfig := configManager.GetAuthConfig()
proxyGroup.Use(middleware.Auth(authConfig)) proxyGroup.Use(middleware.Auth(authConfig, groupManager, channelFactory))
proxyGroup.Any("/:group_name/*path", proxyServer.HandleProxy) proxyGroup.Any("/:group_name/*path", proxyServer.HandleProxy)
} }