diff --git a/internal/channel/factory.go b/internal/channel/factory.go index c3e38dd..1703f83 100644 --- a/internal/channel/factory.go +++ b/internal/channel/factory.go @@ -12,6 +12,32 @@ import ( "gorm.io/datatypes" ) +// channelConstructor defines the function signature for creating a new channel proxy. +type channelConstructor func(f *Factory, group *models.Group) (ChannelProxy, error) + +var ( + // channelRegistry holds the mapping from channel type string to its constructor. + channelRegistry = make(map[string]channelConstructor) +) + +// Register adds a new channel constructor to the registry. +// This function is intended to be called from the init() function of each channel implementation. +func Register(channelType string, constructor channelConstructor) { + if _, exists := channelRegistry[channelType]; exists { + panic(fmt.Sprintf("channel type '%s' is already registered", channelType)) + } + channelRegistry[channelType] = constructor +} + +// GetChannels returns a slice of all registered channel type names. +func GetChannels() []string { + supportedTypes := make([]string, 0, len(channelRegistry)) + for t := range channelRegistry { + supportedTypes = append(supportedTypes, t) + } + return supportedTypes +} + // Factory is responsible for creating channel proxies. type Factory struct { settingsManager *config.SystemSettingsManager @@ -26,14 +52,11 @@ func NewFactory(settingsManager *config.SystemSettingsManager) *Factory { // GetChannel returns a channel proxy based on the group's channel type. func (f *Factory) GetChannel(group *models.Group) (ChannelProxy, error) { - switch group.ChannelType { - case "openai": - return f.NewOpenAIChannel(group) - case "gemini": - return f.NewGeminiChannel(group) - default: + constructor, ok := channelRegistry[group.ChannelType] + if !ok { return nil, fmt.Errorf("unsupported channel type: %s", group.ChannelType) } + return constructor(f, group) } // newBaseChannel is a helper function to create and configure a BaseChannel. @@ -60,7 +83,7 @@ func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, grou } weight := def.Weight if weight <= 0 { - weight = 1 // Default weight to 1 if not specified or invalid + weight = 1 } upstreamInfos = append(upstreamInfos, UpstreamInfo{URL: u, Weight: weight}) } diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index 0353eae..3412e11 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -11,11 +11,15 @@ import ( "github.com/gin-gonic/gin" ) +func init() { + Register("gemini", newGeminiChannel) +} + type GeminiChannel struct { *BaseChannel } -func (f *Factory) NewGeminiChannel(group *models.Group) (*GeminiChannel, error) { +func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) { base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config) if err != nil { return nil, err diff --git a/internal/channel/openai_channel.go b/internal/channel/openai_channel.go index d2dfe1d..a7fea5e 100644 --- a/internal/channel/openai_channel.go +++ b/internal/channel/openai_channel.go @@ -10,11 +10,15 @@ import ( "github.com/gin-gonic/gin/binding" ) +func init() { + Register("openai", newOpenAIChannel) +} + type OpenAIChannel struct { *BaseChannel } -func (f *Factory) NewOpenAIChannel(group *models.Group) (*OpenAIChannel, error) { +func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) { base, err := f.newBaseChannel("openai", group.Upstreams, group.Config) if err != nil { return nil, err