feat: 密钥管理
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"gpt-load/internal/channel"
|
||||
"gpt-load/internal/config"
|
||||
"gpt-load/internal/db"
|
||||
"gpt-load/internal/handler"
|
||||
@@ -74,15 +75,28 @@ func main() {
|
||||
go startRequestLogger(database, requestLogChan, &wg)
|
||||
// ---
|
||||
|
||||
// --- Service Initialization ---
|
||||
taskService := services.NewTaskService()
|
||||
channelFactory := channel.NewFactory(settingsManager)
|
||||
keyValidatorService := services.NewKeyValidatorService(database, channelFactory)
|
||||
|
||||
keyManualValidationService := services.NewKeyManualValidationService(database, keyValidatorService, taskService, settingsManager)
|
||||
keyCronService := services.NewKeyCronService(database, keyValidatorService, settingsManager)
|
||||
keyCronService.Start()
|
||||
defer keyCronService.Stop()
|
||||
|
||||
keyService := services.NewKeyService(database)
|
||||
// ---
|
||||
|
||||
// Create proxy server
|
||||
proxyServer, err := proxy.NewProxyServer(database, requestLogChan)
|
||||
proxyServer, err := proxy.NewProxyServer(database, channelFactory, requestLogChan)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Failed to create proxy server: %v", err)
|
||||
}
|
||||
defer proxyServer.Close()
|
||||
|
||||
// Create handlers
|
||||
serverHandler := handler.NewServer(database, configManager)
|
||||
serverHandler := handler.NewServer(database, configManager, keyValidatorService, keyManualValidationService, taskService, keyService)
|
||||
logCleanupHandler := handler.NewLogCleanupHandler(logCleanupService)
|
||||
|
||||
// Setup routes using the new router package
|
||||
|
1
go.mod
1
go.mod
@@ -9,6 +9,7 @@ require (
|
||||
github.com/gin-contrib/static v1.1.5
|
||||
github.com/gin-gonic/gin v1.10.1
|
||||
github.com/go-sql-driver/mysql v1.8.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
gorm.io/datatypes v1.2.1
|
||||
|
2
go.sum
2
go.sum
@@ -41,6 +41,8 @@ github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EO
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
|
||||
|
@@ -8,38 +8,65 @@ import (
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// UpstreamInfo holds the information for a single upstream server, including its weight.
|
||||
type UpstreamInfo struct {
|
||||
URL *url.URL
|
||||
Weight int
|
||||
CurrentWeight int
|
||||
}
|
||||
|
||||
// BaseChannel provides common functionality for channel proxies.
|
||||
type BaseChannel struct {
|
||||
Name string
|
||||
Upstreams []*url.URL
|
||||
Upstreams []UpstreamInfo
|
||||
HTTPClient *http.Client
|
||||
roundRobin uint64
|
||||
upstreamLock sync.Mutex
|
||||
}
|
||||
|
||||
// RequestModifier is a function that can modify the request before it's sent.
|
||||
type RequestModifier func(req *http.Request, key *models.APIKey)
|
||||
|
||||
// getUpstreamURL selects an upstream URL using round-robin.
|
||||
// getUpstreamURL selects an upstream URL using a smooth weighted round-robin algorithm.
|
||||
func (b *BaseChannel) getUpstreamURL() *url.URL {
|
||||
b.upstreamLock.Lock()
|
||||
defer b.upstreamLock.Unlock()
|
||||
|
||||
if len(b.Upstreams) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(b.Upstreams) == 1 {
|
||||
return b.Upstreams[0]
|
||||
return b.Upstreams[0].URL
|
||||
}
|
||||
index := atomic.AddUint64(&b.roundRobin, 1) - 1
|
||||
return b.Upstreams[index%uint64(len(b.Upstreams))]
|
||||
|
||||
totalWeight := 0
|
||||
var best *UpstreamInfo
|
||||
|
||||
for i := range b.Upstreams {
|
||||
up := &b.Upstreams[i]
|
||||
totalWeight += up.Weight
|
||||
up.CurrentWeight += up.Weight
|
||||
|
||||
if best == nil || up.CurrentWeight > best.CurrentWeight {
|
||||
best = up
|
||||
}
|
||||
}
|
||||
|
||||
if best == nil {
|
||||
return b.Upstreams[0].URL // 降级到第一个可用的
|
||||
}
|
||||
|
||||
best.CurrentWeight -= totalWeight
|
||||
return best.URL
|
||||
}
|
||||
|
||||
// ProcessRequest handles the common logic of processing and forwarding a request.
|
||||
func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier) error {
|
||||
func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier, ch ChannelProxy) error {
|
||||
upstreamURL := b.getUpstreamURL()
|
||||
if upstreamURL == nil {
|
||||
return fmt.Errorf("no upstream URL configured for channel %s", b.Name)
|
||||
@@ -78,7 +105,7 @@ func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modi
|
||||
}
|
||||
|
||||
// Check if the client request is for a streaming endpoint
|
||||
if isStreamingRequest(c) {
|
||||
if ch.IsStreamingRequest(c) {
|
||||
return b.handleStreaming(c, proxy)
|
||||
}
|
||||
|
||||
@@ -87,6 +114,9 @@ func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modi
|
||||
}
|
||||
|
||||
func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReverseProxy) error {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
@@ -96,13 +126,12 @@ func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReversePro
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
|
||||
// Create a new request with the pipe reader as the body
|
||||
// This is a bit of a hack to get ReverseProxy to stream
|
||||
req := c.Request.Clone(c.Request.Context())
|
||||
req.Body = pr
|
||||
|
||||
// Start the proxy in a goroutine
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer pw.Close()
|
||||
proxy.ServeHTTP(c.Writer, req)
|
||||
}()
|
||||
@@ -111,32 +140,16 @@ func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReversePro
|
||||
_, err := io.Copy(pw, c.Request.Body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Error copying request body to pipe: %v", err)
|
||||
wg.Wait() // Wait for the goroutine to finish even if copy fails
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for the proxy to finish
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isStreamingRequest checks if the request is for a streaming response.
|
||||
func isStreamingRequest(c *gin.Context) bool {
|
||||
// For Gemini, streaming is indicated by the path.
|
||||
if strings.Contains(c.Request.URL.Path, ":streamGenerateContent") {
|
||||
return true
|
||||
}
|
||||
|
||||
// For OpenAI, streaming is indicated by a "stream": true field in the JSON body.
|
||||
// We use ShouldBindBodyWith to check the body without consuming it, so it can be read again by the proxy.
|
||||
type streamPayload struct {
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
var p streamPayload
|
||||
if err := c.ShouldBindBodyWith(&p, binding.JSON); err == nil {
|
||||
return p.Stream
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// singleJoiningSlash joins two URL paths with a single slash.
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gpt-load/internal/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -11,4 +12,10 @@ type ChannelProxy interface {
|
||||
// Handle takes a context, an API key, and the original request,
|
||||
// then forwards the request to the upstream service.
|
||||
Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error
|
||||
|
||||
// ValidateKey checks if the given API key is valid.
|
||||
ValidateKey(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// IsStreamingRequest checks if the request is for a streaming response.
|
||||
IsStreamingRequest(c *gin.Context) bool
|
||||
}
|
@@ -1,6 +1,7 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gpt-load/internal/config"
|
||||
"gpt-load/internal/models"
|
||||
@@ -11,36 +12,61 @@ import (
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
// Factory is responsible for creating channel proxies.
|
||||
type Factory struct {
|
||||
settingsManager *config.SystemSettingsManager
|
||||
}
|
||||
|
||||
// NewFactory creates a new channel factory.
|
||||
func NewFactory(settingsManager *config.SystemSettingsManager) *Factory {
|
||||
return &Factory{
|
||||
settingsManager: settingsManager,
|
||||
}
|
||||
}
|
||||
|
||||
// GetChannel returns a channel proxy based on the group's channel type.
|
||||
func GetChannel(group *models.Group) (ChannelProxy, error) {
|
||||
func (f *Factory) GetChannel(group *models.Group) (ChannelProxy, error) {
|
||||
switch group.ChannelType {
|
||||
case "openai":
|
||||
return NewOpenAIChannel(group.Upstreams, group.Config)
|
||||
return f.NewOpenAIChannel(group)
|
||||
case "gemini":
|
||||
return NewGeminiChannel(group.Upstreams, group.Config)
|
||||
return f.NewGeminiChannel(group)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported channel type: %s", group.ChannelType)
|
||||
}
|
||||
}
|
||||
|
||||
// newBaseChannelWithUpstreams is a helper function to create and configure a BaseChannel.
|
||||
func newBaseChannelWithUpstreams(name string, upstreams []string, groupConfig datatypes.JSONMap) (BaseChannel, error) {
|
||||
if len(upstreams) == 0 {
|
||||
return BaseChannel{}, fmt.Errorf("at least one upstream is required for %s channel", name)
|
||||
// newBaseChannel is a helper function to create and configure a BaseChannel.
|
||||
func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, groupConfig datatypes.JSONMap) (*BaseChannel, error) {
|
||||
type upstreamDef struct {
|
||||
URL string `json:"url"`
|
||||
Weight int `json:"weight"`
|
||||
}
|
||||
|
||||
var upstreamURLs []*url.URL
|
||||
for _, us := range upstreams {
|
||||
u, err := url.Parse(us)
|
||||
if err != nil {
|
||||
return BaseChannel{}, fmt.Errorf("failed to parse upstream url '%s' for %s channel: %w", us, name, err)
|
||||
var defs []upstreamDef
|
||||
if err := json.Unmarshal(upstreamsJSON, &defs); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal upstreams for %s channel: %w", name, err)
|
||||
}
|
||||
upstreamURLs = append(upstreamURLs, u)
|
||||
|
||||
if len(defs) == 0 {
|
||||
return nil, fmt.Errorf("at least one upstream is required for %s channel", name)
|
||||
}
|
||||
|
||||
var upstreamInfos []UpstreamInfo
|
||||
for _, def := range defs {
|
||||
u, err := url.Parse(def.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse upstream url '%s' for %s channel: %w", def.URL, name, err)
|
||||
}
|
||||
weight := def.Weight
|
||||
if weight <= 0 {
|
||||
weight = 1 // Default weight to 1 if not specified or invalid
|
||||
}
|
||||
upstreamInfos = append(upstreamInfos, UpstreamInfo{URL: u, Weight: weight})
|
||||
}
|
||||
|
||||
// Get effective settings by merging system and group configs
|
||||
settingsManager := config.GetSystemSettingsManager()
|
||||
effectiveSettings := settingsManager.GetEffectiveConfig(groupConfig)
|
||||
effectiveSettings := f.settingsManager.GetEffectiveConfig(groupConfig)
|
||||
|
||||
// Configure the HTTP client with the effective timeouts
|
||||
httpClient := &http.Client{
|
||||
@@ -50,9 +76,9 @@ func newBaseChannelWithUpstreams(name string, upstreams []string, groupConfig da
|
||||
Timeout: time.Duration(effectiveSettings.RequestTimeout) * time.Second,
|
||||
}
|
||||
|
||||
return BaseChannel{
|
||||
return &BaseChannel{
|
||||
Name: name,
|
||||
Upstreams: upstreamURLs,
|
||||
Upstreams: upstreamInfos,
|
||||
HTTPClient: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
@@ -1,19 +1,22 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gpt-load/internal/models"
|
||||
"net/http"
|
||||
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
type GeminiChannel struct {
|
||||
BaseChannel
|
||||
*BaseChannel
|
||||
}
|
||||
|
||||
func NewGeminiChannel(upstreams []string, config datatypes.JSONMap) (*GeminiChannel, error) {
|
||||
base, err := newBaseChannelWithUpstreams("gemini", upstreams, config)
|
||||
func (f *Factory) NewGeminiChannel(group *models.Group) (*GeminiChannel, error) {
|
||||
base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -29,5 +32,40 @@ func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo
|
||||
q.Set("key", key.KeyValue)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
return ch.ProcessRequest(c, apiKey, modifier)
|
||||
return ch.ProcessRequest(c, apiKey, modifier, ch)
|
||||
}
|
||||
|
||||
// ValidateKey checks if the given API key is valid by making a request to the models endpoint.
|
||||
func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
|
||||
upstreamURL := ch.getUpstreamURL()
|
||||
if upstreamURL == nil {
|
||||
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
|
||||
}
|
||||
|
||||
// Construct the request URL for listing models.
|
||||
reqURL := fmt.Sprintf("%s/v1beta/models?key=%s", upstreamURL.String(), key)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create validation request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := ch.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to send validation request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// A 200 OK status code indicates the key is valid.
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
|
||||
// IsStreamingRequest checks if the request is for a streaming response.
|
||||
func (ch *GeminiChannel) IsStreamingRequest(c *gin.Context) bool {
|
||||
// For Gemini, streaming is indicated by the path containing streaming keywords
|
||||
path := c.Request.URL.Path
|
||||
return strings.Contains(path, ":streamGenerateContent") ||
|
||||
strings.Contains(path, "streamGenerateContent") ||
|
||||
strings.Contains(path, ":stream") ||
|
||||
strings.Contains(path, "/stream")
|
||||
}
|
||||
|
@@ -1,19 +1,21 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gpt-load/internal/models"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/datatypes"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
)
|
||||
|
||||
type OpenAIChannel struct {
|
||||
BaseChannel
|
||||
*BaseChannel
|
||||
}
|
||||
|
||||
func NewOpenAIChannel(upstreams []string, config datatypes.JSONMap) (*OpenAIChannel, error) {
|
||||
base, err := newBaseChannelWithUpstreams("openai", upstreams, config)
|
||||
func (f *Factory) NewOpenAIChannel(group *models.Group) (*OpenAIChannel, error) {
|
||||
base, err := f.newBaseChannel("openai", group.Upstreams, group.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -27,5 +29,46 @@ func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo
|
||||
modifier := func(req *http.Request, key *models.APIKey) {
|
||||
req.Header.Set("Authorization", "Bearer "+key.KeyValue)
|
||||
}
|
||||
return ch.ProcessRequest(c, apiKey, modifier)
|
||||
return ch.ProcessRequest(c, apiKey, modifier, ch)
|
||||
}
|
||||
|
||||
// ValidateKey checks if the given API key is valid by making a request to the models endpoint.
|
||||
func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
|
||||
upstreamURL := ch.getUpstreamURL()
|
||||
if upstreamURL == nil {
|
||||
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
|
||||
}
|
||||
|
||||
// Construct the request URL for listing models, a common endpoint for key validation.
|
||||
reqURL := upstreamURL.String() + "/v1/models"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create validation request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
resp, err := ch.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to send validation request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// A 200 OK status code indicates the key is valid.
|
||||
// Other status codes (e.g., 401 Unauthorized) indicate an invalid key.
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
|
||||
// IsStreamingRequest checks if the request is for a streaming response.
|
||||
func (ch *OpenAIChannel) IsStreamingRequest(c *gin.Context) bool {
|
||||
// For OpenAI, streaming is indicated by a "stream": true field in the JSON body.
|
||||
// We use ShouldBindBodyWith to check the body without consuming it, so it can be read again by the proxy.
|
||||
type streamPayload struct {
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
var p streamPayload
|
||||
if err := c.ShouldBindBodyWith(&p, binding.JSON); err == nil {
|
||||
return p.Stream
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@@ -276,3 +276,16 @@ func getEnvOrDefault(key, defaultValue string) string {
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// GetInt is a helper function for SystemSettingsManager to get an integer value with a default.
|
||||
func (s *SystemSettingsManager) GetInt(key string, defaultValue int) int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if valStr, ok := s.settingsCache[key]; ok {
|
||||
if valInt, err := strconv.Atoi(valStr); err == nil {
|
||||
return valInt
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
@@ -34,6 +34,11 @@ type SystemSettings struct {
|
||||
|
||||
// 请求日志配置(数据库日志)
|
||||
RequestLogRetentionDays int `json:"request_log_retention_days" default:"30" name:"日志保留天数" category:"日志配置" desc:"请求日志在数据库中的保留天数" validate:"min=1"`
|
||||
|
||||
// 密钥验证配置
|
||||
KeyValidationIntervalMinutes int `json:"key_validation_interval_minutes" default:"60" name:"定时验证周期" category:"密钥验证" desc:"后台定时验证密钥的默认周期(分钟)" validate:"min=5"`
|
||||
KeyValidationConcurrency int `json:"key_validation_concurrency" default:"10" name:"验证并发数" category:"密钥验证" desc:"执行密钥验证时的并发 goroutine 数量" validate:"min=1,max=100"`
|
||||
KeyValidationTaskTimeoutMinutes int `json:"key_validation_task_timeout_minutes" default:"60" name:"手动验证超时" category:"密钥验证" desc:"手动触发的全量验证任务的超时时间(分钟)" validate:"min=10"`
|
||||
}
|
||||
|
||||
// GenerateSettingsMetadata 使用反射从 SystemSettings 结构体动态生成元数据
|
||||
@@ -106,6 +111,7 @@ func DefaultSystemSettings() SystemSettings {
|
||||
// SystemSettingsManager 管理系统配置
|
||||
type SystemSettingsManager struct {
|
||||
settings SystemSettings
|
||||
settingsCache map[string]string // Cache for raw string values
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -169,6 +175,8 @@ func (sm *SystemSettingsManager) LoadFromDatabase() error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
sm.settingsCache = settingsMap
|
||||
|
||||
// 使用默认值,然后用数据库中的值覆盖
|
||||
sm.settings = DefaultSystemSettings()
|
||||
sm.mapToStruct(settingsMap, &sm.settings)
|
||||
@@ -334,6 +342,8 @@ func (sm *SystemSettingsManager) DisplayCurrentSettings() {
|
||||
logrus.Infof(" Request timeouts: request=%ds, response=%ds, idle_conn=%ds",
|
||||
sm.settings.RequestTimeout, sm.settings.ResponseTimeout, sm.settings.IdleConnTimeout)
|
||||
logrus.Infof(" Request log retention: %d days", sm.settings.RequestLogRetentionDays)
|
||||
logrus.Infof(" Key validation: interval=%dmin, concurrency=%d, task_timeout=%dmin",
|
||||
sm.settings.KeyValidationIntervalMinutes, sm.settings.KeyValidationConcurrency, sm.settings.KeyValidationTaskTimeoutMinutes)
|
||||
}
|
||||
|
||||
// 辅助方法
|
||||
|
@@ -31,6 +31,8 @@ var (
|
||||
ErrDatabase = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "DATABASE_ERROR", Message: "Database operation failed"}
|
||||
ErrUnauthorized = &APIError{HTTPStatus: http.StatusUnauthorized, Code: "UNAUTHORIZED", Message: "Authentication failed"}
|
||||
ErrForbidden = &APIError{HTTPStatus: http.StatusForbidden, Code: "FORBIDDEN", Message: "You do not have permission to access this resource"}
|
||||
ErrTaskInProgress = &APIError{HTTPStatus: http.StatusConflict, Code: "TASK_IN_PROGRESS", Message: "A task is already in progress"}
|
||||
ErrBadGateway = &APIError{HTTPStatus: http.StatusBadGateway, Code: "BAD_GATEWAY", Message: "Upstream service error"}
|
||||
)
|
||||
|
||||
// NewAPIError creates a new APIError with a custom message.
|
||||
|
@@ -3,6 +3,7 @@ package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
app_errors "gpt-load/internal/errors"
|
||||
"gpt-load/internal/models"
|
||||
"gpt-load/internal/response"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
// isValidGroupName checks if the group name is valid.
|
||||
@@ -22,6 +24,50 @@ func isValidGroupName(name string) bool {
|
||||
return match
|
||||
}
|
||||
|
||||
// validateAndCleanConfig validates the group config against the GroupConfig struct.
|
||||
func validateAndCleanConfig(configMap map[string]interface{}) (map[string]interface{}, error) {
|
||||
if configMap == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
configBytes, err := json.Marshal(configMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var validatedConfig models.GroupConfig
|
||||
if err := json.Unmarshal(configBytes, &validatedConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证配置项的合理范围
|
||||
if validatedConfig.BlacklistThreshold != nil && *validatedConfig.BlacklistThreshold < 0 {
|
||||
return nil, fmt.Errorf("blacklist_threshold must be >= 0")
|
||||
}
|
||||
if validatedConfig.MaxRetries != nil && (*validatedConfig.MaxRetries < 0 || *validatedConfig.MaxRetries > 10) {
|
||||
return nil, fmt.Errorf("max_retries must be between 0 and 10")
|
||||
}
|
||||
if validatedConfig.RequestTimeout != nil && (*validatedConfig.RequestTimeout < 1 || *validatedConfig.RequestTimeout > 3600) {
|
||||
return nil, fmt.Errorf("request_timeout must be between 1 and 3600 seconds")
|
||||
}
|
||||
if validatedConfig.KeyValidationIntervalMinutes != nil && (*validatedConfig.KeyValidationIntervalMinutes < 5 || *validatedConfig.KeyValidationIntervalMinutes > 1440) {
|
||||
return nil, fmt.Errorf("key_validation_interval_minutes must be between 5 and 1440 minutes")
|
||||
}
|
||||
|
||||
// Marshal back to a map to remove any fields not in GroupConfig
|
||||
validatedBytes, err := json.Marshal(validatedConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cleanedMap map[string]interface{}
|
||||
if err := json.Unmarshal(validatedBytes, &cleanedMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cleanedMap, nil
|
||||
}
|
||||
|
||||
// CreateGroup handles the creation of a new group.
|
||||
func (s *Server) CreateGroup(c *gin.Context) {
|
||||
var group models.Group
|
||||
@@ -43,6 +89,17 @@ func (s *Server) CreateGroup(c *gin.Context) {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Channel type is required"))
|
||||
return
|
||||
}
|
||||
if group.TestModel == "" {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Test model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
cleanedConfig, err := validateAndCleanConfig(group.Config)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid config format"))
|
||||
return
|
||||
}
|
||||
group.Config = cleanedConfig
|
||||
|
||||
if err := s.DB.Create(&group).Error; err != nil {
|
||||
response.Error(c, app_errors.ParseDBError(err))
|
||||
@@ -62,6 +119,20 @@ func (s *Server) ListGroups(c *gin.Context) {
|
||||
response.Success(c, groups)
|
||||
}
|
||||
|
||||
// GroupUpdateRequest defines the payload for updating a group.
|
||||
// Using a dedicated struct avoids issues with zero values being ignored by GORM's Update.
|
||||
type GroupUpdateRequest struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Description string `json:"description"`
|
||||
Upstreams json.RawMessage `json:"upstreams"`
|
||||
ChannelType string `json:"channel_type"`
|
||||
Sort *int `json:"sort"`
|
||||
TestModel string `json:"test_model"`
|
||||
ParamOverrides map[string]interface{} `json:"param_overrides"`
|
||||
Config map[string]interface{} `json:"config"`
|
||||
}
|
||||
|
||||
// UpdateGroup handles updating an existing group.
|
||||
func (s *Server) UpdateGroup(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
@@ -76,84 +147,70 @@ func (s *Server) UpdateGroup(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var updateData models.Group
|
||||
if err := c.ShouldBindJSON(&updateData); err != nil {
|
||||
var req GroupUpdateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate group name if it's being updated
|
||||
if updateData.Name != "" && !isValidGroupName(updateData.Name) {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid group name format. Use 3-30 lowercase letters, numbers, and underscores."))
|
||||
return
|
||||
}
|
||||
|
||||
// Use a transaction to ensure atomicity
|
||||
// Start a transaction
|
||||
tx := s.DB.Begin()
|
||||
if tx.Error != nil {
|
||||
response.Error(c, app_errors.ErrDatabase)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback() // Rollback on panic
|
||||
|
||||
// Convert updateData to a map to ensure zero values (like Sort: 0) are updated
|
||||
var updateMap map[string]interface{}
|
||||
updateBytes, _ := json.Marshal(updateData)
|
||||
if err := json.Unmarshal(updateBytes, &updateMap); err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Failed to process update data"))
|
||||
// Apply updates from the request
|
||||
if req.Name != "" {
|
||||
if !isValidGroupName(req.Name) {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid group name format."))
|
||||
return
|
||||
}
|
||||
|
||||
// If config is being updated, it needs to be marshalled to JSON string for GORM
|
||||
if config, ok := updateMap["config"]; ok {
|
||||
if configMap, isMap := config.(map[string]interface{}); isMap {
|
||||
configJSON, err := json.Marshal(configMap)
|
||||
group.Name = req.Name
|
||||
}
|
||||
if req.DisplayName != "" {
|
||||
group.DisplayName = req.DisplayName
|
||||
}
|
||||
if req.Description != "" {
|
||||
group.Description = req.Description
|
||||
}
|
||||
if req.Upstreams != nil {
|
||||
group.Upstreams = datatypes.JSON(req.Upstreams)
|
||||
}
|
||||
if req.ChannelType != "" {
|
||||
group.ChannelType = req.ChannelType
|
||||
}
|
||||
if req.Sort != nil {
|
||||
group.Sort = *req.Sort
|
||||
}
|
||||
if req.TestModel != "" {
|
||||
group.TestModel = req.TestModel
|
||||
}
|
||||
if req.ParamOverrides != nil {
|
||||
group.ParamOverrides = req.ParamOverrides
|
||||
}
|
||||
if req.Config != nil {
|
||||
cleanedConfig, err := validateAndCleanConfig(req.Config)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Failed to process config data"))
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid config format"))
|
||||
return
|
||||
}
|
||||
updateMap["config"] = string(configJSON)
|
||||
}
|
||||
group.Config = cleanedConfig
|
||||
}
|
||||
|
||||
// Handle upstreams field specifically
|
||||
if upstreams, ok := updateMap["upstreams"]; ok {
|
||||
if upstreamsSlice, isSlice := upstreams.([]interface{}); isSlice {
|
||||
upstreamsJSON, err := json.Marshal(upstreamsSlice)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Failed to process upstreams data"))
|
||||
return
|
||||
}
|
||||
updateMap["upstreams"] = string(upstreamsJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove fields that are not actual columns or should not be updated from the map
|
||||
delete(updateMap, "id")
|
||||
delete(updateMap, "api_keys")
|
||||
delete(updateMap, "created_at")
|
||||
delete(updateMap, "updated_at")
|
||||
|
||||
// Use Updates with a map to only update provided fields, including zero values
|
||||
if err := tx.Model(&group).Updates(updateMap).Error; err != nil {
|
||||
tx.Rollback()
|
||||
// Save the updated group object
|
||||
if err := tx.Save(&group).Error; err != nil {
|
||||
response.Error(c, app_errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
tx.Rollback()
|
||||
response.Error(c, app_errors.ErrDatabase)
|
||||
return
|
||||
}
|
||||
|
||||
// Re-fetch the group to return the updated data
|
||||
var updatedGroup models.Group
|
||||
if err := s.DB.First(&updatedGroup, id).Error; err != nil {
|
||||
response.Error(c, app_errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, updatedGroup)
|
||||
response.Success(c, group)
|
||||
}
|
||||
|
||||
// DeleteGroup handles deleting a group.
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"gpt-load/internal/models"
|
||||
"gpt-load/internal/services"
|
||||
"gpt-load/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -16,13 +17,28 @@ import (
|
||||
type Server struct {
|
||||
DB *gorm.DB
|
||||
config types.ConfigManager
|
||||
KeyValidatorService *services.KeyValidatorService
|
||||
KeyManualValidationService *services.KeyManualValidationService
|
||||
TaskService *services.TaskService
|
||||
KeyService *services.KeyService
|
||||
}
|
||||
|
||||
// NewServer creates a new handler instance
|
||||
func NewServer(db *gorm.DB, config types.ConfigManager) *Server {
|
||||
func NewServer(
|
||||
db *gorm.DB,
|
||||
config types.ConfigManager,
|
||||
keyValidatorService *services.KeyValidatorService,
|
||||
keyManualValidationService *services.KeyManualValidationService,
|
||||
taskService *services.TaskService,
|
||||
keyService *services.KeyService,
|
||||
) *Server {
|
||||
return &Server{
|
||||
DB: db,
|
||||
config: config,
|
||||
KeyValidatorService: keyValidatorService,
|
||||
KeyManualValidationService: keyManualValidationService,
|
||||
TaskService: taskService,
|
||||
KeyService: keyService,
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,60 +1,123 @@
|
||||
// Package handler provides HTTP handlers for the application
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
app_errors "gpt-load/internal/errors"
|
||||
"gpt-load/internal/models"
|
||||
"gpt-load/internal/response"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type CreateKeysRequest struct {
|
||||
Keys []string `json:"keys" binding:"required"`
|
||||
// validateGroupID validates and parses group ID from request parameter
|
||||
func validateGroupID(c *gin.Context) (uint, error) {
|
||||
groupIDStr := c.Param("id")
|
||||
if groupIDStr == "" {
|
||||
return 0, fmt.Errorf("group ID is required")
|
||||
}
|
||||
|
||||
// CreateKeysInGroup handles creating new keys within a specific group.
|
||||
func (s *Server) CreateKeysInGroup(c *gin.Context) {
|
||||
groupID, err := strconv.Atoi(c.Param("id"))
|
||||
groupID, err := strconv.Atoi(groupIDStr)
|
||||
if err != nil || groupID <= 0 {
|
||||
return 0, fmt.Errorf("invalid group ID format")
|
||||
}
|
||||
|
||||
return uint(groupID), nil
|
||||
}
|
||||
|
||||
// validateKeyID validates and parses key ID from request parameter
|
||||
func validateKeyID(c *gin.Context) (uint, error) {
|
||||
keyIDStr := c.Param("key_id")
|
||||
if keyIDStr == "" {
|
||||
return 0, fmt.Errorf("key ID is required")
|
||||
}
|
||||
|
||||
keyID, err := strconv.Atoi(keyIDStr)
|
||||
if err != nil || keyID <= 0 {
|
||||
return 0, fmt.Errorf("invalid key ID format")
|
||||
}
|
||||
|
||||
return uint(keyID), nil
|
||||
}
|
||||
|
||||
// validateKeysText validates the keys text input
|
||||
func validateKeysText(keysText string) error {
|
||||
if strings.TrimSpace(keysText) == "" {
|
||||
return fmt.Errorf("keys text cannot be empty")
|
||||
}
|
||||
|
||||
if len(keysText) > 1024*1024 { // 1MB limit
|
||||
return fmt.Errorf("keys text is too large (max 1MB)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findGroupByID is a helper function to find a group by its ID.
|
||||
func (s *Server) findGroupByID(c *gin.Context, groupID int) (*models.Group, bool) {
|
||||
var group models.Group
|
||||
if err := s.DB.First(&group, groupID).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
response.Error(c, app_errors.ErrResourceNotFound)
|
||||
} else {
|
||||
response.Error(c, app_errors.ParseDBError(err))
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
return &group, true
|
||||
}
|
||||
|
||||
// AddMultipleKeysRequest defines the payload for adding multiple keys from a text block.
|
||||
type AddMultipleKeysRequest struct {
|
||||
KeysText string `json:"keys_text" binding:"required"`
|
||||
}
|
||||
|
||||
// AddMultipleKeys handles creating new keys from a text block within a specific group.
|
||||
func (s *Server) AddMultipleKeys(c *gin.Context) {
|
||||
groupID, err := validateGroupID(c)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format"))
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateKeysRequest
|
||||
var req AddMultipleKeysRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var newKeys []models.APIKey
|
||||
for _, keyVal := range req.Keys {
|
||||
newKeys = append(newKeys, models.APIKey{
|
||||
GroupID: uint(groupID),
|
||||
KeyValue: keyVal,
|
||||
Status: "active",
|
||||
})
|
||||
if err := validateKeysText(req.KeysText); err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.DB.Create(&newKeys).Error; err != nil {
|
||||
result, err := s.KeyService.AddMultipleKeys(groupID, req.KeysText)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, newKeys)
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ListKeysInGroup handles listing all keys within a specific group.
|
||||
func (s *Server) ListKeysInGroup(c *gin.Context) {
|
||||
groupID, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format"))
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
|
||||
return
|
||||
}
|
||||
|
||||
var keys []models.APIKey
|
||||
if err := s.DB.Where("group_id = ?", groupID).Find(&keys).Error; err != nil {
|
||||
statusFilter := c.Query("status")
|
||||
if statusFilter != "" && statusFilter != "active" && statusFilter != "inactive" {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid status filter"))
|
||||
return
|
||||
}
|
||||
|
||||
keys, err := s.KeyService.ListKeysInGroup(uint(groupID), statusFilter)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
@@ -62,90 +125,124 @@ func (s *Server) ListKeysInGroup(c *gin.Context) {
|
||||
response.Success(c, keys)
|
||||
}
|
||||
|
||||
// UpdateKey handles updating a specific key.
|
||||
func (s *Server) UpdateKey(c *gin.Context) {
|
||||
// DeleteSingleKey handles deleting a specific key.
|
||||
func (s *Server) DeleteSingleKey(c *gin.Context) {
|
||||
groupID, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format"))
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := strconv.Atoi(c.Param("key_id"))
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid key ID format"))
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid key ID"))
|
||||
return
|
||||
}
|
||||
|
||||
var key models.APIKey
|
||||
if err := s.DB.Where("group_id = ? AND id = ?", groupID, keyID).First(&key).Error; err != nil {
|
||||
rowsAffected, err := s.KeyService.DeleteSingleKey(uint(groupID), uint(keyID))
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
var updateData struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&updateData); err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
|
||||
if rowsAffected == 0 {
|
||||
response.Error(c, app_errors.ErrResourceNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
key.Status = updateData.Status
|
||||
if err := s.DB.Save(&key).Error; err != nil {
|
||||
response.Error(c, app_errors.ParseDBError(err))
|
||||
response.Success(c, gin.H{"message": "Key deleted successfully"})
|
||||
}
|
||||
|
||||
// TestSingleKey handles a one-off validation test for a single key.
|
||||
func (s *Server) TestSingleKey(c *gin.Context) {
|
||||
keyID, err := validateKeyID(c)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, key)
|
||||
isValid, validationErr := s.KeyValidatorService.TestSingleKeyByID(c.Request.Context(), keyID)
|
||||
if validationErr != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadGateway, validationErr.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
type DeleteKeysRequest struct {
|
||||
KeyIDs []uint `json:"key_ids" binding:"required"`
|
||||
if isValid {
|
||||
response.Success(c, gin.H{"success": true, "message": "Key is valid."})
|
||||
} else {
|
||||
response.Success(c, gin.H{"success": false, "message": "Key is invalid or has insufficient quota."})
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteKeys handles deleting one or more keys.
|
||||
func (s *Server) DeleteKeys(c *gin.Context) {
|
||||
// ValidateGroupKeys initiates a manual validation task for all keys in a group.
|
||||
func (s *Server) ValidateGroupKeys(c *gin.Context) {
|
||||
groupID, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format"))
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
|
||||
return
|
||||
}
|
||||
|
||||
var req DeleteKeysRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
|
||||
group, ok := s.findGroupByID(c, groupID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.KeyIDs) == 0 {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "No key IDs provided"))
|
||||
taskStatus, err := s.KeyManualValidationService.StartValidationTask(group)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx := s.DB.Begin()
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// Verify all keys belong to the specified group
|
||||
var count int64
|
||||
if err := tx.Model(&models.APIKey{}).Where("id IN ? AND group_id = ?", req.KeyIDs, groupID).Count(&count).Error; err != nil {
|
||||
tx.Rollback()
|
||||
// RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'.
|
||||
func (s *Server) RestoreAllInvalidKeys(c *gin.Context) {
|
||||
groupID, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
|
||||
return
|
||||
}
|
||||
|
||||
rowsAffected, err := s.KeyService.RestoreAllInvalidKeys(uint(groupID))
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
if count != int64(len(req.KeyIDs)) {
|
||||
tx.Rollback()
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrForbidden, "One or more keys do not belong to the specified group"))
|
||||
response.Success(c, gin.H{"message": fmt.Sprintf("%d keys restored.", rowsAffected)})
|
||||
}
|
||||
|
||||
// ClearAllInvalidKeys deletes all 'inactive' keys from a group.
|
||||
func (s *Server) ClearAllInvalidKeys(c *gin.Context) {
|
||||
groupID, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the keys
|
||||
if err := tx.Where("id IN ?", req.KeyIDs).Delete(&models.APIKey{}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
rowsAffected, err := s.KeyService.ClearAllInvalidKeys(uint(groupID))
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
response.Success(c, gin.H{"message": "Keys deleted successfully"})
|
||||
response.Success(c, gin.H{"message": fmt.Sprintf("%d invalid keys cleared.", rowsAffected)})
|
||||
}
|
||||
|
||||
// ExportKeys returns a list of keys for a group, filtered by status.
|
||||
func (s *Server) ExportKeys(c *gin.Context) {
|
||||
groupID, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
|
||||
return
|
||||
}
|
||||
|
||||
filter := c.DefaultQuery("filter", "all")
|
||||
keys, err := s.KeyService.ExportKeys(uint(groupID), filter)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"keys": keys})
|
||||
}
|
||||
|
31
internal/handler/task_handler.go
Normal file
31
internal/handler/task_handler.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"gpt-load/internal/response"
|
||||
app_errors "gpt-load/internal/errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetTaskStatus handles requests for the status of the global long-running task.
|
||||
func (s *Server) GetTaskStatus(c *gin.Context) {
|
||||
taskStatus := s.TaskService.GetTaskStatus()
|
||||
response.Success(c, taskStatus)
|
||||
}
|
||||
|
||||
// GetTaskResult handles requests for the result of a finished task.
|
||||
func (s *Server) GetTaskResult(c *gin.Context) {
|
||||
taskID := c.Param("task_id")
|
||||
if taskID == "" {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Task ID is required"))
|
||||
return
|
||||
}
|
||||
|
||||
result, found := s.TaskService.GetResult(taskID)
|
||||
if !found {
|
||||
response.Error(c, app_errors.ErrResourceNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
@@ -1,9 +1,6 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/datatypes"
|
||||
@@ -19,26 +16,6 @@ type SystemSetting struct {
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// Upstreams 是一个上游地址的切片,可以被 GORM 正确处理
|
||||
type Upstreams []string
|
||||
|
||||
// Value 实现 driver.Valuer 接口,用于将 Upstreams 类型转换为数据库值
|
||||
func (u Upstreams) Value() (driver.Value, error) {
|
||||
if len(u) == 0 {
|
||||
return "[]", nil
|
||||
}
|
||||
return json.Marshal(u)
|
||||
}
|
||||
|
||||
// Scan 实现 sql.Scanner 接口,用于将数据库值扫描到 Upstreams 类型
|
||||
func (u *Upstreams) Scan(value interface{}) error {
|
||||
bytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("type assertion to []byte failed")
|
||||
}
|
||||
return json.Unmarshal(bytes, u)
|
||||
}
|
||||
|
||||
// GroupConfig 存储特定于分组的配置
|
||||
type GroupConfig struct {
|
||||
BlacklistThreshold *int `json:"blacklist_threshold,omitempty"`
|
||||
@@ -50,6 +27,7 @@ type GroupConfig struct {
|
||||
RequestTimeout *int `json:"request_timeout,omitempty"`
|
||||
ResponseTimeout *int `json:"response_timeout,omitempty"`
|
||||
IdleConnTimeout *int `json:"idle_conn_timeout,omitempty"`
|
||||
KeyValidationIntervalMinutes *int `json:"key_validation_interval_minutes,omitempty"`
|
||||
}
|
||||
|
||||
// Group 对应 groups 表
|
||||
@@ -58,11 +36,14 @@ type Group struct {
|
||||
Name string `gorm:"type:varchar(255);not null;unique" json:"name"`
|
||||
DisplayName string `gorm:"type:varchar(255)" json:"display_name"`
|
||||
Description string `gorm:"type:varchar(512)" json:"description"`
|
||||
Upstreams Upstreams `gorm:"type:json;not null" json:"upstreams"`
|
||||
Upstreams datatypes.JSON `gorm:"type:json;not null" json:"upstreams"`
|
||||
ChannelType string `gorm:"type:varchar(50);not null" json:"channel_type"`
|
||||
Sort int `gorm:"default:0" json:"sort"`
|
||||
TestModel string `gorm:"type:varchar(255);not null" json:"test_model"`
|
||||
ParamOverrides datatypes.JSONMap `gorm:"type:json" json:"param_overrides"`
|
||||
Config datatypes.JSONMap `gorm:"type:json" json:"config"`
|
||||
APIKeys []APIKey `gorm:"foreignKey:GroupID" json:"api_keys"`
|
||||
LastValidatedAt *time.Time `json:"last_validated_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
@@ -2,11 +2,14 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gpt-load/internal/channel"
|
||||
app_errors "gpt-load/internal/errors"
|
||||
"gpt-load/internal/models"
|
||||
"gpt-load/internal/response"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -19,14 +22,16 @@ import (
|
||||
// ProxyServer represents the proxy server
|
||||
type ProxyServer struct {
|
||||
DB *gorm.DB
|
||||
channelFactory *channel.Factory
|
||||
groupCounters sync.Map // map[uint]*atomic.Uint64
|
||||
requestLogChan chan models.RequestLog
|
||||
}
|
||||
|
||||
// NewProxyServer creates a new proxy server
|
||||
func NewProxyServer(db *gorm.DB, requestLogChan chan models.RequestLog) (*ProxyServer, error) {
|
||||
func NewProxyServer(db *gorm.DB, channelFactory *channel.Factory, requestLogChan chan models.RequestLog) (*ProxyServer, error) {
|
||||
return &ProxyServer{
|
||||
DB: db,
|
||||
channelFactory: channelFactory,
|
||||
groupCounters: sync.Map{},
|
||||
requestLogChan: requestLogChan,
|
||||
}, nil
|
||||
@@ -52,16 +57,25 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3. Get the appropriate channel handler from the factory
|
||||
channelHandler, err := channel.GetChannel(&group)
|
||||
channelHandler, err := ps.channelFactory.GetChannel(&group)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s': %v", groupName, err)))
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Forward the request using the channel handler
|
||||
// 4. Apply parameter overrides if they exist
|
||||
if len(group.ParamOverrides) > 0 {
|
||||
err := ps.applyParamOverrides(c, &group)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to apply parameter overrides: %v", err)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Forward the request using the channel handler
|
||||
err = channelHandler.Handle(c, apiKey, &group)
|
||||
|
||||
// 5. Log the request asynchronously
|
||||
// 6. Log the request asynchronously
|
||||
isSuccess := err == nil
|
||||
if !isSuccess {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
@@ -145,3 +159,51 @@ func (ps *ProxyServer) updateKeyStats(keyID uint, success bool) {
|
||||
func (ps *ProxyServer) Close() {
|
||||
// Nothing to close for now
|
||||
}
|
||||
|
||||
func (ps *ProxyServer) applyParamOverrides(c *gin.Context, group *models.Group) error {
|
||||
// Read the original request body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read request body: %w", err)
|
||||
}
|
||||
c.Request.Body.Close() // Close the original body
|
||||
|
||||
// If body is empty, nothing to override, just restore the body
|
||||
if len(bodyBytes) == 0 {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save the original Content-Type
|
||||
originalContentType := c.GetHeader("Content-Type")
|
||||
|
||||
// Unmarshal the body into a map
|
||||
var requestData map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &requestData); err != nil {
|
||||
// If not a valid JSON, just pass it through
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Merge the overrides into the request data
|
||||
for key, value := range group.ParamOverrides {
|
||||
requestData[key] = value
|
||||
}
|
||||
|
||||
// Marshal the new data back to JSON
|
||||
newBodyBytes, err := json.Marshal(requestData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal new request body: %w", err)
|
||||
}
|
||||
|
||||
// Replace the request body with the new one
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(newBodyBytes))
|
||||
c.Request.ContentLength = int64(len(newBodyBytes))
|
||||
|
||||
// Restore the original Content-Type header
|
||||
if originalContentType != "" {
|
||||
c.Request.Header.Set("Content-Type", originalContentType)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -106,13 +106,27 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser
|
||||
groups.PUT("/:id", serverHandler.UpdateGroup)
|
||||
groups.DELETE("/:id", serverHandler.DeleteGroup)
|
||||
|
||||
// Key-specific routes
|
||||
keys := groups.Group("/:id/keys")
|
||||
{
|
||||
keys.POST("", serverHandler.CreateKeysInGroup)
|
||||
keys.GET("", serverHandler.ListKeysInGroup)
|
||||
keys.PUT("/:key_id", serverHandler.UpdateKey)
|
||||
keys.DELETE("", serverHandler.DeleteKeys)
|
||||
keys.POST("/add-multiple", serverHandler.AddMultipleKeys)
|
||||
keys.POST("/restore-all-invalid", serverHandler.RestoreAllInvalidKeys)
|
||||
keys.POST("/clear-all-invalid", serverHandler.ClearAllInvalidKeys)
|
||||
keys.GET("/export", serverHandler.ExportKeys)
|
||||
keys.DELETE("/:key_id", serverHandler.DeleteSingleKey)
|
||||
keys.POST("/:key_id/test", serverHandler.TestSingleKey)
|
||||
}
|
||||
|
||||
// Group-level actions
|
||||
groups.POST("/:id/validate-keys", serverHandler.ValidateGroupKeys)
|
||||
}
|
||||
|
||||
// Tasks
|
||||
tasks := api.Group("/tasks")
|
||||
{
|
||||
tasks.GET("/key-validation/status", serverHandler.GetTaskStatus)
|
||||
tasks.GET("/:task_id/result", serverHandler.GetTaskResult)
|
||||
}
|
||||
|
||||
// 仪表板和日志
|
||||
|
206
internal/services/key_cron_service.go
Normal file
206
internal/services/key_cron_service.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gpt-load/internal/config"
|
||||
"gpt-load/internal/models"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// KeyCronService is responsible for periodically validating all API keys.
|
||||
type KeyCronService struct {
|
||||
DB *gorm.DB
|
||||
Validator *KeyValidatorService
|
||||
SettingsManager *config.SystemSettingsManager
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewKeyCronService creates a new KeyCronService.
|
||||
func NewKeyCronService(db *gorm.DB, validator *KeyValidatorService, settingsManager *config.SystemSettingsManager) *KeyCronService {
|
||||
return &KeyCronService{
|
||||
DB: db,
|
||||
Validator: validator,
|
||||
SettingsManager: settingsManager,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the cron job.
|
||||
func (s *KeyCronService) Start() {
|
||||
logrus.Info("Starting KeyCronService...")
|
||||
s.wg.Add(1)
|
||||
go s.run()
|
||||
}
|
||||
|
||||
// Stop stops the cron job.
|
||||
func (s *KeyCronService) Stop() {
|
||||
logrus.Info("Stopping KeyCronService...")
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
logrus.Info("KeyCronService stopped.")
|
||||
}
|
||||
|
||||
func (s *KeyCronService) run() {
|
||||
defer s.wg.Done()
|
||||
ctx := context.Background()
|
||||
|
||||
// Run once on start
|
||||
s.validateAllGroups(ctx)
|
||||
|
||||
for {
|
||||
// Dynamically get the interval for the next run
|
||||
intervalMinutes := s.SettingsManager.GetInt("key_validation_interval_minutes", 60)
|
||||
if intervalMinutes <= 0 {
|
||||
intervalMinutes = 60 // Fallback to a safe default
|
||||
}
|
||||
nextRunTimer := time.NewTimer(time.Duration(intervalMinutes) * time.Minute)
|
||||
|
||||
select {
|
||||
case <-nextRunTimer.C:
|
||||
s.validateAllGroups(ctx)
|
||||
case <-s.stopChan:
|
||||
nextRunTimer.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyCronService) validateAllGroups(ctx context.Context) {
|
||||
logrus.Info("KeyCronService: Starting validation cycle for all groups.")
|
||||
var groups []models.Group
|
||||
if err := s.DB.Find(&groups).Error; err != nil {
|
||||
logrus.Errorf("KeyCronService: Failed to get groups: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
groupCopy := group // Create a copy for the closure
|
||||
go func(g models.Group) {
|
||||
// Get effective settings for the group
|
||||
effectiveSettings := s.SettingsManager.GetEffectiveConfig(g.Config)
|
||||
interval := time.Duration(effectiveSettings.KeyValidationIntervalMinutes) * time.Minute
|
||||
|
||||
// Check if it's time to validate this group
|
||||
if g.LastValidatedAt == nil || time.Since(*g.LastValidatedAt) > interval {
|
||||
s.validateGroup(ctx, &g)
|
||||
}
|
||||
}(groupCopy)
|
||||
}
|
||||
logrus.Info("KeyCronService: Validation cycle finished.")
|
||||
}
|
||||
|
||||
func (s *KeyCronService) validateGroup(ctx context.Context, group *models.Group) {
|
||||
var keys []models.APIKey
|
||||
if err := s.DB.Where("group_id = ?", group.ID).Find(&keys).Error; err != nil {
|
||||
logrus.Errorf("KeyCronService: Failed to get keys for group %s: %v", group.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
logrus.Infof("KeyCronService: Validating %d keys for group %s", len(keys), group.Name)
|
||||
|
||||
jobs := make(chan models.APIKey, len(keys))
|
||||
results := make(chan models.APIKey, len(keys))
|
||||
|
||||
concurrency := s.SettingsManager.GetInt("key_validation_concurrency", 10)
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10 // Fallback to a safe default
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go s.worker(ctx, &wg, group, jobs, results)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
jobs <- key
|
||||
}
|
||||
close(jobs)
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
var keysToUpdate []models.APIKey
|
||||
for key := range results {
|
||||
keysToUpdate = append(keysToUpdate, key)
|
||||
}
|
||||
|
||||
if len(keysToUpdate) > 0 {
|
||||
s.batchUpdateKeyStatus(keysToUpdate)
|
||||
}
|
||||
|
||||
// Update the last validated timestamp for the group
|
||||
if err := s.DB.Model(group).Update("last_validated_at", time.Now()).Error; err != nil {
|
||||
logrus.Errorf("KeyCronService: Failed to update last_validated_at for group %s: %v", group.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyCronService) worker(ctx context.Context, wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- models.APIKey) {
|
||||
defer wg.Done()
|
||||
for key := range jobs {
|
||||
isValid, err := s.Validator.ValidateSingleKey(ctx, &key, group)
|
||||
// Only update status if there was no error during validation
|
||||
if err != nil {
|
||||
logrus.Warnf("KeyCronService: Failed to validate key ID %d for group %s: %v. Skipping status update.", key.ID, group.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
newStatus := "inactive"
|
||||
if isValid {
|
||||
newStatus = "active"
|
||||
}
|
||||
|
||||
// Only send to results if the status has changed
|
||||
if key.Status != newStatus {
|
||||
key.Status = newStatus
|
||||
results <- key
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyCronService) batchUpdateKeyStatus(keys []models.APIKey) {
|
||||
if len(keys) == 0 {
|
||||
return
|
||||
}
|
||||
logrus.Infof("KeyCronService: Batch updating status for %d keys.", len(keys))
|
||||
|
||||
activeIDs := []uint{}
|
||||
inactiveIDs := []uint{}
|
||||
|
||||
for _, key := range keys {
|
||||
if key.Status == "active" {
|
||||
activeIDs = append(activeIDs, key.ID)
|
||||
} else {
|
||||
inactiveIDs = append(inactiveIDs, key.ID)
|
||||
}
|
||||
}
|
||||
|
||||
err := s.DB.Transaction(func(tx *gorm.DB) error {
|
||||
if len(activeIDs) > 0 {
|
||||
if err := tx.Model(&models.APIKey{}).Where("id IN ?", activeIDs).Update("status", "active").Error; err != nil {
|
||||
return err
|
||||
}
|
||||
logrus.Infof("KeyCronService: Set %d keys to 'active'.", len(activeIDs))
|
||||
}
|
||||
if len(inactiveIDs) > 0 {
|
||||
if err := tx.Model(&models.APIKey{}).Where("id IN ?", inactiveIDs).Update("status", "inactive").Error; err != nil {
|
||||
return err
|
||||
}
|
||||
logrus.Infof("KeyCronService: Set %d keys to 'inactive'.", len(inactiveIDs))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logrus.Errorf("KeyCronService: Failed to batch update key status: %v", err)
|
||||
}
|
||||
}
|
122
internal/services/key_manual_validation_service.go
Normal file
122
internal/services/key_manual_validation_service.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gpt-load/internal/config"
|
||||
"gpt-load/internal/models"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ManualValidationResult holds the result of a manual validation task.
|
||||
type ManualValidationResult struct {
|
||||
TotalKeys int `json:"total_keys"`
|
||||
ValidKeys int `json:"valid_keys"`
|
||||
InvalidKeys int `json:"invalid_keys"`
|
||||
}
|
||||
|
||||
// KeyManualValidationService handles user-initiated key validation for a group.
|
||||
type KeyManualValidationService struct {
|
||||
DB *gorm.DB
|
||||
Validator *KeyValidatorService
|
||||
TaskService *TaskService
|
||||
SettingsManager *config.SystemSettingsManager
|
||||
}
|
||||
|
||||
// NewKeyManualValidationService creates a new KeyManualValidationService.
|
||||
func NewKeyManualValidationService(db *gorm.DB, validator *KeyValidatorService, taskService *TaskService, settingsManager *config.SystemSettingsManager) *KeyManualValidationService {
|
||||
return &KeyManualValidationService{
|
||||
DB: db,
|
||||
Validator: validator,
|
||||
TaskService: taskService,
|
||||
SettingsManager: settingsManager,
|
||||
}
|
||||
}
|
||||
|
||||
// StartValidationTask starts a new manual validation task for a given group.
|
||||
func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (*TaskStatus, error) {
|
||||
var keys []models.APIKey
|
||||
if err := s.DB.Where("group_id = ?", group.ID).Find(&keys).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to get keys for group %s: %w", group.Name, err)
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no keys to validate in group %s", group.Name)
|
||||
}
|
||||
|
||||
taskID := uuid.New().String()
|
||||
timeoutMinutes := s.SettingsManager.GetInt("key_validation_task_timeout_minutes", 60)
|
||||
timeout := time.Duration(timeoutMinutes) * time.Minute
|
||||
|
||||
taskStatus, err := s.TaskService.StartTask(taskID, group.Name, len(keys), timeout)
|
||||
if err != nil {
|
||||
return nil, err // A task is already running
|
||||
}
|
||||
|
||||
// Run the validation in a separate goroutine
|
||||
go s.runValidation(group, keys, taskStatus)
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyManualValidationService) runValidation(group *models.Group, keys []models.APIKey, task *TaskStatus) {
|
||||
defer s.TaskService.EndTask()
|
||||
|
||||
logrus.Infof("Starting manual validation for group %s (TaskID: %s)", group.Name, task.TaskID)
|
||||
|
||||
jobs := make(chan models.APIKey, len(keys))
|
||||
results := make(chan bool, len(keys))
|
||||
|
||||
concurrency := s.SettingsManager.GetInt("key_validation_concurrency", 10)
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10 // Fallback to a safe default
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go s.validationWorker(&wg, group, jobs, results)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
jobs <- key
|
||||
}
|
||||
close(jobs)
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
validCount := 0
|
||||
processedCount := 0
|
||||
for isValid := range results {
|
||||
processedCount++
|
||||
if isValid {
|
||||
validCount++
|
||||
}
|
||||
// Update progress
|
||||
s.TaskService.UpdateProgress(processedCount)
|
||||
}
|
||||
|
||||
result := ManualValidationResult{
|
||||
TotalKeys: len(keys),
|
||||
ValidKeys: validCount,
|
||||
InvalidKeys: len(keys) - validCount,
|
||||
}
|
||||
|
||||
// Store the final result
|
||||
s.TaskService.StoreResult(task.TaskID, result)
|
||||
logrus.Infof("Manual validation finished for group %s (TaskID: %s): %+v", group.Name, task.TaskID, result)
|
||||
}
|
||||
|
||||
func (s *KeyManualValidationService) validationWorker(wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- bool) {
|
||||
defer wg.Done()
|
||||
for key := range jobs {
|
||||
isValid, _ := s.Validator.ValidateSingleKey(context.Background(), &key, group)
|
||||
results <- isValid
|
||||
}
|
||||
}
|
206
internal/services/key_service.go
Normal file
206
internal/services/key_service.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gpt-load/internal/models"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AddKeysResult holds the result of adding multiple keys.
|
||||
type AddKeysResult struct {
|
||||
AddedCount int `json:"added_count"`
|
||||
IgnoredCount int `json:"ignored_count"`
|
||||
TotalInGroup int64 `json:"total_in_group"`
|
||||
}
|
||||
|
||||
// KeyService provides services related to API keys.
|
||||
type KeyService struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
// NewKeyService creates a new KeyService.
|
||||
func NewKeyService(db *gorm.DB) *KeyService {
|
||||
return &KeyService{DB: db}
|
||||
}
|
||||
|
||||
// AddMultipleKeys handles the business logic of creating new keys from a text block.
|
||||
func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysResult, error) {
|
||||
// 1. Parse keys from the text block
|
||||
keys := s.parseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found in the input text")
|
||||
}
|
||||
|
||||
// 2. Get the group information for validation
|
||||
var group models.Group
|
||||
if err := s.DB.First(&group, groupID).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to find group: %w", err)
|
||||
}
|
||||
|
||||
// 3. Get existing keys in the group for deduplication
|
||||
var existingKeys []models.APIKey
|
||||
if err := s.DB.Where("group_id = ?", groupID).Select("key_value").Find(&existingKeys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
existingKeyMap := make(map[string]bool)
|
||||
for _, k := range existingKeys {
|
||||
existingKeyMap[k.KeyValue] = true
|
||||
}
|
||||
|
||||
// 4. Prepare new keys with basic validation only
|
||||
var newKeysToCreate []models.APIKey
|
||||
uniqueNewKeys := make(map[string]bool)
|
||||
|
||||
for _, keyVal := range keys {
|
||||
trimmedKey := strings.TrimSpace(keyVal)
|
||||
if trimmedKey == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if key already exists
|
||||
if existingKeyMap[trimmedKey] || uniqueNewKeys[trimmedKey] {
|
||||
continue
|
||||
}
|
||||
|
||||
// 通用验证:只做基础格式检查,不做渠道特定验证
|
||||
if s.isValidKeyFormat(trimmedKey) {
|
||||
uniqueNewKeys[trimmedKey] = true
|
||||
newKeysToCreate = append(newKeysToCreate, models.APIKey{
|
||||
GroupID: groupID,
|
||||
KeyValue: trimmedKey,
|
||||
Status: "active",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
addedCount := len(newKeysToCreate)
|
||||
// 更准确的忽略计数:包括重复的和无效的
|
||||
ignoredCount := len(keys) - addedCount
|
||||
|
||||
// 5. Insert new keys if any
|
||||
if addedCount > 0 {
|
||||
if err := s.DB.Create(&newKeysToCreate).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Get the new total count
|
||||
var totalInGroup int64
|
||||
if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AddKeysResult{
|
||||
AddedCount: addedCount,
|
||||
IgnoredCount: ignoredCount,
|
||||
TotalInGroup: totalInGroup,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *KeyService) parseKeysFromText(text string) []string {
|
||||
var keys []string
|
||||
|
||||
// First, try to parse as a JSON array of strings
|
||||
if json.Unmarshal([]byte(text), &keys) == nil && len(keys) > 0 {
|
||||
return s.filterValidKeys(keys)
|
||||
}
|
||||
|
||||
// 通用解析:通过分隔符分割文本,不使用复杂的正则表达式
|
||||
delimiters := regexp.MustCompile(`[\s,;|\n\r\t]+`)
|
||||
splitKeys := delimiters.Split(strings.TrimSpace(text), -1)
|
||||
|
||||
for _, key := range splitKeys {
|
||||
key = strings.TrimSpace(key)
|
||||
if key != "" {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
return s.filterValidKeys(keys)
|
||||
}
|
||||
|
||||
// filterValidKeys validates and filters potential API keys
|
||||
func (s *KeyService) filterValidKeys(keys []string) []string {
|
||||
var validKeys []string
|
||||
for _, key := range keys {
|
||||
key = strings.TrimSpace(key)
|
||||
if s.isValidKeyFormat(key) {
|
||||
validKeys = append(validKeys, key)
|
||||
}
|
||||
}
|
||||
return validKeys
|
||||
}
|
||||
|
||||
// isValidKeyFormat performs basic validation on key format
|
||||
func (s *KeyService) isValidKeyFormat(key string) bool {
|
||||
if len(key) < 4 || len(key) > 1000 {
|
||||
return false
|
||||
}
|
||||
|
||||
if key == "" ||
|
||||
strings.TrimSpace(key) == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
validChars := regexp.MustCompile(`^[a-zA-Z0-9_\-./+=:]+$`)
|
||||
return validChars.MatchString(key)
|
||||
}
|
||||
|
||||
// RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'.
|
||||
func (s *KeyService) RestoreAllInvalidKeys(groupID uint) (int64, error) {
|
||||
result := s.DB.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, "inactive").Update("status", "active")
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// ClearAllInvalidKeys deletes all 'inactive' keys from a group.
|
||||
func (s *KeyService) ClearAllInvalidKeys(groupID uint) (int64, error) {
|
||||
result := s.DB.Where("group_id = ? AND status = ?", groupID, "inactive").Delete(&models.APIKey{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// DeleteSingleKey deletes a specific key from a group.
|
||||
func (s *KeyService) DeleteSingleKey(groupID, keyID uint) (int64, error) {
|
||||
result := s.DB.Where("group_id = ? AND id = ?", groupID, keyID).Delete(&models.APIKey{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// ExportKeys returns a list of keys for a group, filtered by status.
|
||||
func (s *KeyService) ExportKeys(groupID uint, filter string) ([]string, error) {
|
||||
query := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID)
|
||||
|
||||
switch filter {
|
||||
case "valid":
|
||||
query = query.Where("status = ?", "active")
|
||||
case "invalid":
|
||||
query = query.Where("status = ?", "inactive")
|
||||
case "all":
|
||||
// No status filter needed
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid filter value. Use 'all', 'valid', or 'invalid'")
|
||||
}
|
||||
|
||||
var keys []string
|
||||
if err := query.Pluck("key_value", &keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// ListKeysInGroup lists all keys within a specific group, filtered by status.
|
||||
func (s *KeyService) ListKeysInGroup(groupID uint, statusFilter string) ([]models.APIKey, error) {
|
||||
var keys []models.APIKey
|
||||
query := s.DB.Where("group_id = ?", groupID)
|
||||
|
||||
if statusFilter != "" {
|
||||
query = query.Where("status = ?", statusFilter)
|
||||
}
|
||||
|
||||
if err := query.Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
91
internal/services/key_validator_service.go
Normal file
91
internal/services/key_validator_service.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gpt-load/internal/channel"
|
||||
"gpt-load/internal/models"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// KeyValidatorService provides methods to validate API keys.
|
||||
type KeyValidatorService struct {
|
||||
DB *gorm.DB
|
||||
channelFactory *channel.Factory
|
||||
}
|
||||
|
||||
// NewKeyValidatorService creates a new KeyValidatorService.
|
||||
func NewKeyValidatorService(db *gorm.DB, factory *channel.Factory) *KeyValidatorService {
|
||||
return &KeyValidatorService{
|
||||
DB: db,
|
||||
channelFactory: factory,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateSingleKey performs a validation check on a single API key.
|
||||
// It does not modify the key's state in the database.
|
||||
// It returns true if the key is valid, and an error if it's not.
|
||||
func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models.APIKey, group *models.Group) (bool, error) {
|
||||
// 添加超时保护
|
||||
if ctx.Err() != nil {
|
||||
return false, fmt.Errorf("context cancelled or timed out: %w", ctx.Err())
|
||||
}
|
||||
|
||||
ch, err := s.channelFactory.GetChannel(group)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"group_id": group.ID,
|
||||
"group_name": group.Name,
|
||||
"channel_type": group.ChannelType,
|
||||
"error": err,
|
||||
}).Error("Failed to get channel for key validation")
|
||||
return false, fmt.Errorf("failed to get channel for group %s: %w", group.Name, err)
|
||||
}
|
||||
|
||||
// 记录验证开始
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"key_id": key.ID,
|
||||
"group_id": group.ID,
|
||||
"group_name": group.Name,
|
||||
}).Debug("Starting key validation")
|
||||
|
||||
isValid, validationErr := ch.ValidateKey(ctx, key.KeyValue)
|
||||
if validationErr != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"key_id": key.ID,
|
||||
"group_id": group.ID,
|
||||
"group_name": group.Name,
|
||||
"error": validationErr,
|
||||
}).Warn("Key validation failed")
|
||||
return false, validationErr
|
||||
}
|
||||
|
||||
// 记录验证结果
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"key_id": key.ID,
|
||||
"group_id": group.ID,
|
||||
"group_name": group.Name,
|
||||
"is_valid": isValid,
|
||||
}).Debug("Key validation completed")
|
||||
|
||||
return isValid, nil
|
||||
}
|
||||
|
||||
// TestSingleKeyByID performs a synchronous validation test for a single API key by its ID.
|
||||
// It is intended for handling user-initiated "Test" actions.
|
||||
// It does not modify the key's state in the database.
|
||||
func (s *KeyValidatorService) TestSingleKeyByID(ctx context.Context, keyID uint) (bool, error) {
|
||||
var apiKey models.APIKey
|
||||
if err := s.DB.First(&apiKey, keyID).Error; err != nil {
|
||||
return false, fmt.Errorf("failed to find api key with id %d: %w", keyID, err)
|
||||
}
|
||||
|
||||
var group models.Group
|
||||
if err := s.DB.First(&group, apiKey.GroupID).Error; err != nil {
|
||||
return false, fmt.Errorf("failed to find group with id %d: %w", apiKey.GroupID, err)
|
||||
}
|
||||
|
||||
return s.ValidateSingleKey(ctx, &apiKey, &group)
|
||||
}
|
125
internal/services/task_service.go
Normal file
125
internal/services/task_service.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TaskStatus represents the status of a long-running task.
|
||||
type TaskStatus struct {
|
||||
IsRunning bool `json:"is_running"`
|
||||
GroupName string `json:"group_name,omitempty"`
|
||||
Processed int `json:"processed,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"` // Internal field to handle zombie tasks
|
||||
lastUpdated time.Time
|
||||
}
|
||||
|
||||
// TaskService manages the state of a single, global, long-running task.
|
||||
type TaskService struct {
|
||||
mu sync.Mutex
|
||||
status TaskStatus
|
||||
resultsCache map[string]interface{}
|
||||
cacheOrder []string
|
||||
maxCacheSize int
|
||||
}
|
||||
|
||||
// NewTaskService creates a new TaskService.
|
||||
func NewTaskService() *TaskService {
|
||||
return &TaskService{
|
||||
resultsCache: make(map[string]interface{}),
|
||||
cacheOrder: make([]string, 0),
|
||||
maxCacheSize: 100, // Store results for the last 100 tasks
|
||||
}
|
||||
}
|
||||
|
||||
// StartTask attempts to start a new task. It returns an error if a task is already running.
|
||||
func (s *TaskService) StartTask(taskID, groupName string, total int, timeout time.Duration) (*TaskStatus, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Zombie task check
|
||||
if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) {
|
||||
// The previous task is considered a zombie, reset it.
|
||||
s.status = TaskStatus{}
|
||||
}
|
||||
|
||||
if s.status.IsRunning {
|
||||
return nil, errors.New("a task is already running")
|
||||
}
|
||||
|
||||
s.status = TaskStatus{
|
||||
IsRunning: true,
|
||||
TaskID: taskID,
|
||||
GroupName: groupName,
|
||||
Total: total,
|
||||
Processed: 0,
|
||||
ExpiresAt: time.Now().Add(timeout),
|
||||
lastUpdated: time.Now(),
|
||||
}
|
||||
|
||||
return &s.status, nil
|
||||
}
|
||||
|
||||
// GetTaskStatus returns the current status of the task.
|
||||
func (s *TaskService) GetTaskStatus() *TaskStatus {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Zombie task check
|
||||
if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) {
|
||||
s.status = TaskStatus{} // Reset if expired
|
||||
}
|
||||
|
||||
// Return a copy to prevent race conditions on the caller's side
|
||||
statusCopy := s.status
|
||||
return &statusCopy
|
||||
}
|
||||
|
||||
// UpdateProgress updates the progress of the current task.
|
||||
func (s *TaskService) UpdateProgress(processed int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.status.IsRunning {
|
||||
return
|
||||
}
|
||||
|
||||
s.status.Processed = processed
|
||||
s.status.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
// EndTask marks the current task as finished.
|
||||
func (s *TaskService) EndTask() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.status.IsRunning = false
|
||||
}
|
||||
|
||||
// StoreResult stores the result of a finished task.
|
||||
func (s *TaskService) StoreResult(taskID string, result interface{}) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.resultsCache[taskID]; !exists {
|
||||
if len(s.cacheOrder) >= s.maxCacheSize {
|
||||
oldestTaskID := s.cacheOrder[0]
|
||||
delete(s.resultsCache, oldestTaskID)
|
||||
s.cacheOrder = s.cacheOrder[1:]
|
||||
}
|
||||
s.cacheOrder = append(s.cacheOrder, taskID)
|
||||
}
|
||||
s.resultsCache[taskID] = result
|
||||
}
|
||||
|
||||
// GetResult retrieves the result of a finished task.
|
||||
func (s *TaskService) GetResult(taskID string) (interface{}, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
result, found := s.resultsCache[taskID]
|
||||
return result, found
|
||||
}
|
Reference in New Issue
Block a user