refactor: 调整渠道注册机制

This commit is contained in:
tbphp
2025-07-05 08:51:37 +08:00
parent 40f59524f0
commit 92e4a53659
3 changed files with 40 additions and 9 deletions

View File

@@ -12,6 +12,32 @@ import (
"gorm.io/datatypes"
)
// channelConstructor defines the function signature for creating a new channel proxy.
type channelConstructor func(f *Factory, group *models.Group) (ChannelProxy, error)
var (
// channelRegistry holds the mapping from channel type string to its constructor.
channelRegistry = make(map[string]channelConstructor)
)
// Register adds a new channel constructor to the registry.
// This function is intended to be called from the init() function of each channel implementation.
func Register(channelType string, constructor channelConstructor) {
if _, exists := channelRegistry[channelType]; exists {
panic(fmt.Sprintf("channel type '%s' is already registered", channelType))
}
channelRegistry[channelType] = constructor
}
// GetChannels returns a slice of all registered channel type names.
func GetChannels() []string {
supportedTypes := make([]string, 0, len(channelRegistry))
for t := range channelRegistry {
supportedTypes = append(supportedTypes, t)
}
return supportedTypes
}
// Factory is responsible for creating channel proxies.
type Factory struct {
settingsManager *config.SystemSettingsManager
@@ -26,14 +52,11 @@ func NewFactory(settingsManager *config.SystemSettingsManager) *Factory {
// GetChannel returns a channel proxy based on the group's channel type.
func (f *Factory) GetChannel(group *models.Group) (ChannelProxy, error) {
switch group.ChannelType {
case "openai":
return f.NewOpenAIChannel(group)
case "gemini":
return f.NewGeminiChannel(group)
default:
constructor, ok := channelRegistry[group.ChannelType]
if !ok {
return nil, fmt.Errorf("unsupported channel type: %s", group.ChannelType)
}
return constructor(f, group)
}
// newBaseChannel is a helper function to create and configure a BaseChannel.
@@ -60,7 +83,7 @@ func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, grou
}
weight := def.Weight
if weight <= 0 {
weight = 1 // Default weight to 1 if not specified or invalid
weight = 1
}
upstreamInfos = append(upstreamInfos, UpstreamInfo{URL: u, Weight: weight})
}

View File

@@ -11,11 +11,15 @@ import (
"github.com/gin-gonic/gin"
)
func init() {
Register("gemini", newGeminiChannel)
}
type GeminiChannel struct {
*BaseChannel
}
func (f *Factory) NewGeminiChannel(group *models.Group) (*GeminiChannel, error) {
func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config)
if err != nil {
return nil, err

View File

@@ -10,11 +10,15 @@ import (
"github.com/gin-gonic/gin/binding"
)
func init() {
Register("openai", newOpenAIChannel)
}
type OpenAIChannel struct {
*BaseChannel
}
func (f *Factory) NewOpenAIChannel(group *models.Group) (*OpenAIChannel, error) {
func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
base, err := f.newBaseChannel("openai", group.Upstreams, group.Config)
if err != nil {
return nil, err