diff --git a/internal/channel/anthropic_channel.go b/internal/channel/anthropic_channel.go index 1919a32..d81190e 100644 --- a/internal/channel/anthropic_channel.go +++ b/internal/channel/anthropic_channel.go @@ -9,6 +9,7 @@ import ( "gpt-load/internal/models" "io" "net/http" + "net/url" "strings" "github.com/gin-gonic/gin" @@ -67,7 +68,14 @@ func (ch *AnthropicChannel) ValidateKey(ctx context.Context, key string) (bool, return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name) } - reqURL := upstreamURL.String() + "/v1/messages" + validationEndpoint := ch.ValidationEndpoint + if validationEndpoint == "" { + validationEndpoint = "/v1/messages" + } + reqURL, err := url.JoinPath(upstreamURL.String(), validationEndpoint) + if err != nil { + return false, fmt.Errorf("failed to join upstream URL and validation endpoint: %w", err) + } // Use a minimal, low-cost payload for validation payload := gin.H{ diff --git a/internal/channel/base_channel.go b/internal/channel/base_channel.go index d6610b7..bf534d1 100644 --- a/internal/channel/base_channel.go +++ b/internal/channel/base_channel.go @@ -23,12 +23,13 @@ type UpstreamInfo struct { // BaseChannel provides common functionality for channel proxies. type BaseChannel struct { - Name string - Upstreams []UpstreamInfo - HTTPClient *http.Client - StreamClient *http.Client - TestModel string - upstreamLock sync.Mutex + Name string + Upstreams []UpstreamInfo + HTTPClient *http.Client + StreamClient *http.Client + TestModel string + ValidationEndpoint string + upstreamLock sync.Mutex // Cached fields from the group for stale check channelType string @@ -96,6 +97,9 @@ func (b *BaseChannel) IsConfigStale(group *models.Group) bool { if b.TestModel != group.TestModel { return true } + if b.ValidationEndpoint != group.ValidationEndpoint { + return true + } if !bytes.Equal(b.groupUpstreams, group.Upstreams) { return true } diff --git a/internal/channel/factory.go b/internal/channel/factory.go index 42bed43..bed4cd3 100644 --- a/internal/channel/factory.go +++ b/internal/channel/factory.go @@ -140,13 +140,14 @@ func (f *Factory) newBaseChannel(name string, group *models.Group) (*BaseChannel streamClient := f.clientManager.GetClient(&streamConfig) return &BaseChannel{ - Name: name, - Upstreams: upstreamInfos, - HTTPClient: httpClient, - StreamClient: streamClient, - TestModel: group.TestModel, - channelType: group.ChannelType, - groupUpstreams: group.Upstreams, - effectiveConfig: &group.EffectiveConfig, + Name: name, + Upstreams: upstreamInfos, + HTTPClient: httpClient, + StreamClient: streamClient, + TestModel: group.TestModel, + ValidationEndpoint: group.ValidationEndpoint, + channelType: group.ChannelType, + groupUpstreams: group.Upstreams, + effectiveConfig: &group.EffectiveConfig, }, nil } diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index 4e769d8..5a1676f 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -9,6 +9,7 @@ import ( "gpt-load/internal/models" "io" "net/http" + "net/url" "strings" "github.com/gin-gonic/gin" @@ -65,7 +66,12 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, err return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name) } - reqURL := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", upstreamURL.String(), ch.TestModel, key) + // Safely join the path segments + reqURL, err := url.JoinPath(upstreamURL.String(), "v1beta", "models", ch.TestModel+":generateContent") + if err != nil { + return false, fmt.Errorf("failed to create gemini validation path: %w", err) + } + reqURL += "?key=" + key payload := gin.H{ "contents": []gin.H{ diff --git a/internal/channel/openai_channel.go b/internal/channel/openai_channel.go index caa0a7b..df43b1a 100644 --- a/internal/channel/openai_channel.go +++ b/internal/channel/openai_channel.go @@ -9,6 +9,7 @@ import ( "gpt-load/internal/models" "io" "net/http" + "net/url" "strings" "github.com/gin-gonic/gin" @@ -66,7 +67,14 @@ func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, err return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name) } - reqURL := upstreamURL.String() + "/v1/chat/completions" + validationEndpoint := ch.ValidationEndpoint + if validationEndpoint == "" { + validationEndpoint = "/v1/chat/completions" + } + reqURL, err := url.JoinPath(upstreamURL.String(), validationEndpoint) + if err != nil { + return false, fmt.Errorf("failed to join upstream URL and validation endpoint: %w", err) + } // Use a minimal, low-cost payload for validation payload := gin.H{ diff --git a/internal/handler/group_handler.go b/internal/handler/group_handler.go index ce6c57a..2caa5c2 100644 --- a/internal/handler/group_handler.go +++ b/internal/handler/group_handler.go @@ -88,6 +88,20 @@ func isValidGroupName(name string) bool { return match } +// isValidValidationEndpoint checks if the validation endpoint is a valid path. +func isValidValidationEndpoint(endpoint string) bool { + if endpoint == "" { + return true + } + if !strings.HasPrefix(endpoint, "/") { + return false + } + if strings.Contains(endpoint, "://") { + return false + } + return true +} + // validateAndCleanConfig validates the group config against the GroupConfig struct and system-defined rules. func (s *Server) validateAndCleanConfig(configMap map[string]any) (map[string]any, error) { if configMap == nil { @@ -180,16 +194,23 @@ func (s *Server) CreateGroup(c *gin.Context) { return } + validationEndpoint := strings.TrimSpace(req.ValidationEndpoint) + if !isValidValidationEndpoint(validationEndpoint) { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "无效的测试路径。如果提供,必须是以 / 开头的有效路径,且不能是完整的URL。")) + return + } + group := models.Group{ - Name: name, - DisplayName: strings.TrimSpace(req.DisplayName), - Description: strings.TrimSpace(req.Description), - Upstreams: cleanedUpstreams, - ChannelType: channelType, - Sort: req.Sort, - TestModel: testModel, - ParamOverrides: req.ParamOverrides, - Config: cleanedConfig, + Name: name, + DisplayName: strings.TrimSpace(req.DisplayName), + Description: strings.TrimSpace(req.Description), + Upstreams: cleanedUpstreams, + ChannelType: channelType, + Sort: req.Sort, + TestModel: testModel, + ValidationEndpoint: validationEndpoint, + ParamOverrides: req.ParamOverrides, + Config: cleanedConfig, } if err := s.DB.Create(&group).Error; err != nil { @@ -222,15 +243,16 @@ func (s *Server) ListGroups(c *gin.Context) { // GroupUpdateRequest defines the payload for updating a group. // Using a dedicated struct avoids issues with zero values being ignored by GORM's Update. type GroupUpdateRequest struct { - Name *string `json:"name,omitempty"` - DisplayName *string `json:"display_name,omitempty"` - Description *string `json:"description,omitempty"` - Upstreams json.RawMessage `json:"upstreams"` - ChannelType *string `json:"channel_type,omitempty"` - Sort *int `json:"sort"` - TestModel string `json:"test_model"` - ParamOverrides map[string]any `json:"param_overrides"` - Config map[string]any `json:"config"` + Name *string `json:"name,omitempty"` + DisplayName *string `json:"display_name,omitempty"` + Description *string `json:"description,omitempty"` + Upstreams json.RawMessage `json:"upstreams"` + ChannelType *string `json:"channel_type,omitempty"` + Sort *int `json:"sort"` + TestModel string `json:"test_model"` + ValidationEndpoint *string `json:"validation_endpoint,omitempty"` + ParamOverrides map[string]any `json:"param_overrides"` + Config map[string]any `json:"config"` } // UpdateGroup handles updating an existing group. @@ -311,6 +333,15 @@ func (s *Server) UpdateGroup(c *gin.Context) { if req.ParamOverrides != nil { group.ParamOverrides = req.ParamOverrides } + if req.ValidationEndpoint != nil { + validationEndpoint := strings.TrimSpace(*req.ValidationEndpoint) + if !isValidValidationEndpoint(validationEndpoint) { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "无效的测试路径。如果提供,必须是以 / 开头的有效路径,且不能是完整的URL。")) + return + } + group.ValidationEndpoint = validationEndpoint + } + if req.Config != nil { cleanedConfig, err := s.validateAndCleanConfig(req.Config) if err != nil { @@ -339,20 +370,21 @@ func (s *Server) UpdateGroup(c *gin.Context) { // GroupResponse defines the structure for a group response, excluding sensitive or large fields. type GroupResponse struct { - ID uint `json:"id"` - Name string `json:"name"` - Endpoint string `json:"endpoint"` - DisplayName string `json:"display_name"` - Description string `json:"description"` - Upstreams datatypes.JSON `json:"upstreams"` - ChannelType string `json:"channel_type"` - Sort int `json:"sort"` - TestModel string `json:"test_model"` - ParamOverrides datatypes.JSONMap `json:"param_overrides"` - Config datatypes.JSONMap `json:"config"` - LastValidatedAt *time.Time `json:"last_validated_at"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID uint `json:"id"` + Name string `json:"name"` + Endpoint string `json:"endpoint"` + DisplayName string `json:"display_name"` + Description string `json:"description"` + Upstreams datatypes.JSON `json:"upstreams"` + ChannelType string `json:"channel_type"` + Sort int `json:"sort"` + TestModel string `json:"test_model"` + ValidationEndpoint string `json:"validation_endpoint"` + ParamOverrides datatypes.JSONMap `json:"param_overrides"` + Config datatypes.JSONMap `json:"config"` + LastValidatedAt *time.Time `json:"last_validated_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // newGroupResponse creates a new GroupResponse from a models.Group. @@ -368,20 +400,21 @@ func (s *Server) newGroupResponse(group *models.Group) *GroupResponse { } return &GroupResponse{ - ID: group.ID, - Name: group.Name, - Endpoint: endpoint, - DisplayName: group.DisplayName, - Description: group.Description, - Upstreams: group.Upstreams, - ChannelType: group.ChannelType, - Sort: group.Sort, - TestModel: group.TestModel, - ParamOverrides: group.ParamOverrides, - Config: group.Config, - LastValidatedAt: group.LastValidatedAt, - CreatedAt: group.CreatedAt, - UpdatedAt: group.UpdatedAt, + ID: group.ID, + Name: group.Name, + Endpoint: endpoint, + DisplayName: group.DisplayName, + Description: group.Description, + Upstreams: group.Upstreams, + ChannelType: group.ChannelType, + Sort: group.Sort, + TestModel: group.TestModel, + ValidationEndpoint: group.ValidationEndpoint, + ParamOverrides: group.ParamOverrides, + Config: group.Config, + LastValidatedAt: group.LastValidatedAt, + CreatedAt: group.CreatedAt, + UpdatedAt: group.UpdatedAt, } } diff --git a/internal/models/types.go b/internal/models/types.go index 0078424..37ec4db 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -40,22 +40,23 @@ type GroupConfig struct { // Group 对应 groups 表 type Group struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` - EffectiveConfig types.SystemSettings `gorm:"-" json:"effective_config,omitempty"` - Name string `gorm:"type:varchar(255);not null;unique" json:"name"` - Endpoint string `gorm:"-" json:"endpoint"` - DisplayName string `gorm:"type:varchar(255)" json:"display_name"` - Description string `gorm:"type:varchar(512)" json:"description"` - Upstreams datatypes.JSON `gorm:"type:json;not null" json:"upstreams"` - ChannelType string `gorm:"type:varchar(50);not null" json:"channel_type"` - Sort int `gorm:"default:0" json:"sort"` - TestModel string `gorm:"type:varchar(255);not null" json:"test_model"` - ParamOverrides datatypes.JSONMap `gorm:"type:json" json:"param_overrides"` - Config datatypes.JSONMap `gorm:"type:json" json:"config"` - APIKeys []APIKey `gorm:"foreignKey:GroupID" json:"api_keys"` - LastValidatedAt *time.Time `json:"last_validated_at"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + EffectiveConfig types.SystemSettings `gorm:"-" json:"effective_config,omitempty"` + Name string `gorm:"type:varchar(255);not null;unique" json:"name"` + Endpoint string `gorm:"-" json:"endpoint"` + DisplayName string `gorm:"type:varchar(255)" json:"display_name"` + Description string `gorm:"type:varchar(512)" json:"description"` + Upstreams datatypes.JSON `gorm:"type:json;not null" json:"upstreams"` + ValidationEndpoint string `gorm:"type:varchar(255)" json:"validation_endpoint"` + ChannelType string `gorm:"type:varchar(50);not null" json:"channel_type"` + Sort int `gorm:"default:0" json:"sort"` + TestModel string `gorm:"type:varchar(255);not null" json:"test_model"` + ParamOverrides datatypes.JSONMap `gorm:"type:json" json:"param_overrides"` + Config datatypes.JSONMap `gorm:"type:json" json:"config"` + APIKeys []APIKey `gorm:"foreignKey:GroupID" json:"api_keys"` + LastValidatedAt *time.Time `json:"last_validated_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // APIKey 对应 api_keys 表 diff --git a/web/src/components/keys/GroupFormModal.vue b/web/src/components/keys/GroupFormModal.vue index 3a4e87b..2d6358e 100644 --- a/web/src/components/keys/GroupFormModal.vue +++ b/web/src/components/keys/GroupFormModal.vue @@ -54,6 +54,7 @@ interface GroupFormData { channel_type: "openai" | "gemini" | "anthropic"; sort: number; test_model: string; + validation_endpoint: string; param_overrides: string; config: Record; configItems: ConfigItem[]; @@ -73,6 +74,7 @@ const formData = reactive({ channel_type: "openai", sort: 1, test_model: "", + validation_endpoint: "", param_overrides: "", config: {}, configItems: [] as ConfigItem[], @@ -177,6 +179,7 @@ function resetForm() { channel_type: "openai", sort: 1, test_model: "", + validation_endpoint: "", param_overrides: "", config: {}, configItems: [], @@ -203,6 +206,7 @@ function loadGroupData() { channel_type: props.group.channel_type || "openai", sort: props.group.sort || 1, test_model: props.group.test_model || "", + validation_endpoint: props.group.validation_endpoint || "", param_overrides: JSON.stringify(props.group.param_overrides || {}, null, 2), config: {}, configItems, @@ -231,6 +235,8 @@ function addUpstream() { function removeUpstream(index: number) { if (formData.upstreams.length > 1) { formData.upstreams.splice(index, 1); + } else { + message.warning("至少需要保留一个上游地址"); } } @@ -305,6 +311,7 @@ async function handleSubmit() { channel_type: formData.channel_type, sort: formData.sort, test_model: formData.test_model, + validation_endpoint: formData.validation_endpoint, param_overrides: formData.param_overrides ? paramOverrides : undefined, config, }; @@ -376,6 +383,17 @@ async function handleSubmit() { + + + + + + + {{ group?.validation_endpoint }} + + {{ group?.description || "-" }} diff --git a/web/src/types/models.ts b/web/src/types/models.ts index c8d43d9..13838e8 100644 --- a/web/src/types/models.ts +++ b/web/src/types/models.ts @@ -38,6 +38,7 @@ export interface Group { test_model: string; channel_type: "openai" | "gemini" | "anthropic"; upstreams: UpstreamInfo[]; + validation_endpoint: string; config: Record; api_keys?: APIKey[]; endpoint?: string;