feat: 密钥管理
This commit is contained in:
@@ -8,38 +8,65 @@ import (
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// UpstreamInfo holds the information for a single upstream server, including its weight.
|
||||
type UpstreamInfo struct {
|
||||
URL *url.URL
|
||||
Weight int
|
||||
CurrentWeight int
|
||||
}
|
||||
|
||||
// BaseChannel provides common functionality for channel proxies.
|
||||
type BaseChannel struct {
|
||||
Name string
|
||||
Upstreams []*url.URL
|
||||
HTTPClient *http.Client
|
||||
roundRobin uint64
|
||||
Name string
|
||||
Upstreams []UpstreamInfo
|
||||
HTTPClient *http.Client
|
||||
upstreamLock sync.Mutex
|
||||
}
|
||||
|
||||
// RequestModifier is a function that can modify the request before it's sent.
|
||||
type RequestModifier func(req *http.Request, key *models.APIKey)
|
||||
|
||||
// getUpstreamURL selects an upstream URL using round-robin.
|
||||
// getUpstreamURL selects an upstream URL using a smooth weighted round-robin algorithm.
|
||||
func (b *BaseChannel) getUpstreamURL() *url.URL {
|
||||
b.upstreamLock.Lock()
|
||||
defer b.upstreamLock.Unlock()
|
||||
|
||||
if len(b.Upstreams) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(b.Upstreams) == 1 {
|
||||
return b.Upstreams[0]
|
||||
return b.Upstreams[0].URL
|
||||
}
|
||||
index := atomic.AddUint64(&b.roundRobin, 1) - 1
|
||||
return b.Upstreams[index%uint64(len(b.Upstreams))]
|
||||
|
||||
totalWeight := 0
|
||||
var best *UpstreamInfo
|
||||
|
||||
for i := range b.Upstreams {
|
||||
up := &b.Upstreams[i]
|
||||
totalWeight += up.Weight
|
||||
up.CurrentWeight += up.Weight
|
||||
|
||||
if best == nil || up.CurrentWeight > best.CurrentWeight {
|
||||
best = up
|
||||
}
|
||||
}
|
||||
|
||||
if best == nil {
|
||||
return b.Upstreams[0].URL // 降级到第一个可用的
|
||||
}
|
||||
|
||||
best.CurrentWeight -= totalWeight
|
||||
return best.URL
|
||||
}
|
||||
|
||||
// ProcessRequest handles the common logic of processing and forwarding a request.
|
||||
func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier) error {
|
||||
func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier, ch ChannelProxy) error {
|
||||
upstreamURL := b.getUpstreamURL()
|
||||
if upstreamURL == nil {
|
||||
return fmt.Errorf("no upstream URL configured for channel %s", b.Name)
|
||||
@@ -78,7 +105,7 @@ func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modi
|
||||
}
|
||||
|
||||
// Check if the client request is for a streaming endpoint
|
||||
if isStreamingRequest(c) {
|
||||
if ch.IsStreamingRequest(c) {
|
||||
return b.handleStreaming(c, proxy)
|
||||
}
|
||||
|
||||
@@ -87,6 +114,9 @@ func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modi
|
||||
}
|
||||
|
||||
func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReverseProxy) error {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
@@ -96,13 +126,12 @@ func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReversePro
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
|
||||
// Create a new request with the pipe reader as the body
|
||||
// This is a bit of a hack to get ReverseProxy to stream
|
||||
req := c.Request.Clone(c.Request.Context())
|
||||
req.Body = pr
|
||||
|
||||
// Start the proxy in a goroutine
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer pw.Close()
|
||||
proxy.ServeHTTP(c.Writer, req)
|
||||
}()
|
||||
@@ -111,32 +140,16 @@ func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReversePro
|
||||
_, err := io.Copy(pw, c.Request.Body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Error copying request body to pipe: %v", err)
|
||||
wg.Wait() // Wait for the goroutine to finish even if copy fails
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for the proxy to finish
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isStreamingRequest checks if the request is for a streaming response.
|
||||
func isStreamingRequest(c *gin.Context) bool {
|
||||
// For Gemini, streaming is indicated by the path.
|
||||
if strings.Contains(c.Request.URL.Path, ":streamGenerateContent") {
|
||||
return true
|
||||
}
|
||||
|
||||
// For OpenAI, streaming is indicated by a "stream": true field in the JSON body.
|
||||
// We use ShouldBindBodyWith to check the body without consuming it, so it can be read again by the proxy.
|
||||
type streamPayload struct {
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
var p streamPayload
|
||||
if err := c.ShouldBindBodyWith(&p, binding.JSON); err == nil {
|
||||
return p.Stream
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// singleJoiningSlash joins two URL paths with a single slash.
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gpt-load/internal/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -11,4 +12,10 @@ type ChannelProxy interface {
|
||||
// Handle takes a context, an API key, and the original request,
|
||||
// then forwards the request to the upstream service.
|
||||
Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error
|
||||
}
|
||||
|
||||
// ValidateKey checks if the given API key is valid.
|
||||
ValidateKey(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// IsStreamingRequest checks if the request is for a streaming response.
|
||||
IsStreamingRequest(c *gin.Context) bool
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gpt-load/internal/config"
|
||||
"gpt-load/internal/models"
|
||||
@@ -11,36 +12,61 @@ import (
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
// Factory is responsible for creating channel proxies.
|
||||
type Factory struct {
|
||||
settingsManager *config.SystemSettingsManager
|
||||
}
|
||||
|
||||
// NewFactory creates a new channel factory.
|
||||
func NewFactory(settingsManager *config.SystemSettingsManager) *Factory {
|
||||
return &Factory{
|
||||
settingsManager: settingsManager,
|
||||
}
|
||||
}
|
||||
|
||||
// GetChannel returns a channel proxy based on the group's channel type.
|
||||
func GetChannel(group *models.Group) (ChannelProxy, error) {
|
||||
func (f *Factory) GetChannel(group *models.Group) (ChannelProxy, error) {
|
||||
switch group.ChannelType {
|
||||
case "openai":
|
||||
return NewOpenAIChannel(group.Upstreams, group.Config)
|
||||
return f.NewOpenAIChannel(group)
|
||||
case "gemini":
|
||||
return NewGeminiChannel(group.Upstreams, group.Config)
|
||||
return f.NewGeminiChannel(group)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported channel type: %s", group.ChannelType)
|
||||
}
|
||||
}
|
||||
|
||||
// newBaseChannelWithUpstreams is a helper function to create and configure a BaseChannel.
|
||||
func newBaseChannelWithUpstreams(name string, upstreams []string, groupConfig datatypes.JSONMap) (BaseChannel, error) {
|
||||
if len(upstreams) == 0 {
|
||||
return BaseChannel{}, fmt.Errorf("at least one upstream is required for %s channel", name)
|
||||
// newBaseChannel is a helper function to create and configure a BaseChannel.
|
||||
func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, groupConfig datatypes.JSONMap) (*BaseChannel, error) {
|
||||
type upstreamDef struct {
|
||||
URL string `json:"url"`
|
||||
Weight int `json:"weight"`
|
||||
}
|
||||
|
||||
var upstreamURLs []*url.URL
|
||||
for _, us := range upstreams {
|
||||
u, err := url.Parse(us)
|
||||
var defs []upstreamDef
|
||||
if err := json.Unmarshal(upstreamsJSON, &defs); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal upstreams for %s channel: %w", name, err)
|
||||
}
|
||||
|
||||
if len(defs) == 0 {
|
||||
return nil, fmt.Errorf("at least one upstream is required for %s channel", name)
|
||||
}
|
||||
|
||||
var upstreamInfos []UpstreamInfo
|
||||
for _, def := range defs {
|
||||
u, err := url.Parse(def.URL)
|
||||
if err != nil {
|
||||
return BaseChannel{}, fmt.Errorf("failed to parse upstream url '%s' for %s channel: %w", us, name, err)
|
||||
return nil, fmt.Errorf("failed to parse upstream url '%s' for %s channel: %w", def.URL, name, err)
|
||||
}
|
||||
upstreamURLs = append(upstreamURLs, u)
|
||||
weight := def.Weight
|
||||
if weight <= 0 {
|
||||
weight = 1 // Default weight to 1 if not specified or invalid
|
||||
}
|
||||
upstreamInfos = append(upstreamInfos, UpstreamInfo{URL: u, Weight: weight})
|
||||
}
|
||||
|
||||
// Get effective settings by merging system and group configs
|
||||
settingsManager := config.GetSystemSettingsManager()
|
||||
effectiveSettings := settingsManager.GetEffectiveConfig(groupConfig)
|
||||
effectiveSettings := f.settingsManager.GetEffectiveConfig(groupConfig)
|
||||
|
||||
// Configure the HTTP client with the effective timeouts
|
||||
httpClient := &http.Client{
|
||||
@@ -50,9 +76,9 @@ func newBaseChannelWithUpstreams(name string, upstreams []string, groupConfig da
|
||||
Timeout: time.Duration(effectiveSettings.RequestTimeout) * time.Second,
|
||||
}
|
||||
|
||||
return BaseChannel{
|
||||
return &BaseChannel{
|
||||
Name: name,
|
||||
Upstreams: upstreamURLs,
|
||||
Upstreams: upstreamInfos,
|
||||
HTTPClient: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
@@ -1,19 +1,22 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gpt-load/internal/models"
|
||||
"net/http"
|
||||
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
type GeminiChannel struct {
|
||||
BaseChannel
|
||||
*BaseChannel
|
||||
}
|
||||
|
||||
func NewGeminiChannel(upstreams []string, config datatypes.JSONMap) (*GeminiChannel, error) {
|
||||
base, err := newBaseChannelWithUpstreams("gemini", upstreams, config)
|
||||
func (f *Factory) NewGeminiChannel(group *models.Group) (*GeminiChannel, error) {
|
||||
base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -29,5 +32,40 @@ func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo
|
||||
q.Set("key", key.KeyValue)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
return ch.ProcessRequest(c, apiKey, modifier)
|
||||
return ch.ProcessRequest(c, apiKey, modifier, ch)
|
||||
}
|
||||
|
||||
// ValidateKey checks if the given API key is valid by making a request to the models endpoint.
|
||||
func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
|
||||
upstreamURL := ch.getUpstreamURL()
|
||||
if upstreamURL == nil {
|
||||
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
|
||||
}
|
||||
|
||||
// Construct the request URL for listing models.
|
||||
reqURL := fmt.Sprintf("%s/v1beta/models?key=%s", upstreamURL.String(), key)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create validation request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := ch.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to send validation request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// A 200 OK status code indicates the key is valid.
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
|
||||
// IsStreamingRequest checks if the request is for a streaming response.
|
||||
func (ch *GeminiChannel) IsStreamingRequest(c *gin.Context) bool {
|
||||
// For Gemini, streaming is indicated by the path containing streaming keywords
|
||||
path := c.Request.URL.Path
|
||||
return strings.Contains(path, ":streamGenerateContent") ||
|
||||
strings.Contains(path, "streamGenerateContent") ||
|
||||
strings.Contains(path, ":stream") ||
|
||||
strings.Contains(path, "/stream")
|
||||
}
|
||||
|
@@ -1,19 +1,21 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gpt-load/internal/models"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/datatypes"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
)
|
||||
|
||||
type OpenAIChannel struct {
|
||||
BaseChannel
|
||||
*BaseChannel
|
||||
}
|
||||
|
||||
func NewOpenAIChannel(upstreams []string, config datatypes.JSONMap) (*OpenAIChannel, error) {
|
||||
base, err := newBaseChannelWithUpstreams("openai", upstreams, config)
|
||||
func (f *Factory) NewOpenAIChannel(group *models.Group) (*OpenAIChannel, error) {
|
||||
base, err := f.newBaseChannel("openai", group.Upstreams, group.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -27,5 +29,46 @@ func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo
|
||||
modifier := func(req *http.Request, key *models.APIKey) {
|
||||
req.Header.Set("Authorization", "Bearer "+key.KeyValue)
|
||||
}
|
||||
return ch.ProcessRequest(c, apiKey, modifier)
|
||||
return ch.ProcessRequest(c, apiKey, modifier, ch)
|
||||
}
|
||||
|
||||
// ValidateKey checks if the given API key is valid by making a request to the models endpoint.
|
||||
func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
|
||||
upstreamURL := ch.getUpstreamURL()
|
||||
if upstreamURL == nil {
|
||||
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
|
||||
}
|
||||
|
||||
// Construct the request URL for listing models, a common endpoint for key validation.
|
||||
reqURL := upstreamURL.String() + "/v1/models"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create validation request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
resp, err := ch.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to send validation request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// A 200 OK status code indicates the key is valid.
|
||||
// Other status codes (e.g., 401 Unauthorized) indicate an invalid key.
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
|
||||
// IsStreamingRequest checks if the request is for a streaming response.
|
||||
func (ch *OpenAIChannel) IsStreamingRequest(c *gin.Context) bool {
|
||||
// For OpenAI, streaming is indicated by a "stream": true field in the JSON body.
|
||||
// We use ShouldBindBodyWith to check the body without consuming it, so it can be read again by the proxy.
|
||||
type streamPayload struct {
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
var p streamPayload
|
||||
if err := c.ShouldBindBodyWith(&p, binding.JSON); err == nil {
|
||||
return p.Stream
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
Reference in New Issue
Block a user