feat: 单例group

This commit is contained in:
tbphp
2025-07-06 01:38:30 +08:00
parent ed921352e5
commit e6fe973ea4
5 changed files with 75 additions and 16 deletions

View File

@@ -1,6 +1,8 @@
package channel package channel
import ( import (
"bytes"
"encoding/json"
"fmt" "fmt"
"gpt-load/internal/models" "gpt-load/internal/models"
"io" "io"
@@ -12,6 +14,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gorm.io/datatypes"
) )
// UpstreamInfo holds the information for a single upstream server, including its weight. // UpstreamInfo holds the information for a single upstream server, including its weight.
@@ -28,6 +31,8 @@ type BaseChannel struct {
HTTPClient *http.Client HTTPClient *http.Client
TestModel string TestModel string
upstreamLock sync.Mutex upstreamLock sync.Mutex
groupUpstreams datatypes.JSON
groupConfig datatypes.JSONMap
} }
// RequestModifier is a function that can modify the request before it's sent. // RequestModifier is a function that can modify the request before it's sent.
@@ -66,6 +71,34 @@ func (b *BaseChannel) getUpstreamURL() *url.URL {
return best.URL return best.URL
} }
// IsConfigStale checks if the channel's configuration is stale compared to the provided group.
func (b *BaseChannel) IsConfigStale(group *models.Group) bool {
// It's important to compare the raw JSON here to detect any changes.
if !bytes.Equal(b.groupUpstreams, group.Upstreams) {
return true
}
// For JSONMap, we need to marshal it to compare.
currentConfigBytes, err := json.Marshal(b.groupConfig)
if err != nil {
// Log the error and assume it's stale to be safe
logrus.Errorf("failed to marshal current group config: %v", err)
return true
}
newConfigBytes, err := json.Marshal(group.Config)
if err != nil {
// Log the error and assume it's stale
logrus.Errorf("failed to marshal new group config: %v", err)
return true
}
if !bytes.Equal(currentConfigBytes, newConfigBytes) {
return true
}
return false
}
// ProcessRequest handles the common logic of processing and forwarding a request. // ProcessRequest handles the common logic of processing and forwarding a request.
func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier, ch ChannelProxy) error { func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier, ch ChannelProxy) error {
upstreamURL := b.getUpstreamURL() upstreamURL := b.getUpstreamURL()

View File

@@ -18,4 +18,7 @@ type ChannelProxy interface {
// IsStreamingRequest checks if the request is for a streaming response. // IsStreamingRequest checks if the request is for a streaming response.
IsStreamingRequest(c *gin.Context) bool IsStreamingRequest(c *gin.Context) bool
// IsConfigStale checks if the channel's configuration is stale compared to the provided group.
IsConfigStale(group *models.Group) bool
} }

View File

@@ -7,9 +7,10 @@ import (
"gpt-load/internal/models" "gpt-load/internal/models"
"net/http" "net/http"
"net/url" "net/url"
"sync"
"time" "time"
"gorm.io/datatypes" "github.com/sirupsen/logrus"
) )
// channelConstructor defines the function signature for creating a new channel proxy. // channelConstructor defines the function signature for creating a new channel proxy.
@@ -41,33 +42,53 @@ func GetChannels() []string {
// Factory is responsible for creating channel proxies. // Factory is responsible for creating channel proxies.
type Factory struct { type Factory struct {
settingsManager *config.SystemSettingsManager settingsManager *config.SystemSettingsManager
channelCache map[uint]ChannelProxy
cacheLock sync.Mutex
} }
// NewFactory creates a new channel factory. // NewFactory creates a new channel factory.
func NewFactory(settingsManager *config.SystemSettingsManager) *Factory { func NewFactory(settingsManager *config.SystemSettingsManager) *Factory {
return &Factory{ return &Factory{
settingsManager: settingsManager, settingsManager: settingsManager,
channelCache: make(map[uint]ChannelProxy),
} }
} }
// GetChannel returns a channel proxy based on the group's channel type. // GetChannel returns a channel proxy based on the group's channel type.
// It uses a cache to ensure that only one instance of a channel is created for each group.
func (f *Factory) GetChannel(group *models.Group) (ChannelProxy, error) { func (f *Factory) GetChannel(group *models.Group) (ChannelProxy, error) {
f.cacheLock.Lock()
defer f.cacheLock.Unlock()
if channel, ok := f.channelCache[group.ID]; ok {
if !channel.IsConfigStale(group) {
return channel, nil
}
}
logrus.Infof("Creating new channel for group %d with type '%s'", group.ID, group.ChannelType)
constructor, ok := channelRegistry[group.ChannelType] constructor, ok := channelRegistry[group.ChannelType]
if !ok { if !ok {
return nil, fmt.Errorf("unsupported channel type: %s", group.ChannelType) return nil, fmt.Errorf("unsupported channel type: %s", group.ChannelType)
} }
return constructor(f, group) channel, err := constructor(f, group)
if err != nil {
return nil, err
}
f.channelCache[group.ID] = channel
return channel, nil
} }
// newBaseChannel is a helper function to create and configure a BaseChannel. // newBaseChannel is a helper function to create and configure a BaseChannel.
func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, groupConfig datatypes.JSONMap, testModel string) (*BaseChannel, error) { func (f *Factory) newBaseChannel(name string, group *models.Group) (*BaseChannel, error) {
type upstreamDef struct { type upstreamDef struct {
URL string `json:"url"` URL string `json:"url"`
Weight int `json:"weight"` Weight int `json:"weight"`
} }
var defs []upstreamDef var defs []upstreamDef
if err := json.Unmarshal(upstreamsJSON, &defs); err != nil { if err := json.Unmarshal(group.Upstreams, &defs); err != nil {
return nil, fmt.Errorf("failed to unmarshal upstreams for %s channel: %w", name, err) return nil, fmt.Errorf("failed to unmarshal upstreams for %s channel: %w", name, err)
} }
@@ -89,7 +110,7 @@ func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, grou
} }
// Get effective settings by merging system and group configs // Get effective settings by merging system and group configs
effectiveSettings := f.settingsManager.GetEffectiveConfig(groupConfig) effectiveSettings := f.settingsManager.GetEffectiveConfig(group.Config)
// Configure the HTTP client with the effective timeouts // Configure the HTTP client with the effective timeouts
httpClient := &http.Client{ httpClient := &http.Client{
@@ -103,6 +124,8 @@ func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, grou
Name: name, Name: name,
Upstreams: upstreamInfos, Upstreams: upstreamInfos,
HTTPClient: httpClient, HTTPClient: httpClient,
TestModel: testModel, TestModel: group.TestModel,
groupUpstreams: group.Upstreams,
groupConfig: group.Config,
}, nil }, nil
} }

View File

@@ -23,7 +23,7 @@ type GeminiChannel struct {
} }
func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) { func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config, group.TestModel) base, err := f.newBaseChannel("gemini", group)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -23,7 +23,7 @@ type OpenAIChannel struct {
} }
func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) { func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
base, err := f.newBaseChannel("openai", group.Upstreams, group.Config, group.TestModel) base, err := f.newBaseChannel("openai", group)
if err != nil { if err != nil {
return nil, err return nil, err
} }