diff --git a/internal/channel/base_channel.go b/internal/channel/base_channel.go index 0b14675..7257d8a 100644 --- a/internal/channel/base_channel.go +++ b/internal/channel/base_channel.go @@ -1,6 +1,8 @@ package channel import ( + "bytes" + "encoding/json" "fmt" "gpt-load/internal/models" "io" @@ -12,6 +14,7 @@ import ( "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" + "gorm.io/datatypes" ) // UpstreamInfo holds the information for a single upstream server, including its weight. @@ -23,11 +26,13 @@ type UpstreamInfo struct { // BaseChannel provides common functionality for channel proxies. type BaseChannel struct { - Name string - Upstreams []UpstreamInfo - HTTPClient *http.Client - TestModel string - upstreamLock sync.Mutex + Name string + Upstreams []UpstreamInfo + HTTPClient *http.Client + TestModel string + upstreamLock sync.Mutex + groupUpstreams datatypes.JSON + groupConfig datatypes.JSONMap } // RequestModifier is a function that can modify the request before it's sent. @@ -66,6 +71,34 @@ func (b *BaseChannel) getUpstreamURL() *url.URL { return best.URL } +// IsConfigStale checks if the channel's configuration is stale compared to the provided group. +func (b *BaseChannel) IsConfigStale(group *models.Group) bool { + // It's important to compare the raw JSON here to detect any changes. + if !bytes.Equal(b.groupUpstreams, group.Upstreams) { + return true + } + + // For JSONMap, we need to marshal it to compare. + currentConfigBytes, err := json.Marshal(b.groupConfig) + if err != nil { + // Log the error and assume it's stale to be safe + logrus.Errorf("failed to marshal current group config: %v", err) + return true + } + newConfigBytes, err := json.Marshal(group.Config) + if err != nil { + // Log the error and assume it's stale + logrus.Errorf("failed to marshal new group config: %v", err) + return true + } + + if !bytes.Equal(currentConfigBytes, newConfigBytes) { + return true + } + + return false +} + // ProcessRequest handles the common logic of processing and forwarding a request. func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier, ch ChannelProxy) error { upstreamURL := b.getUpstreamURL() diff --git a/internal/channel/channel.go b/internal/channel/channel.go index 4317cfe..96604e2 100644 --- a/internal/channel/channel.go +++ b/internal/channel/channel.go @@ -18,4 +18,7 @@ type ChannelProxy interface { // IsStreamingRequest checks if the request is for a streaming response. IsStreamingRequest(c *gin.Context) bool + + // IsConfigStale checks if the channel's configuration is stale compared to the provided group. + IsConfigStale(group *models.Group) bool } diff --git a/internal/channel/factory.go b/internal/channel/factory.go index 84c0818..ba8addf 100644 --- a/internal/channel/factory.go +++ b/internal/channel/factory.go @@ -7,9 +7,10 @@ import ( "gpt-load/internal/models" "net/http" "net/url" + "sync" "time" - "gorm.io/datatypes" + "github.com/sirupsen/logrus" ) // channelConstructor defines the function signature for creating a new channel proxy. @@ -41,33 +42,53 @@ func GetChannels() []string { // Factory is responsible for creating channel proxies. type Factory struct { settingsManager *config.SystemSettingsManager + channelCache map[uint]ChannelProxy + cacheLock sync.Mutex } // NewFactory creates a new channel factory. func NewFactory(settingsManager *config.SystemSettingsManager) *Factory { return &Factory{ settingsManager: settingsManager, + channelCache: make(map[uint]ChannelProxy), } } // GetChannel returns a channel proxy based on the group's channel type. +// It uses a cache to ensure that only one instance of a channel is created for each group. func (f *Factory) GetChannel(group *models.Group) (ChannelProxy, error) { + f.cacheLock.Lock() + defer f.cacheLock.Unlock() + + if channel, ok := f.channelCache[group.ID]; ok { + if !channel.IsConfigStale(group) { + return channel, nil + } + } + + logrus.Infof("Creating new channel for group %d with type '%s'", group.ID, group.ChannelType) + constructor, ok := channelRegistry[group.ChannelType] if !ok { return nil, fmt.Errorf("unsupported channel type: %s", group.ChannelType) } - return constructor(f, group) + channel, err := constructor(f, group) + if err != nil { + return nil, err + } + f.channelCache[group.ID] = channel + return channel, nil } // newBaseChannel is a helper function to create and configure a BaseChannel. -func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, groupConfig datatypes.JSONMap, testModel string) (*BaseChannel, error) { +func (f *Factory) newBaseChannel(name string, group *models.Group) (*BaseChannel, error) { type upstreamDef struct { URL string `json:"url"` Weight int `json:"weight"` } var defs []upstreamDef - if err := json.Unmarshal(upstreamsJSON, &defs); err != nil { + if err := json.Unmarshal(group.Upstreams, &defs); err != nil { return nil, fmt.Errorf("failed to unmarshal upstreams for %s channel: %w", name, err) } @@ -89,7 +110,7 @@ func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, grou } // Get effective settings by merging system and group configs - effectiveSettings := f.settingsManager.GetEffectiveConfig(groupConfig) + effectiveSettings := f.settingsManager.GetEffectiveConfig(group.Config) // Configure the HTTP client with the effective timeouts httpClient := &http.Client{ @@ -100,9 +121,11 @@ func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, grou } return &BaseChannel{ - Name: name, - Upstreams: upstreamInfos, - HTTPClient: httpClient, - TestModel: testModel, + Name: name, + Upstreams: upstreamInfos, + HTTPClient: httpClient, + TestModel: group.TestModel, + groupUpstreams: group.Upstreams, + groupConfig: group.Config, }, nil } diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index 1fa0042..9f9077b 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -23,7 +23,7 @@ type GeminiChannel struct { } func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) { - base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config, group.TestModel) + base, err := f.newBaseChannel("gemini", group) if err != nil { return nil, err } diff --git a/internal/channel/openai_channel.go b/internal/channel/openai_channel.go index 2b054da..3ecd7e5 100644 --- a/internal/channel/openai_channel.go +++ b/internal/channel/openai_channel.go @@ -23,7 +23,7 @@ type OpenAIChannel struct { } func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) { - base, err := f.newBaseChannel("openai", group.Upstreams, group.Config, group.TestModel) + base, err := f.newBaseChannel("openai", group) if err != nil { return nil, err }