Merge branch 'main' into feat-fe

This commit is contained in:
tbphp
2025-07-23 14:46:22 +08:00
16 changed files with 198 additions and 208 deletions

View File

@@ -9,6 +9,7 @@ import (
"gpt-load/internal/models" "gpt-load/internal/models"
"io" "io"
"net/http" "net/http"
"net/url"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -60,25 +61,6 @@ func (ch *AnthropicChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bo
return false return false
} }
// ExtractKey extracts the API key from the x-api-key header.
func (ch *AnthropicChannel) ExtractKey(c *gin.Context) string {
// Check x-api-key header (Anthropic's standard)
if key := c.GetHeader("x-api-key"); key != "" {
return key
}
// Fallback to Authorization header for compatibility
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
const bearerPrefix = "Bearer "
if strings.HasPrefix(authHeader, bearerPrefix) {
return authHeader[len(bearerPrefix):]
}
}
return ""
}
// ValidateKey checks if the given API key is valid by making a messages request. // ValidateKey checks if the given API key is valid by making a messages request.
func (ch *AnthropicChannel) ValidateKey(ctx context.Context, key string) (bool, error) { func (ch *AnthropicChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
upstreamURL := ch.getUpstreamURL() upstreamURL := ch.getUpstreamURL()
@@ -86,7 +68,14 @@ func (ch *AnthropicChannel) ValidateKey(ctx context.Context, key string) (bool,
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name) return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
} }
reqURL := upstreamURL.String() + "/v1/messages" validationEndpoint := ch.ValidationEndpoint
if validationEndpoint == "" {
validationEndpoint = "/v1/messages"
}
reqURL, err := url.JoinPath(upstreamURL.String(), validationEndpoint)
if err != nil {
return false, fmt.Errorf("failed to join upstream URL and validation endpoint: %w", err)
}
// Use a minimal, low-cost payload for validation // Use a minimal, low-cost payload for validation
payload := gin.H{ payload := gin.H{

View File

@@ -23,12 +23,13 @@ type UpstreamInfo struct {
// BaseChannel provides common functionality for channel proxies. // BaseChannel provides common functionality for channel proxies.
type BaseChannel struct { type BaseChannel struct {
Name string Name string
Upstreams []UpstreamInfo Upstreams []UpstreamInfo
HTTPClient *http.Client HTTPClient *http.Client
StreamClient *http.Client StreamClient *http.Client
TestModel string TestModel string
upstreamLock sync.Mutex ValidationEndpoint string
upstreamLock sync.Mutex
// Cached fields from the group for stale check // Cached fields from the group for stale check
channelType string channelType string
@@ -96,6 +97,9 @@ func (b *BaseChannel) IsConfigStale(group *models.Group) bool {
if b.TestModel != group.TestModel { if b.TestModel != group.TestModel {
return true return true
} }
if b.ValidationEndpoint != group.ValidationEndpoint {
return true
}
if !bytes.Equal(b.groupUpstreams, group.Upstreams) { if !bytes.Equal(b.groupUpstreams, group.Upstreams) {
return true return true
} }

View File

@@ -29,9 +29,6 @@ type ChannelProxy interface {
// IsStreamRequest checks if the request is for a streaming response, // IsStreamRequest checks if the request is for a streaming response,
IsStreamRequest(c *gin.Context, bodyBytes []byte) bool IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
// ExtractKey extracts the API key from the request.
ExtractKey(c *gin.Context) string
// ValidateKey checks if the given API key is valid. // ValidateKey checks if the given API key is valid.
ValidateKey(ctx context.Context, key string) (bool, error) ValidateKey(ctx context.Context, key string) (bool, error)
} }

View File

@@ -140,13 +140,14 @@ func (f *Factory) newBaseChannel(name string, group *models.Group) (*BaseChannel
streamClient := f.clientManager.GetClient(&streamConfig) streamClient := f.clientManager.GetClient(&streamConfig)
return &BaseChannel{ return &BaseChannel{
Name: name, Name: name,
Upstreams: upstreamInfos, Upstreams: upstreamInfos,
HTTPClient: httpClient, HTTPClient: httpClient,
StreamClient: streamClient, StreamClient: streamClient,
TestModel: group.TestModel, TestModel: group.TestModel,
channelType: group.ChannelType, ValidationEndpoint: group.ValidationEndpoint,
groupUpstreams: group.Upstreams, channelType: group.ChannelType,
effectiveConfig: &group.EffectiveConfig, groupUpstreams: group.Upstreams,
effectiveConfig: &group.EffectiveConfig,
}, nil }, nil
} }

View File

@@ -9,6 +9,7 @@ import (
"gpt-load/internal/models" "gpt-load/internal/models"
"io" "io"
"net/http" "net/http"
"net/url"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -40,7 +41,6 @@ func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey,
req.URL.RawQuery = q.Encode() req.URL.RawQuery = q.Encode()
} }
// IsStreamRequest checks if the request is for a streaming response. // IsStreamRequest checks if the request is for a streaming response.
func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool { func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool {
path := c.Request.URL.Path path := c.Request.URL.Path
@@ -59,21 +59,6 @@ func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
return false return false
} }
// ExtractKey extracts the API key from the X-Goog-Api-Key header or the "key" query parameter.
func (ch *GeminiChannel) ExtractKey(c *gin.Context) string {
// 1. Check X-Goog-Api-Key header
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
return key
}
// 2. Check "key" query parameter
if key := c.Query("key"); key != "" {
return key
}
return ""
}
// ValidateKey checks if the given API key is valid by making a generateContent request. // ValidateKey checks if the given API key is valid by making a generateContent request.
func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) { func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
upstreamURL := ch.getUpstreamURL() upstreamURL := ch.getUpstreamURL()
@@ -81,7 +66,12 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, err
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name) return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
} }
reqURL := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", upstreamURL.String(), ch.TestModel, key) // Safely join the path segments
reqURL, err := url.JoinPath(upstreamURL.String(), "v1beta", "models", ch.TestModel+":generateContent")
if err != nil {
return false, fmt.Errorf("failed to create gemini validation path: %w", err)
}
reqURL += "?key=" + key
payload := gin.H{ payload := gin.H{
"contents": []gin.H{ "contents": []gin.H{

View File

@@ -9,6 +9,7 @@ import (
"gpt-load/internal/models" "gpt-load/internal/models"
"io" "io"
"net/http" "net/http"
"net/url"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -59,18 +60,6 @@ func (ch *OpenAIChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
return false return false
} }
// ExtractKey extracts the API key from the Authorization header.
func (ch *OpenAIChannel) ExtractKey(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
const bearerPrefix = "Bearer "
if strings.HasPrefix(authHeader, bearerPrefix) {
return authHeader[len(bearerPrefix):]
}
}
return ""
}
// ValidateKey checks if the given API key is valid by making a chat completion request. // ValidateKey checks if the given API key is valid by making a chat completion request.
func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) { func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
upstreamURL := ch.getUpstreamURL() upstreamURL := ch.getUpstreamURL()
@@ -78,7 +67,14 @@ func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, err
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name) return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
} }
reqURL := upstreamURL.String() + "/v1/chat/completions" validationEndpoint := ch.ValidationEndpoint
if validationEndpoint == "" {
validationEndpoint = "/v1/chat/completions"
}
reqURL, err := url.JoinPath(upstreamURL.String(), validationEndpoint)
if err != nil {
return false, fmt.Errorf("failed to join upstream URL and validation endpoint: %w", err)
}
// Use a minimal, low-cost payload for validation // Use a minimal, low-cost payload for validation
payload := gin.H{ payload := gin.H{

View File

@@ -88,6 +88,20 @@ func isValidGroupName(name string) bool {
return match return match
} }
// isValidValidationEndpoint checks if the validation endpoint is a valid path.
func isValidValidationEndpoint(endpoint string) bool {
if endpoint == "" {
return true
}
if !strings.HasPrefix(endpoint, "/") {
return false
}
if strings.Contains(endpoint, "://") {
return false
}
return true
}
// validateAndCleanConfig validates the group config against the GroupConfig struct and system-defined rules. // validateAndCleanConfig validates the group config against the GroupConfig struct and system-defined rules.
func (s *Server) validateAndCleanConfig(configMap map[string]any) (map[string]any, error) { func (s *Server) validateAndCleanConfig(configMap map[string]any) (map[string]any, error) {
if configMap == nil { if configMap == nil {
@@ -180,16 +194,23 @@ func (s *Server) CreateGroup(c *gin.Context) {
return return
} }
validationEndpoint := strings.TrimSpace(req.ValidationEndpoint)
if !isValidValidationEndpoint(validationEndpoint) {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "无效的测试路径。如果提供,必须是以 / 开头的有效路径且不能是完整的URL。"))
return
}
group := models.Group{ group := models.Group{
Name: name, Name: name,
DisplayName: strings.TrimSpace(req.DisplayName), DisplayName: strings.TrimSpace(req.DisplayName),
Description: strings.TrimSpace(req.Description), Description: strings.TrimSpace(req.Description),
Upstreams: cleanedUpstreams, Upstreams: cleanedUpstreams,
ChannelType: channelType, ChannelType: channelType,
Sort: req.Sort, Sort: req.Sort,
TestModel: testModel, TestModel: testModel,
ParamOverrides: req.ParamOverrides, ValidationEndpoint: validationEndpoint,
Config: cleanedConfig, ParamOverrides: req.ParamOverrides,
Config: cleanedConfig,
} }
if err := s.DB.Create(&group).Error; err != nil { if err := s.DB.Create(&group).Error; err != nil {
@@ -222,15 +243,16 @@ func (s *Server) ListGroups(c *gin.Context) {
// GroupUpdateRequest defines the payload for updating a group. // GroupUpdateRequest defines the payload for updating a group.
// Using a dedicated struct avoids issues with zero values being ignored by GORM's Update. // Using a dedicated struct avoids issues with zero values being ignored by GORM's Update.
type GroupUpdateRequest struct { type GroupUpdateRequest struct {
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
DisplayName *string `json:"display_name,omitempty"` DisplayName *string `json:"display_name,omitempty"`
Description *string `json:"description,omitempty"` Description *string `json:"description,omitempty"`
Upstreams json.RawMessage `json:"upstreams"` Upstreams json.RawMessage `json:"upstreams"`
ChannelType *string `json:"channel_type,omitempty"` ChannelType *string `json:"channel_type,omitempty"`
Sort *int `json:"sort"` Sort *int `json:"sort"`
TestModel string `json:"test_model"` TestModel string `json:"test_model"`
ParamOverrides map[string]any `json:"param_overrides"` ValidationEndpoint *string `json:"validation_endpoint,omitempty"`
Config map[string]any `json:"config"` ParamOverrides map[string]any `json:"param_overrides"`
Config map[string]any `json:"config"`
} }
// UpdateGroup handles updating an existing group. // UpdateGroup handles updating an existing group.
@@ -311,6 +333,15 @@ func (s *Server) UpdateGroup(c *gin.Context) {
if req.ParamOverrides != nil { if req.ParamOverrides != nil {
group.ParamOverrides = req.ParamOverrides group.ParamOverrides = req.ParamOverrides
} }
if req.ValidationEndpoint != nil {
validationEndpoint := strings.TrimSpace(*req.ValidationEndpoint)
if !isValidValidationEndpoint(validationEndpoint) {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "无效的测试路径。如果提供,必须是以 / 开头的有效路径且不能是完整的URL。"))
return
}
group.ValidationEndpoint = validationEndpoint
}
if req.Config != nil { if req.Config != nil {
cleanedConfig, err := s.validateAndCleanConfig(req.Config) cleanedConfig, err := s.validateAndCleanConfig(req.Config)
if err != nil { if err != nil {
@@ -339,20 +370,21 @@ func (s *Server) UpdateGroup(c *gin.Context) {
// GroupResponse defines the structure for a group response, excluding sensitive or large fields. // GroupResponse defines the structure for a group response, excluding sensitive or large fields.
type GroupResponse struct { type GroupResponse struct {
ID uint `json:"id"` ID uint `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
DisplayName string `json:"display_name"` DisplayName string `json:"display_name"`
Description string `json:"description"` Description string `json:"description"`
Upstreams datatypes.JSON `json:"upstreams"` Upstreams datatypes.JSON `json:"upstreams"`
ChannelType string `json:"channel_type"` ChannelType string `json:"channel_type"`
Sort int `json:"sort"` Sort int `json:"sort"`
TestModel string `json:"test_model"` TestModel string `json:"test_model"`
ParamOverrides datatypes.JSONMap `json:"param_overrides"` ValidationEndpoint string `json:"validation_endpoint"`
Config datatypes.JSONMap `json:"config"` ParamOverrides datatypes.JSONMap `json:"param_overrides"`
LastValidatedAt *time.Time `json:"last_validated_at"` Config datatypes.JSONMap `json:"config"`
CreatedAt time.Time `json:"created_at"` LastValidatedAt *time.Time `json:"last_validated_at"`
UpdatedAt time.Time `json:"updated_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
} }
// newGroupResponse creates a new GroupResponse from a models.Group. // newGroupResponse creates a new GroupResponse from a models.Group.
@@ -368,20 +400,21 @@ func (s *Server) newGroupResponse(group *models.Group) *GroupResponse {
} }
return &GroupResponse{ return &GroupResponse{
ID: group.ID, ID: group.ID,
Name: group.Name, Name: group.Name,
Endpoint: endpoint, Endpoint: endpoint,
DisplayName: group.DisplayName, DisplayName: group.DisplayName,
Description: group.Description, Description: group.Description,
Upstreams: group.Upstreams, Upstreams: group.Upstreams,
ChannelType: group.ChannelType, ChannelType: group.ChannelType,
Sort: group.Sort, Sort: group.Sort,
TestModel: group.TestModel, TestModel: group.TestModel,
ParamOverrides: group.ParamOverrides, ValidationEndpoint: group.ValidationEndpoint,
Config: group.Config, ParamOverrides: group.ParamOverrides,
LastValidatedAt: group.LastValidatedAt, Config: group.Config,
CreatedAt: group.CreatedAt, LastValidatedAt: group.LastValidatedAt,
UpdatedAt: group.UpdatedAt, CreatedAt: group.CreatedAt,
UpdatedAt: group.UpdatedAt,
} }
} }

View File

@@ -6,10 +6,8 @@ import (
"strings" "strings"
"time" "time"
"gpt-load/internal/channel"
app_errors "gpt-load/internal/errors" app_errors "gpt-load/internal/errors"
"gpt-load/internal/response" "gpt-load/internal/response"
"gpt-load/internal/services"
"gpt-load/internal/types" "gpt-load/internal/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -116,45 +114,16 @@ func CORS(config types.CORSConfig) gin.HandlerFunc {
} }
// Auth creates an authentication middleware // Auth creates an authentication middleware
func Auth( func Auth(authConfig types.AuthConfig) gin.HandlerFunc {
authConfig types.AuthConfig,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
path := c.Request.URL.Path path := c.Request.URL.Path
// Skip authentication for health endpoints
if isMonitoringEndpoint(path) { if isMonitoringEndpoint(path) {
c.Next() c.Next()
return return
} }
var key string key := extractAuthKey(c)
var err error
if strings.HasPrefix(path, "/api") {
// Handle backend API authentication
key = extractApiKey(c)
} else if strings.HasPrefix(path, "/proxy/") {
// Handle proxy authentication
key, err = extractProxyKey(c, groupManager, channelFactory)
if err != nil {
// The error from extractProxyKey is already an APIError
if apiErr, ok := err.(*app_errors.APIError); ok {
response.Error(c, apiErr)
} else {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, err.Error()))
}
c.Abort()
return
}
} else {
// For any other paths, deny access by default
response.Error(c, app_errors.ErrResourceNotFound)
c.Abort()
return
}
if key == "" || key != authConfig.Key { if key == "" || key != authConfig.Key {
response.Error(c, app_errors.ErrUnauthorized) response.Error(c, app_errors.ErrUnauthorized)
@@ -162,8 +131,6 @@ func Auth(
return return
} }
// Key is extracted, but validation is handled by the proxy logic itself.
// For the backend API, we've already validated it.
c.Next() c.Next()
} }
} }
@@ -227,8 +194,10 @@ func isMonitoringEndpoint(path string) bool {
return false return false
} }
// extractBearerKey extracts a key from the "Authorization: Bearer <key>" header. // extractAuthKey extracts a auth key.
func extractApiKey(c *gin.Context) string { func extractAuthKey(c *gin.Context) string {
// Bearer token
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
if authHeader != "" { if authHeader != "" {
const bearerPrefix = "Bearer " const bearerPrefix = "Bearer "
@@ -237,39 +206,20 @@ func extractApiKey(c *gin.Context) string {
} }
} }
authKey := c.Query("auth_key") // X-Api-Key
if authKey != "" { if key := c.GetHeader("X-Api-Key"); key != "" {
return authKey return key
}
// X-Goog-Api-Key
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
return key
}
// Query key
if key := c.Query("key"); key != "" {
return key
} }
return "" return ""
} }
// extractProxyKey handles key extraction for proxy routes.
func extractProxyKey(
c *gin.Context,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
) (string, error) {
groupName := c.Param("group_name")
if groupName == "" {
return "", app_errors.NewAPIError(app_errors.ErrBadRequest, "Group name is missing in the URL path")
}
group, err := groupManager.GetGroupByName(groupName)
if err != nil {
return "", app_errors.NewAPIError(app_errors.ErrResourceNotFound, fmt.Sprintf("Group '%s' not found", groupName))
}
channel, err := channelFactory.GetChannel(group)
if err != nil {
return "", app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s'", groupName))
}
key := channel.ExtractKey(c)
if key == "" {
return "", app_errors.ErrUnauthorized
}
return key, nil
}

View File

@@ -40,22 +40,23 @@ type GroupConfig struct {
// Group 对应 groups 表 // Group 对应 groups 表
type Group struct { type Group struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"` ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
EffectiveConfig types.SystemSettings `gorm:"-" json:"effective_config,omitempty"` EffectiveConfig types.SystemSettings `gorm:"-" json:"effective_config,omitempty"`
Name string `gorm:"type:varchar(255);not null;unique" json:"name"` Name string `gorm:"type:varchar(255);not null;unique" json:"name"`
Endpoint string `gorm:"-" json:"endpoint"` Endpoint string `gorm:"-" json:"endpoint"`
DisplayName string `gorm:"type:varchar(255)" json:"display_name"` DisplayName string `gorm:"type:varchar(255)" json:"display_name"`
Description string `gorm:"type:varchar(512)" json:"description"` Description string `gorm:"type:varchar(512)" json:"description"`
Upstreams datatypes.JSON `gorm:"type:json;not null" json:"upstreams"` Upstreams datatypes.JSON `gorm:"type:json;not null" json:"upstreams"`
ChannelType string `gorm:"type:varchar(50);not null" json:"channel_type"` ValidationEndpoint string `gorm:"type:varchar(255)" json:"validation_endpoint"`
Sort int `gorm:"default:0" json:"sort"` ChannelType string `gorm:"type:varchar(50);not null" json:"channel_type"`
TestModel string `gorm:"type:varchar(255);not null" json:"test_model"` Sort int `gorm:"default:0" json:"sort"`
ParamOverrides datatypes.JSONMap `gorm:"type:json" json:"param_overrides"` TestModel string `gorm:"type:varchar(255);not null" json:"test_model"`
Config datatypes.JSONMap `gorm:"type:json" json:"config"` ParamOverrides datatypes.JSONMap `gorm:"type:json" json:"param_overrides"`
APIKeys []APIKey `gorm:"foreignKey:GroupID" json:"api_keys"` Config datatypes.JSONMap `gorm:"type:json" json:"config"`
LastValidatedAt *time.Time `json:"last_validated_at"` APIKeys []APIKey `gorm:"foreignKey:GroupID" json:"api_keys"`
CreatedAt time.Time `json:"created_at"` LastValidatedAt *time.Time `json:"last_validated_at"`
UpdatedAt time.Time `json:"updated_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
} }
// APIKey 对应 api_keys 表 // APIKey 对应 api_keys 表

View File

@@ -156,6 +156,15 @@ func (ps *ProxyServer) executeRequestWithRetry(
req.ContentLength = int64(len(bodyBytes)) req.ContentLength = int64(len(bodyBytes))
req.Header = c.Request.Header.Clone() req.Header = c.Request.Header.Clone()
// Clean up client auth key
req.Header.Del("Authorization")
req.Header.Del("X-Api-Key")
req.Header.Del("X-Goog-Api-Key")
q := req.URL.Query()
q.Del("key")
req.URL.RawQuery = q.Encode()
channelHandler.ModifyRequest(req, apiKey, group) channelHandler.ModifyRequest(req, apiKey, group)
var client *http.Client var client *http.Client

View File

@@ -65,8 +65,8 @@ func NewRouter(
// 注册路由 // 注册路由
registerSystemRoutes(router, serverHandler) registerSystemRoutes(router, serverHandler)
registerAPIRoutes(router, serverHandler, configManager, groupManager, channelFactory) registerAPIRoutes(router, serverHandler, configManager)
registerProxyRoutes(router, proxyServer, configManager, groupManager, channelFactory) registerProxyRoutes(router, proxyServer, configManager)
registerFrontendRoutes(router, buildFS, indexPage) registerFrontendRoutes(router, buildFS, indexPage)
return router return router
@@ -82,8 +82,6 @@ func registerAPIRoutes(
router *gin.Engine, router *gin.Engine,
serverHandler *handler.Server, serverHandler *handler.Server,
configManager types.ConfigManager, configManager types.ConfigManager,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
) { ) {
api := router.Group("/api") api := router.Group("/api")
authConfig := configManager.GetAuthConfig() authConfig := configManager.GetAuthConfig()
@@ -93,7 +91,7 @@ func registerAPIRoutes(
// 认证 // 认证
protectedAPI := api.Group("") protectedAPI := api.Group("")
protectedAPI.Use(middleware.Auth(authConfig, groupManager, channelFactory)) protectedAPI.Use(middleware.Auth(authConfig))
registerProtectedAPIRoutes(protectedAPI, serverHandler) registerProtectedAPIRoutes(protectedAPI, serverHandler)
} }
@@ -162,13 +160,11 @@ func registerProxyRoutes(
router *gin.Engine, router *gin.Engine,
proxyServer *proxy.ProxyServer, proxyServer *proxy.ProxyServer,
configManager types.ConfigManager, configManager types.ConfigManager,
groupManager *services.GroupManager,
channelFactory *channel.Factory,
) { ) {
proxyGroup := router.Group("/proxy") proxyGroup := router.Group("/proxy")
authConfig := configManager.GetAuthConfig() authConfig := configManager.GetAuthConfig()
proxyGroup.Use(middleware.Auth(authConfig, groupManager, channelFactory)) proxyGroup.Use(middleware.Auth(authConfig))
proxyGroup.Any("/:group_name/*path", proxyServer.HandleProxy) proxyGroup.Any("/:group_name/*path", proxyServer.HandleProxy)
} }

View File

@@ -157,7 +157,7 @@ export const keysApi = {
const params = new URLSearchParams({ const params = new URLSearchParams({
group_id: groupId.toString(), group_id: groupId.toString(),
auth_key: authKey, key: authKey,
}); });
if (status !== "all") { if (status !== "all") {

View File

@@ -31,7 +31,7 @@ export const logApi = {
{} as Record<string, string> {} as Record<string, string>
) )
); );
queryParams.append("auth_key", authKey); queryParams.append("key", authKey);
const url = `${http.defaults.baseURL}/logs/export?${queryParams.toString()}`; const url = `${http.defaults.baseURL}/logs/export?${queryParams.toString()}`;

View File

@@ -54,6 +54,7 @@ interface GroupFormData {
channel_type: "openai" | "gemini" | "anthropic"; channel_type: "openai" | "gemini" | "anthropic";
sort: number; sort: number;
test_model: string; test_model: string;
validation_endpoint: string;
param_overrides: string; param_overrides: string;
config: Record<string, number>; config: Record<string, number>;
configItems: ConfigItem[]; configItems: ConfigItem[];
@@ -73,6 +74,7 @@ const formData = reactive<GroupFormData>({
channel_type: "openai", channel_type: "openai",
sort: 1, sort: 1,
test_model: "", test_model: "",
validation_endpoint: "",
param_overrides: "", param_overrides: "",
config: {}, config: {},
configItems: [] as ConfigItem[], configItems: [] as ConfigItem[],
@@ -177,6 +179,7 @@ function resetForm() {
channel_type: "openai", channel_type: "openai",
sort: 1, sort: 1,
test_model: "", test_model: "",
validation_endpoint: "",
param_overrides: "", param_overrides: "",
config: {}, config: {},
configItems: [], configItems: [],
@@ -203,6 +206,7 @@ function loadGroupData() {
channel_type: props.group.channel_type || "openai", channel_type: props.group.channel_type || "openai",
sort: props.group.sort || 1, sort: props.group.sort || 1,
test_model: props.group.test_model || "", test_model: props.group.test_model || "",
validation_endpoint: props.group.validation_endpoint || "",
param_overrides: JSON.stringify(props.group.param_overrides || {}, null, 2), param_overrides: JSON.stringify(props.group.param_overrides || {}, null, 2),
config: {}, config: {},
configItems, configItems,
@@ -231,6 +235,8 @@ function addUpstream() {
function removeUpstream(index: number) { function removeUpstream(index: number) {
if (formData.upstreams.length > 1) { if (formData.upstreams.length > 1) {
formData.upstreams.splice(index, 1); formData.upstreams.splice(index, 1);
} else {
message.warning("至少需要保留一个上游地址");
} }
} }
@@ -305,6 +311,7 @@ async function handleSubmit() {
channel_type: formData.channel_type, channel_type: formData.channel_type,
sort: formData.sort, sort: formData.sort,
test_model: formData.test_model, test_model: formData.test_model,
validation_endpoint: formData.validation_endpoint,
param_overrides: formData.param_overrides ? paramOverrides : undefined, param_overrides: formData.param_overrides ? paramOverrides : undefined,
config, config,
}; };
@@ -376,6 +383,17 @@ async function handleSubmit() {
<n-input v-model:value="formData.test_model" :placeholder="testModelPlaceholder" /> <n-input v-model:value="formData.test_model" :placeholder="testModelPlaceholder" />
</n-form-item> </n-form-item>
<n-form-item
label="测试路径"
path="validation_endpoint"
v-if="formData.channel_type !== 'gemini'"
>
<n-input
v-model:value="formData.validation_endpoint"
placeholder="可选自定义用于验证key的API路径"
/>
</n-form-item>
<n-form-item label="排序" path="sort"> <n-form-item label="排序" path="sort">
<n-input-number <n-input-number
v-model:value="formData.sort" v-model:value="formData.sort"

View File

@@ -318,6 +318,11 @@ function resetPage() {
{{ group?.sort || 0 }} {{ group?.sort || 0 }}
</n-form-item> </n-form-item>
</n-grid-item> </n-grid-item>
<n-grid-item v-if="group?.channel_type !== 'gemini'">
<n-form-item label="测试路径:">
{{ group?.validation_endpoint }}
</n-form-item>
</n-grid-item>
<n-grid-item> <n-grid-item>
<n-form-item label="描述:"> <n-form-item label="描述:">
{{ group?.description || "-" }} {{ group?.description || "-" }}

View File

@@ -38,6 +38,7 @@ export interface Group {
test_model: string; test_model: string;
channel_type: "openai" | "gemini" | "anthropic"; channel_type: "openai" | "gemini" | "anthropic";
upstreams: UpstreamInfo[]; upstreams: UpstreamInfo[];
validation_endpoint: string;
config: Record<string, unknown>; config: Record<string, unknown>;
api_keys?: APIKey[]; api_keys?: APIKey[];
endpoint?: string; endpoint?: string;