feat: 密钥管理

This commit is contained in:
tbphp
2025-07-04 21:19:15 +08:00
parent 7c10474d19
commit 01b86f7e30
23 changed files with 1427 additions and 250 deletions

View File

@@ -14,6 +14,7 @@ import (
"syscall"
"time"
"gpt-load/internal/channel"
"gpt-load/internal/config"
"gpt-load/internal/db"
"gpt-load/internal/handler"
@@ -74,15 +75,28 @@ func main() {
go startRequestLogger(database, requestLogChan, &wg)
// ---
// --- Service Initialization ---
taskService := services.NewTaskService()
channelFactory := channel.NewFactory(settingsManager)
keyValidatorService := services.NewKeyValidatorService(database, channelFactory)
keyManualValidationService := services.NewKeyManualValidationService(database, keyValidatorService, taskService, settingsManager)
keyCronService := services.NewKeyCronService(database, keyValidatorService, settingsManager)
keyCronService.Start()
defer keyCronService.Stop()
keyService := services.NewKeyService(database)
// ---
// Create proxy server
proxyServer, err := proxy.NewProxyServer(database, requestLogChan)
proxyServer, err := proxy.NewProxyServer(database, channelFactory, requestLogChan)
if err != nil {
logrus.Fatalf("Failed to create proxy server: %v", err)
}
defer proxyServer.Close()
// Create handlers
serverHandler := handler.NewServer(database, configManager)
serverHandler := handler.NewServer(database, configManager, keyValidatorService, keyManualValidationService, taskService, keyService)
logCleanupHandler := handler.NewLogCleanupHandler(logCleanupService)
// Setup routes using the new router package

1
go.mod
View File

@@ -9,6 +9,7 @@ require (
github.com/gin-contrib/static v1.1.5
github.com/gin-gonic/gin v1.10.1
github.com/go-sql-driver/mysql v1.8.1
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
github.com/sirupsen/logrus v1.9.3
gorm.io/datatypes v1.2.1

2
go.sum
View File

@@ -41,6 +41,8 @@ github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EO
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=

View File

@@ -8,38 +8,65 @@ import (
"net/http/httputil"
"net/url"
"strings"
"sync/atomic"
"sync"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/sirupsen/logrus"
)
// UpstreamInfo holds the information for a single upstream server, including its weight.
type UpstreamInfo struct {
URL *url.URL
Weight int
CurrentWeight int
}
// BaseChannel provides common functionality for channel proxies.
type BaseChannel struct {
Name string
Upstreams []*url.URL
Upstreams []UpstreamInfo
HTTPClient *http.Client
roundRobin uint64
upstreamLock sync.Mutex
}
// RequestModifier is a function that can modify the request before it's sent.
type RequestModifier func(req *http.Request, key *models.APIKey)
// getUpstreamURL selects an upstream URL using round-robin.
// getUpstreamURL selects an upstream URL using a smooth weighted round-robin algorithm.
func (b *BaseChannel) getUpstreamURL() *url.URL {
b.upstreamLock.Lock()
defer b.upstreamLock.Unlock()
if len(b.Upstreams) == 0 {
return nil
}
if len(b.Upstreams) == 1 {
return b.Upstreams[0]
return b.Upstreams[0].URL
}
index := atomic.AddUint64(&b.roundRobin, 1) - 1
return b.Upstreams[index%uint64(len(b.Upstreams))]
totalWeight := 0
var best *UpstreamInfo
for i := range b.Upstreams {
up := &b.Upstreams[i]
totalWeight += up.Weight
up.CurrentWeight += up.Weight
if best == nil || up.CurrentWeight > best.CurrentWeight {
best = up
}
}
if best == nil {
return b.Upstreams[0].URL // 降级到第一个可用的
}
best.CurrentWeight -= totalWeight
return best.URL
}
// ProcessRequest handles the common logic of processing and forwarding a request.
func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier) error {
func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier, ch ChannelProxy) error {
upstreamURL := b.getUpstreamURL()
if upstreamURL == nil {
return fmt.Errorf("no upstream URL configured for channel %s", b.Name)
@@ -78,7 +105,7 @@ func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modi
}
// Check if the client request is for a streaming endpoint
if isStreamingRequest(c) {
if ch.IsStreamingRequest(c) {
return b.handleStreaming(c, proxy)
}
@@ -87,6 +114,9 @@ func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modi
}
func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReverseProxy) error {
var wg sync.WaitGroup
wg.Add(1)
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
@@ -96,13 +126,12 @@ func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReversePro
pr, pw := io.Pipe()
defer pr.Close()
// Create a new request with the pipe reader as the body
// This is a bit of a hack to get ReverseProxy to stream
req := c.Request.Clone(c.Request.Context())
req.Body = pr
// Start the proxy in a goroutine
go func() {
defer wg.Done()
defer pw.Close()
proxy.ServeHTTP(c.Writer, req)
}()
@@ -111,32 +140,16 @@ func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReversePro
_, err := io.Copy(pw, c.Request.Body)
if err != nil {
logrus.Errorf("Error copying request body to pipe: %v", err)
wg.Wait() // Wait for the goroutine to finish even if copy fails
return err
}
// Wait for the proxy to finish
wg.Wait()
return nil
}
// isStreamingRequest checks if the request is for a streaming response.
func isStreamingRequest(c *gin.Context) bool {
// For Gemini, streaming is indicated by the path.
if strings.Contains(c.Request.URL.Path, ":streamGenerateContent") {
return true
}
// For OpenAI, streaming is indicated by a "stream": true field in the JSON body.
// We use ShouldBindBodyWith to check the body without consuming it, so it can be read again by the proxy.
type streamPayload struct {
Stream bool `json:"stream"`
}
var p streamPayload
if err := c.ShouldBindBodyWith(&p, binding.JSON); err == nil {
return p.Stream
}
return false
}
// singleJoiningSlash joins two URL paths with a single slash.
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")

View File

@@ -1,6 +1,7 @@
package channel
import (
"context"
"gpt-load/internal/models"
"github.com/gin-gonic/gin"
@@ -11,4 +12,10 @@ type ChannelProxy interface {
// Handle takes a context, an API key, and the original request,
// then forwards the request to the upstream service.
Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error
// ValidateKey checks if the given API key is valid.
ValidateKey(ctx context.Context, key string) (bool, error)
// IsStreamingRequest checks if the request is for a streaming response.
IsStreamingRequest(c *gin.Context) bool
}

View File

@@ -1,6 +1,7 @@
package channel
import (
"encoding/json"
"fmt"
"gpt-load/internal/config"
"gpt-load/internal/models"
@@ -11,36 +12,61 @@ import (
"gorm.io/datatypes"
)
// Factory is responsible for creating channel proxies.
type Factory struct {
settingsManager *config.SystemSettingsManager
}
// NewFactory creates a new channel factory.
func NewFactory(settingsManager *config.SystemSettingsManager) *Factory {
return &Factory{
settingsManager: settingsManager,
}
}
// GetChannel returns a channel proxy based on the group's channel type.
func GetChannel(group *models.Group) (ChannelProxy, error) {
func (f *Factory) GetChannel(group *models.Group) (ChannelProxy, error) {
switch group.ChannelType {
case "openai":
return NewOpenAIChannel(group.Upstreams, group.Config)
return f.NewOpenAIChannel(group)
case "gemini":
return NewGeminiChannel(group.Upstreams, group.Config)
return f.NewGeminiChannel(group)
default:
return nil, fmt.Errorf("unsupported channel type: %s", group.ChannelType)
}
}
// newBaseChannelWithUpstreams is a helper function to create and configure a BaseChannel.
func newBaseChannelWithUpstreams(name string, upstreams []string, groupConfig datatypes.JSONMap) (BaseChannel, error) {
if len(upstreams) == 0 {
return BaseChannel{}, fmt.Errorf("at least one upstream is required for %s channel", name)
// newBaseChannel is a helper function to create and configure a BaseChannel.
func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, groupConfig datatypes.JSONMap) (*BaseChannel, error) {
type upstreamDef struct {
URL string `json:"url"`
Weight int `json:"weight"`
}
var upstreamURLs []*url.URL
for _, us := range upstreams {
u, err := url.Parse(us)
if err != nil {
return BaseChannel{}, fmt.Errorf("failed to parse upstream url '%s' for %s channel: %w", us, name, err)
var defs []upstreamDef
if err := json.Unmarshal(upstreamsJSON, &defs); err != nil {
return nil, fmt.Errorf("failed to unmarshal upstreams for %s channel: %w", name, err)
}
upstreamURLs = append(upstreamURLs, u)
if len(defs) == 0 {
return nil, fmt.Errorf("at least one upstream is required for %s channel", name)
}
var upstreamInfos []UpstreamInfo
for _, def := range defs {
u, err := url.Parse(def.URL)
if err != nil {
return nil, fmt.Errorf("failed to parse upstream url '%s' for %s channel: %w", def.URL, name, err)
}
weight := def.Weight
if weight <= 0 {
weight = 1 // Default weight to 1 if not specified or invalid
}
upstreamInfos = append(upstreamInfos, UpstreamInfo{URL: u, Weight: weight})
}
// Get effective settings by merging system and group configs
settingsManager := config.GetSystemSettingsManager()
effectiveSettings := settingsManager.GetEffectiveConfig(groupConfig)
effectiveSettings := f.settingsManager.GetEffectiveConfig(groupConfig)
// Configure the HTTP client with the effective timeouts
httpClient := &http.Client{
@@ -50,9 +76,9 @@ func newBaseChannelWithUpstreams(name string, upstreams []string, groupConfig da
Timeout: time.Duration(effectiveSettings.RequestTimeout) * time.Second,
}
return BaseChannel{
return &BaseChannel{
Name: name,
Upstreams: upstreamURLs,
Upstreams: upstreamInfos,
HTTPClient: httpClient,
}, nil
}

View File

@@ -1,19 +1,22 @@
package channel
import (
"context"
"fmt"
"gpt-load/internal/models"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"gorm.io/datatypes"
)
type GeminiChannel struct {
BaseChannel
*BaseChannel
}
func NewGeminiChannel(upstreams []string, config datatypes.JSONMap) (*GeminiChannel, error) {
base, err := newBaseChannelWithUpstreams("gemini", upstreams, config)
func (f *Factory) NewGeminiChannel(group *models.Group) (*GeminiChannel, error) {
base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config)
if err != nil {
return nil, err
}
@@ -29,5 +32,40 @@ func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo
q.Set("key", key.KeyValue)
req.URL.RawQuery = q.Encode()
}
return ch.ProcessRequest(c, apiKey, modifier)
return ch.ProcessRequest(c, apiKey, modifier, ch)
}
// ValidateKey checks if the given API key is valid by making a request to the models endpoint.
func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
upstreamURL := ch.getUpstreamURL()
if upstreamURL == nil {
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
}
// Construct the request URL for listing models.
reqURL := fmt.Sprintf("%s/v1beta/models?key=%s", upstreamURL.String(), key)
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil {
return false, fmt.Errorf("failed to create validation request: %w", err)
}
resp, err := ch.HTTPClient.Do(req)
if err != nil {
return false, fmt.Errorf("failed to send validation request: %w", err)
}
defer resp.Body.Close()
// A 200 OK status code indicates the key is valid.
return resp.StatusCode == http.StatusOK, nil
}
// IsStreamingRequest checks if the request is for a streaming response.
func (ch *GeminiChannel) IsStreamingRequest(c *gin.Context) bool {
// For Gemini, streaming is indicated by the path containing streaming keywords
path := c.Request.URL.Path
return strings.Contains(path, ":streamGenerateContent") ||
strings.Contains(path, "streamGenerateContent") ||
strings.Contains(path, ":stream") ||
strings.Contains(path, "/stream")
}

View File

@@ -1,19 +1,21 @@
package channel
import (
"context"
"fmt"
"gpt-load/internal/models"
"net/http"
"github.com/gin-gonic/gin"
"gorm.io/datatypes"
"github.com/gin-gonic/gin/binding"
)
type OpenAIChannel struct {
BaseChannel
*BaseChannel
}
func NewOpenAIChannel(upstreams []string, config datatypes.JSONMap) (*OpenAIChannel, error) {
base, err := newBaseChannelWithUpstreams("openai", upstreams, config)
func (f *Factory) NewOpenAIChannel(group *models.Group) (*OpenAIChannel, error) {
base, err := f.newBaseChannel("openai", group.Upstreams, group.Config)
if err != nil {
return nil, err
}
@@ -27,5 +29,46 @@ func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo
modifier := func(req *http.Request, key *models.APIKey) {
req.Header.Set("Authorization", "Bearer "+key.KeyValue)
}
return ch.ProcessRequest(c, apiKey, modifier)
return ch.ProcessRequest(c, apiKey, modifier, ch)
}
// ValidateKey checks if the given API key is valid by making a request to the models endpoint.
func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
upstreamURL := ch.getUpstreamURL()
if upstreamURL == nil {
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
}
// Construct the request URL for listing models, a common endpoint for key validation.
reqURL := upstreamURL.String() + "/v1/models"
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil {
return false, fmt.Errorf("failed to create validation request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+key)
resp, err := ch.HTTPClient.Do(req)
if err != nil {
return false, fmt.Errorf("failed to send validation request: %w", err)
}
defer resp.Body.Close()
// A 200 OK status code indicates the key is valid.
// Other status codes (e.g., 401 Unauthorized) indicate an invalid key.
return resp.StatusCode == http.StatusOK, nil
}
// IsStreamingRequest checks if the request is for a streaming response.
func (ch *OpenAIChannel) IsStreamingRequest(c *gin.Context) bool {
// For OpenAI, streaming is indicated by a "stream": true field in the JSON body.
// We use ShouldBindBodyWith to check the body without consuming it, so it can be read again by the proxy.
type streamPayload struct {
Stream bool `json:"stream"`
}
var p streamPayload
if err := c.ShouldBindBodyWith(&p, binding.JSON); err == nil {
return p.Stream
}
return false
}

View File

@@ -276,3 +276,16 @@ func getEnvOrDefault(key, defaultValue string) string {
}
return defaultValue
}
// GetInt is a helper function for SystemSettingsManager to get an integer value with a default.
func (s *SystemSettingsManager) GetInt(key string, defaultValue int) int {
s.mu.RLock()
defer s.mu.RUnlock()
if valStr, ok := s.settingsCache[key]; ok {
if valInt, err := strconv.Atoi(valStr); err == nil {
return valInt
}
}
return defaultValue
}

View File

@@ -34,6 +34,11 @@ type SystemSettings struct {
// 请求日志配置(数据库日志)
RequestLogRetentionDays int `json:"request_log_retention_days" default:"30" name:"日志保留天数" category:"日志配置" desc:"请求日志在数据库中的保留天数" validate:"min=1"`
// 密钥验证配置
KeyValidationIntervalMinutes int `json:"key_validation_interval_minutes" default:"60" name:"定时验证周期" category:"密钥验证" desc:"后台定时验证密钥的默认周期(分钟)" validate:"min=5"`
KeyValidationConcurrency int `json:"key_validation_concurrency" default:"10" name:"验证并发数" category:"密钥验证" desc:"执行密钥验证时的并发 goroutine 数量" validate:"min=1,max=100"`
KeyValidationTaskTimeoutMinutes int `json:"key_validation_task_timeout_minutes" default:"60" name:"手动验证超时" category:"密钥验证" desc:"手动触发的全量验证任务的超时时间(分钟)" validate:"min=10"`
}
// GenerateSettingsMetadata 使用反射从 SystemSettings 结构体动态生成元数据
@@ -106,6 +111,7 @@ func DefaultSystemSettings() SystemSettings {
// SystemSettingsManager 管理系统配置
type SystemSettingsManager struct {
settings SystemSettings
settingsCache map[string]string // Cache for raw string values
mu sync.RWMutex
}
@@ -169,6 +175,8 @@ func (sm *SystemSettingsManager) LoadFromDatabase() error {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.settingsCache = settingsMap
// 使用默认值,然后用数据库中的值覆盖
sm.settings = DefaultSystemSettings()
sm.mapToStruct(settingsMap, &sm.settings)
@@ -334,6 +342,8 @@ func (sm *SystemSettingsManager) DisplayCurrentSettings() {
logrus.Infof(" Request timeouts: request=%ds, response=%ds, idle_conn=%ds",
sm.settings.RequestTimeout, sm.settings.ResponseTimeout, sm.settings.IdleConnTimeout)
logrus.Infof(" Request log retention: %d days", sm.settings.RequestLogRetentionDays)
logrus.Infof(" Key validation: interval=%dmin, concurrency=%d, task_timeout=%dmin",
sm.settings.KeyValidationIntervalMinutes, sm.settings.KeyValidationConcurrency, sm.settings.KeyValidationTaskTimeoutMinutes)
}
// 辅助方法

View File

@@ -31,6 +31,8 @@ var (
ErrDatabase = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "DATABASE_ERROR", Message: "Database operation failed"}
ErrUnauthorized = &APIError{HTTPStatus: http.StatusUnauthorized, Code: "UNAUTHORIZED", Message: "Authentication failed"}
ErrForbidden = &APIError{HTTPStatus: http.StatusForbidden, Code: "FORBIDDEN", Message: "You do not have permission to access this resource"}
ErrTaskInProgress = &APIError{HTTPStatus: http.StatusConflict, Code: "TASK_IN_PROGRESS", Message: "A task is already in progress"}
ErrBadGateway = &APIError{HTTPStatus: http.StatusBadGateway, Code: "BAD_GATEWAY", Message: "Upstream service error"}
)
// NewAPIError creates a new APIError with a custom message.

View File

@@ -3,6 +3,7 @@ package handler
import (
"encoding/json"
"fmt"
app_errors "gpt-load/internal/errors"
"gpt-load/internal/models"
"gpt-load/internal/response"
@@ -10,6 +11,7 @@ import (
"strconv"
"github.com/gin-gonic/gin"
"gorm.io/datatypes"
)
// isValidGroupName checks if the group name is valid.
@@ -22,6 +24,50 @@ func isValidGroupName(name string) bool {
return match
}
// validateAndCleanConfig validates the group config against the GroupConfig struct.
func validateAndCleanConfig(configMap map[string]interface{}) (map[string]interface{}, error) {
if configMap == nil {
return nil, nil
}
configBytes, err := json.Marshal(configMap)
if err != nil {
return nil, err
}
var validatedConfig models.GroupConfig
if err := json.Unmarshal(configBytes, &validatedConfig); err != nil {
return nil, err
}
// 验证配置项的合理范围
if validatedConfig.BlacklistThreshold != nil && *validatedConfig.BlacklistThreshold < 0 {
return nil, fmt.Errorf("blacklist_threshold must be >= 0")
}
if validatedConfig.MaxRetries != nil && (*validatedConfig.MaxRetries < 0 || *validatedConfig.MaxRetries > 10) {
return nil, fmt.Errorf("max_retries must be between 0 and 10")
}
if validatedConfig.RequestTimeout != nil && (*validatedConfig.RequestTimeout < 1 || *validatedConfig.RequestTimeout > 3600) {
return nil, fmt.Errorf("request_timeout must be between 1 and 3600 seconds")
}
if validatedConfig.KeyValidationIntervalMinutes != nil && (*validatedConfig.KeyValidationIntervalMinutes < 5 || *validatedConfig.KeyValidationIntervalMinutes > 1440) {
return nil, fmt.Errorf("key_validation_interval_minutes must be between 5 and 1440 minutes")
}
// Marshal back to a map to remove any fields not in GroupConfig
validatedBytes, err := json.Marshal(validatedConfig)
if err != nil {
return nil, err
}
var cleanedMap map[string]interface{}
if err := json.Unmarshal(validatedBytes, &cleanedMap); err != nil {
return nil, err
}
return cleanedMap, nil
}
// CreateGroup handles the creation of a new group.
func (s *Server) CreateGroup(c *gin.Context) {
var group models.Group
@@ -43,6 +89,17 @@ func (s *Server) CreateGroup(c *gin.Context) {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Channel type is required"))
return
}
if group.TestModel == "" {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Test model is required"))
return
}
cleanedConfig, err := validateAndCleanConfig(group.Config)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid config format"))
return
}
group.Config = cleanedConfig
if err := s.DB.Create(&group).Error; err != nil {
response.Error(c, app_errors.ParseDBError(err))
@@ -62,6 +119,20 @@ func (s *Server) ListGroups(c *gin.Context) {
response.Success(c, groups)
}
// GroupUpdateRequest defines the payload for updating a group.
// Using a dedicated struct avoids issues with zero values being ignored by GORM's Update.
type GroupUpdateRequest struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
Description string `json:"description"`
Upstreams json.RawMessage `json:"upstreams"`
ChannelType string `json:"channel_type"`
Sort *int `json:"sort"`
TestModel string `json:"test_model"`
ParamOverrides map[string]interface{} `json:"param_overrides"`
Config map[string]interface{} `json:"config"`
}
// UpdateGroup handles updating an existing group.
func (s *Server) UpdateGroup(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
@@ -76,84 +147,70 @@ func (s *Server) UpdateGroup(c *gin.Context) {
return
}
var updateData models.Group
if err := c.ShouldBindJSON(&updateData); err != nil {
var req GroupUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
return
}
// Validate group name if it's being updated
if updateData.Name != "" && !isValidGroupName(updateData.Name) {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid group name format. Use 3-30 lowercase letters, numbers, and underscores."))
return
}
// Use a transaction to ensure atomicity
// Start a transaction
tx := s.DB.Begin()
if tx.Error != nil {
response.Error(c, app_errors.ErrDatabase)
return
}
defer tx.Rollback() // Rollback on panic
// Convert updateData to a map to ensure zero values (like Sort: 0) are updated
var updateMap map[string]interface{}
updateBytes, _ := json.Marshal(updateData)
if err := json.Unmarshal(updateBytes, &updateMap); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Failed to process update data"))
// Apply updates from the request
if req.Name != "" {
if !isValidGroupName(req.Name) {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid group name format."))
return
}
// If config is being updated, it needs to be marshalled to JSON string for GORM
if config, ok := updateMap["config"]; ok {
if configMap, isMap := config.(map[string]interface{}); isMap {
configJSON, err := json.Marshal(configMap)
group.Name = req.Name
}
if req.DisplayName != "" {
group.DisplayName = req.DisplayName
}
if req.Description != "" {
group.Description = req.Description
}
if req.Upstreams != nil {
group.Upstreams = datatypes.JSON(req.Upstreams)
}
if req.ChannelType != "" {
group.ChannelType = req.ChannelType
}
if req.Sort != nil {
group.Sort = *req.Sort
}
if req.TestModel != "" {
group.TestModel = req.TestModel
}
if req.ParamOverrides != nil {
group.ParamOverrides = req.ParamOverrides
}
if req.Config != nil {
cleanedConfig, err := validateAndCleanConfig(req.Config)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Failed to process config data"))
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid config format"))
return
}
updateMap["config"] = string(configJSON)
}
group.Config = cleanedConfig
}
// Handle upstreams field specifically
if upstreams, ok := updateMap["upstreams"]; ok {
if upstreamsSlice, isSlice := upstreams.([]interface{}); isSlice {
upstreamsJSON, err := json.Marshal(upstreamsSlice)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Failed to process upstreams data"))
return
}
updateMap["upstreams"] = string(upstreamsJSON)
}
}
// Remove fields that are not actual columns or should not be updated from the map
delete(updateMap, "id")
delete(updateMap, "api_keys")
delete(updateMap, "created_at")
delete(updateMap, "updated_at")
// Use Updates with a map to only update provided fields, including zero values
if err := tx.Model(&group).Updates(updateMap).Error; err != nil {
tx.Rollback()
// Save the updated group object
if err := tx.Save(&group).Error; err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
if err := tx.Commit().Error; err != nil {
tx.Rollback()
response.Error(c, app_errors.ErrDatabase)
return
}
// Re-fetch the group to return the updated data
var updatedGroup models.Group
if err := s.DB.First(&updatedGroup, id).Error; err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
response.Success(c, updatedGroup)
response.Success(c, group)
}
// DeleteGroup handles deleting a group.

View File

@@ -6,6 +6,7 @@ import (
"time"
"gpt-load/internal/models"
"gpt-load/internal/services"
"gpt-load/internal/types"
"github.com/gin-gonic/gin"
@@ -16,13 +17,28 @@ import (
type Server struct {
DB *gorm.DB
config types.ConfigManager
KeyValidatorService *services.KeyValidatorService
KeyManualValidationService *services.KeyManualValidationService
TaskService *services.TaskService
KeyService *services.KeyService
}
// NewServer creates a new handler instance
func NewServer(db *gorm.DB, config types.ConfigManager) *Server {
func NewServer(
db *gorm.DB,
config types.ConfigManager,
keyValidatorService *services.KeyValidatorService,
keyManualValidationService *services.KeyManualValidationService,
taskService *services.TaskService,
keyService *services.KeyService,
) *Server {
return &Server{
DB: db,
config: config,
KeyValidatorService: keyValidatorService,
KeyManualValidationService: keyManualValidationService,
TaskService: taskService,
KeyService: keyService,
}
}

View File

@@ -1,60 +1,123 @@
// Package handler provides HTTP handlers for the application
package handler
import (
"fmt"
app_errors "gpt-load/internal/errors"
"gpt-load/internal/models"
"gpt-load/internal/response"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type CreateKeysRequest struct {
Keys []string `json:"keys" binding:"required"`
// validateGroupID validates and parses group ID from request parameter
func validateGroupID(c *gin.Context) (uint, error) {
groupIDStr := c.Param("id")
if groupIDStr == "" {
return 0, fmt.Errorf("group ID is required")
}
// CreateKeysInGroup handles creating new keys within a specific group.
func (s *Server) CreateKeysInGroup(c *gin.Context) {
groupID, err := strconv.Atoi(c.Param("id"))
groupID, err := strconv.Atoi(groupIDStr)
if err != nil || groupID <= 0 {
return 0, fmt.Errorf("invalid group ID format")
}
return uint(groupID), nil
}
// validateKeyID validates and parses key ID from request parameter
func validateKeyID(c *gin.Context) (uint, error) {
keyIDStr := c.Param("key_id")
if keyIDStr == "" {
return 0, fmt.Errorf("key ID is required")
}
keyID, err := strconv.Atoi(keyIDStr)
if err != nil || keyID <= 0 {
return 0, fmt.Errorf("invalid key ID format")
}
return uint(keyID), nil
}
// validateKeysText validates the keys text input
func validateKeysText(keysText string) error {
if strings.TrimSpace(keysText) == "" {
return fmt.Errorf("keys text cannot be empty")
}
if len(keysText) > 1024*1024 { // 1MB limit
return fmt.Errorf("keys text is too large (max 1MB)")
}
return nil
}
// findGroupByID is a helper function to find a group by its ID.
func (s *Server) findGroupByID(c *gin.Context, groupID int) (*models.Group, bool) {
var group models.Group
if err := s.DB.First(&group, groupID).Error; err != nil {
if err == gorm.ErrRecordNotFound {
response.Error(c, app_errors.ErrResourceNotFound)
} else {
response.Error(c, app_errors.ParseDBError(err))
}
return nil, false
}
return &group, true
}
// AddMultipleKeysRequest defines the payload for adding multiple keys from a text block.
type AddMultipleKeysRequest struct {
KeysText string `json:"keys_text" binding:"required"`
}
// AddMultipleKeys handles creating new keys from a text block within a specific group.
func (s *Server) AddMultipleKeys(c *gin.Context) {
groupID, err := validateGroupID(c)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format"))
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error()))
return
}
var req CreateKeysRequest
var req AddMultipleKeysRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
return
}
var newKeys []models.APIKey
for _, keyVal := range req.Keys {
newKeys = append(newKeys, models.APIKey{
GroupID: uint(groupID),
KeyValue: keyVal,
Status: "active",
})
if err := validateKeysText(req.KeysText); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
return
}
if err := s.DB.Create(&newKeys).Error; err != nil {
result, err := s.KeyService.AddMultipleKeys(groupID, req.KeysText)
if err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
response.Success(c, newKeys)
response.Success(c, result)
}
// ListKeysInGroup handles listing all keys within a specific group.
func (s *Server) ListKeysInGroup(c *gin.Context) {
groupID, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format"))
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
return
}
var keys []models.APIKey
if err := s.DB.Where("group_id = ?", groupID).Find(&keys).Error; err != nil {
statusFilter := c.Query("status")
if statusFilter != "" && statusFilter != "active" && statusFilter != "inactive" {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "Invalid status filter"))
return
}
keys, err := s.KeyService.ListKeysInGroup(uint(groupID), statusFilter)
if err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
@@ -62,90 +125,124 @@ func (s *Server) ListKeysInGroup(c *gin.Context) {
response.Success(c, keys)
}
// UpdateKey handles updating a specific key.
func (s *Server) UpdateKey(c *gin.Context) {
// DeleteSingleKey handles deleting a specific key.
func (s *Server) DeleteSingleKey(c *gin.Context) {
groupID, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format"))
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
return
}
keyID, err := strconv.Atoi(c.Param("key_id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid key ID format"))
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid key ID"))
return
}
var key models.APIKey
if err := s.DB.Where("group_id = ? AND id = ?", groupID, keyID).First(&key).Error; err != nil {
rowsAffected, err := s.KeyService.DeleteSingleKey(uint(groupID), uint(keyID))
if err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
var updateData struct {
Status string `json:"status"`
}
if err := c.ShouldBindJSON(&updateData); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
if rowsAffected == 0 {
response.Error(c, app_errors.ErrResourceNotFound)
return
}
key.Status = updateData.Status
if err := s.DB.Save(&key).Error; err != nil {
response.Error(c, app_errors.ParseDBError(err))
response.Success(c, gin.H{"message": "Key deleted successfully"})
}
// TestSingleKey handles a one-off validation test for a single key.
func (s *Server) TestSingleKey(c *gin.Context) {
keyID, err := validateKeyID(c)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error()))
return
}
response.Success(c, key)
isValid, validationErr := s.KeyValidatorService.TestSingleKeyByID(c.Request.Context(), keyID)
if validationErr != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadGateway, validationErr.Error()))
return
}
type DeleteKeysRequest struct {
KeyIDs []uint `json:"key_ids" binding:"required"`
if isValid {
response.Success(c, gin.H{"success": true, "message": "Key is valid."})
} else {
response.Success(c, gin.H{"success": false, "message": "Key is invalid or has insufficient quota."})
}
}
// DeleteKeys handles deleting one or more keys.
func (s *Server) DeleteKeys(c *gin.Context) {
// ValidateGroupKeys initiates a manual validation task for all keys in a group.
func (s *Server) ValidateGroupKeys(c *gin.Context) {
groupID, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID format"))
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
return
}
var req DeleteKeysRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
group, ok := s.findGroupByID(c, groupID)
if !ok {
return
}
if len(req.KeyIDs) == 0 {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "No key IDs provided"))
taskStatus, err := s.KeyManualValidationService.StartValidationTask(group)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrTaskInProgress, err.Error()))
return
}
// Start a transaction
tx := s.DB.Begin()
response.Success(c, taskStatus)
}
// Verify all keys belong to the specified group
var count int64
if err := tx.Model(&models.APIKey{}).Where("id IN ? AND group_id = ?", req.KeyIDs, groupID).Count(&count).Error; err != nil {
tx.Rollback()
// RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'.
func (s *Server) RestoreAllInvalidKeys(c *gin.Context) {
groupID, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
return
}
rowsAffected, err := s.KeyService.RestoreAllInvalidKeys(uint(groupID))
if err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
if count != int64(len(req.KeyIDs)) {
tx.Rollback()
response.Error(c, app_errors.NewAPIError(app_errors.ErrForbidden, "One or more keys do not belong to the specified group"))
response.Success(c, gin.H{"message": fmt.Sprintf("%d keys restored.", rowsAffected)})
}
// ClearAllInvalidKeys deletes all 'inactive' keys from a group.
func (s *Server) ClearAllInvalidKeys(c *gin.Context) {
groupID, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
return
}
// Delete the keys
if err := tx.Where("id IN ?", req.KeyIDs).Delete(&models.APIKey{}).Error; err != nil {
tx.Rollback()
rowsAffected, err := s.KeyService.ClearAllInvalidKeys(uint(groupID))
if err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
tx.Commit()
response.Success(c, gin.H{"message": "Keys deleted successfully"})
response.Success(c, gin.H{"message": fmt.Sprintf("%d invalid keys cleared.", rowsAffected)})
}
// ExportKeys returns a list of keys for a group, filtered by status.
func (s *Server) ExportKeys(c *gin.Context) {
groupID, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
return
}
filter := c.DefaultQuery("filter", "all")
keys, err := s.KeyService.ExportKeys(uint(groupID), filter)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
return
}
response.Success(c, gin.H{"keys": keys})
}

View File

@@ -0,0 +1,31 @@
package handler
import (
"gpt-load/internal/response"
app_errors "gpt-load/internal/errors"
"github.com/gin-gonic/gin"
)
// GetTaskStatus handles requests for the status of the global long-running task.
func (s *Server) GetTaskStatus(c *gin.Context) {
taskStatus := s.TaskService.GetTaskStatus()
response.Success(c, taskStatus)
}
// GetTaskResult handles requests for the result of a finished task.
func (s *Server) GetTaskResult(c *gin.Context) {
taskID := c.Param("task_id")
if taskID == "" {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Task ID is required"))
return
}
result, found := s.TaskService.GetResult(taskID)
if !found {
response.Error(c, app_errors.ErrResourceNotFound)
return
}
response.Success(c, result)
}

View File

@@ -1,9 +1,6 @@
package models
import (
"database/sql/driver"
"encoding/json"
"errors"
"time"
"gorm.io/datatypes"
@@ -19,26 +16,6 @@ type SystemSetting struct {
UpdatedAt time.Time `json:"updated_at"`
}
// Upstreams 是一个上游地址的切片,可以被 GORM 正确处理
type Upstreams []string
// Value 实现 driver.Valuer 接口,用于将 Upstreams 类型转换为数据库值
func (u Upstreams) Value() (driver.Value, error) {
if len(u) == 0 {
return "[]", nil
}
return json.Marshal(u)
}
// Scan 实现 sql.Scanner 接口,用于将数据库值扫描到 Upstreams 类型
func (u *Upstreams) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(bytes, u)
}
// GroupConfig 存储特定于分组的配置
type GroupConfig struct {
BlacklistThreshold *int `json:"blacklist_threshold,omitempty"`
@@ -50,6 +27,7 @@ type GroupConfig struct {
RequestTimeout *int `json:"request_timeout,omitempty"`
ResponseTimeout *int `json:"response_timeout,omitempty"`
IdleConnTimeout *int `json:"idle_conn_timeout,omitempty"`
KeyValidationIntervalMinutes *int `json:"key_validation_interval_minutes,omitempty"`
}
// Group 对应 groups 表
@@ -58,11 +36,14 @@ type Group struct {
Name string `gorm:"type:varchar(255);not null;unique" json:"name"`
DisplayName string `gorm:"type:varchar(255)" json:"display_name"`
Description string `gorm:"type:varchar(512)" json:"description"`
Upstreams Upstreams `gorm:"type:json;not null" json:"upstreams"`
Upstreams datatypes.JSON `gorm:"type:json;not null" json:"upstreams"`
ChannelType string `gorm:"type:varchar(50);not null" json:"channel_type"`
Sort int `gorm:"default:0" json:"sort"`
TestModel string `gorm:"type:varchar(255);not null" json:"test_model"`
ParamOverrides datatypes.JSONMap `gorm:"type:json" json:"param_overrides"`
Config datatypes.JSONMap `gorm:"type:json" json:"config"`
APIKeys []APIKey `gorm:"foreignKey:GroupID" json:"api_keys"`
LastValidatedAt *time.Time `json:"last_validated_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}

View File

@@ -2,11 +2,14 @@
package proxy
import (
"bytes"
"encoding/json"
"fmt"
"gpt-load/internal/channel"
app_errors "gpt-load/internal/errors"
"gpt-load/internal/models"
"gpt-load/internal/response"
"io"
"sync"
"sync/atomic"
"time"
@@ -19,14 +22,16 @@ import (
// ProxyServer represents the proxy server
type ProxyServer struct {
DB *gorm.DB
channelFactory *channel.Factory
groupCounters sync.Map // map[uint]*atomic.Uint64
requestLogChan chan models.RequestLog
}
// NewProxyServer creates a new proxy server
func NewProxyServer(db *gorm.DB, requestLogChan chan models.RequestLog) (*ProxyServer, error) {
func NewProxyServer(db *gorm.DB, channelFactory *channel.Factory, requestLogChan chan models.RequestLog) (*ProxyServer, error) {
return &ProxyServer{
DB: db,
channelFactory: channelFactory,
groupCounters: sync.Map{},
requestLogChan: requestLogChan,
}, nil
@@ -52,16 +57,25 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) {
}
// 3. Get the appropriate channel handler from the factory
channelHandler, err := channel.GetChannel(&group)
channelHandler, err := ps.channelFactory.GetChannel(&group)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s': %v", groupName, err)))
return
}
// 4. Forward the request using the channel handler
// 4. Apply parameter overrides if they exist
if len(group.ParamOverrides) > 0 {
err := ps.applyParamOverrides(c, &group)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to apply parameter overrides: %v", err)))
return
}
}
// 5. Forward the request using the channel handler
err = channelHandler.Handle(c, apiKey, &group)
// 5. Log the request asynchronously
// 6. Log the request asynchronously
isSuccess := err == nil
if !isSuccess {
logrus.WithFields(logrus.Fields{
@@ -145,3 +159,51 @@ func (ps *ProxyServer) updateKeyStats(keyID uint, success bool) {
func (ps *ProxyServer) Close() {
// Nothing to close for now
}
func (ps *ProxyServer) applyParamOverrides(c *gin.Context, group *models.Group) error {
// Read the original request body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return fmt.Errorf("failed to read request body: %w", err)
}
c.Request.Body.Close() // Close the original body
// If body is empty, nothing to override, just restore the body
if len(bodyBytes) == 0 {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
return nil
}
// Save the original Content-Type
originalContentType := c.GetHeader("Content-Type")
// Unmarshal the body into a map
var requestData map[string]interface{}
if err := json.Unmarshal(bodyBytes, &requestData); err != nil {
// If not a valid JSON, just pass it through
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
return nil
}
// Merge the overrides into the request data
for key, value := range group.ParamOverrides {
requestData[key] = value
}
// Marshal the new data back to JSON
newBodyBytes, err := json.Marshal(requestData)
if err != nil {
return fmt.Errorf("failed to marshal new request body: %w", err)
}
// Replace the request body with the new one
c.Request.Body = io.NopCloser(bytes.NewBuffer(newBodyBytes))
c.Request.ContentLength = int64(len(newBodyBytes))
// Restore the original Content-Type header
if originalContentType != "" {
c.Request.Header.Set("Content-Type", originalContentType)
}
return nil
}

View File

@@ -106,13 +106,27 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser
groups.PUT("/:id", serverHandler.UpdateGroup)
groups.DELETE("/:id", serverHandler.DeleteGroup)
// Key-specific routes
keys := groups.Group("/:id/keys")
{
keys.POST("", serverHandler.CreateKeysInGroup)
keys.GET("", serverHandler.ListKeysInGroup)
keys.PUT("/:key_id", serverHandler.UpdateKey)
keys.DELETE("", serverHandler.DeleteKeys)
keys.POST("/add-multiple", serverHandler.AddMultipleKeys)
keys.POST("/restore-all-invalid", serverHandler.RestoreAllInvalidKeys)
keys.POST("/clear-all-invalid", serverHandler.ClearAllInvalidKeys)
keys.GET("/export", serverHandler.ExportKeys)
keys.DELETE("/:key_id", serverHandler.DeleteSingleKey)
keys.POST("/:key_id/test", serverHandler.TestSingleKey)
}
// Group-level actions
groups.POST("/:id/validate-keys", serverHandler.ValidateGroupKeys)
}
// Tasks
tasks := api.Group("/tasks")
{
tasks.GET("/key-validation/status", serverHandler.GetTaskStatus)
tasks.GET("/:task_id/result", serverHandler.GetTaskResult)
}
// 仪表板和日志

View File

@@ -0,0 +1,206 @@
package services
import (
"context"
"gpt-load/internal/config"
"gpt-load/internal/models"
"sync"
"time"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// KeyCronService is responsible for periodically validating all API keys.
type KeyCronService struct {
DB *gorm.DB
Validator *KeyValidatorService
SettingsManager *config.SystemSettingsManager
stopChan chan struct{}
wg sync.WaitGroup
}
// NewKeyCronService creates a new KeyCronService.
func NewKeyCronService(db *gorm.DB, validator *KeyValidatorService, settingsManager *config.SystemSettingsManager) *KeyCronService {
return &KeyCronService{
DB: db,
Validator: validator,
SettingsManager: settingsManager,
stopChan: make(chan struct{}),
}
}
// Start begins the cron job.
func (s *KeyCronService) Start() {
logrus.Info("Starting KeyCronService...")
s.wg.Add(1)
go s.run()
}
// Stop stops the cron job.
func (s *KeyCronService) Stop() {
logrus.Info("Stopping KeyCronService...")
close(s.stopChan)
s.wg.Wait()
logrus.Info("KeyCronService stopped.")
}
func (s *KeyCronService) run() {
defer s.wg.Done()
ctx := context.Background()
// Run once on start
s.validateAllGroups(ctx)
for {
// Dynamically get the interval for the next run
intervalMinutes := s.SettingsManager.GetInt("key_validation_interval_minutes", 60)
if intervalMinutes <= 0 {
intervalMinutes = 60 // Fallback to a safe default
}
nextRunTimer := time.NewTimer(time.Duration(intervalMinutes) * time.Minute)
select {
case <-nextRunTimer.C:
s.validateAllGroups(ctx)
case <-s.stopChan:
nextRunTimer.Stop()
return
}
}
}
func (s *KeyCronService) validateAllGroups(ctx context.Context) {
logrus.Info("KeyCronService: Starting validation cycle for all groups.")
var groups []models.Group
if err := s.DB.Find(&groups).Error; err != nil {
logrus.Errorf("KeyCronService: Failed to get groups: %v", err)
return
}
for _, group := range groups {
groupCopy := group // Create a copy for the closure
go func(g models.Group) {
// Get effective settings for the group
effectiveSettings := s.SettingsManager.GetEffectiveConfig(g.Config)
interval := time.Duration(effectiveSettings.KeyValidationIntervalMinutes) * time.Minute
// Check if it's time to validate this group
if g.LastValidatedAt == nil || time.Since(*g.LastValidatedAt) > interval {
s.validateGroup(ctx, &g)
}
}(groupCopy)
}
logrus.Info("KeyCronService: Validation cycle finished.")
}
func (s *KeyCronService) validateGroup(ctx context.Context, group *models.Group) {
var keys []models.APIKey
if err := s.DB.Where("group_id = ?", group.ID).Find(&keys).Error; err != nil {
logrus.Errorf("KeyCronService: Failed to get keys for group %s: %v", group.Name, err)
return
}
if len(keys) == 0 {
return
}
logrus.Infof("KeyCronService: Validating %d keys for group %s", len(keys), group.Name)
jobs := make(chan models.APIKey, len(keys))
results := make(chan models.APIKey, len(keys))
concurrency := s.SettingsManager.GetInt("key_validation_concurrency", 10)
if concurrency <= 0 {
concurrency = 10 // Fallback to a safe default
}
var wg sync.WaitGroup
for i := 0; i < concurrency; i++ {
wg.Add(1)
go s.worker(ctx, &wg, group, jobs, results)
}
for _, key := range keys {
jobs <- key
}
close(jobs)
wg.Wait()
close(results)
var keysToUpdate []models.APIKey
for key := range results {
keysToUpdate = append(keysToUpdate, key)
}
if len(keysToUpdate) > 0 {
s.batchUpdateKeyStatus(keysToUpdate)
}
// Update the last validated timestamp for the group
if err := s.DB.Model(group).Update("last_validated_at", time.Now()).Error; err != nil {
logrus.Errorf("KeyCronService: Failed to update last_validated_at for group %s: %v", group.Name, err)
}
}
func (s *KeyCronService) worker(ctx context.Context, wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- models.APIKey) {
defer wg.Done()
for key := range jobs {
isValid, err := s.Validator.ValidateSingleKey(ctx, &key, group)
// Only update status if there was no error during validation
if err != nil {
logrus.Warnf("KeyCronService: Failed to validate key ID %d for group %s: %v. Skipping status update.", key.ID, group.Name, err)
continue
}
newStatus := "inactive"
if isValid {
newStatus = "active"
}
// Only send to results if the status has changed
if key.Status != newStatus {
key.Status = newStatus
results <- key
}
}
}
func (s *KeyCronService) batchUpdateKeyStatus(keys []models.APIKey) {
if len(keys) == 0 {
return
}
logrus.Infof("KeyCronService: Batch updating status for %d keys.", len(keys))
activeIDs := []uint{}
inactiveIDs := []uint{}
for _, key := range keys {
if key.Status == "active" {
activeIDs = append(activeIDs, key.ID)
} else {
inactiveIDs = append(inactiveIDs, key.ID)
}
}
err := s.DB.Transaction(func(tx *gorm.DB) error {
if len(activeIDs) > 0 {
if err := tx.Model(&models.APIKey{}).Where("id IN ?", activeIDs).Update("status", "active").Error; err != nil {
return err
}
logrus.Infof("KeyCronService: Set %d keys to 'active'.", len(activeIDs))
}
if len(inactiveIDs) > 0 {
if err := tx.Model(&models.APIKey{}).Where("id IN ?", inactiveIDs).Update("status", "inactive").Error; err != nil {
return err
}
logrus.Infof("KeyCronService: Set %d keys to 'inactive'.", len(inactiveIDs))
}
return nil
})
if err != nil {
logrus.Errorf("KeyCronService: Failed to batch update key status: %v", err)
}
}

View File

@@ -0,0 +1,122 @@
package services
import (
"context"
"fmt"
"gpt-load/internal/config"
"gpt-load/internal/models"
"sync"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// ManualValidationResult holds the result of a manual validation task.
type ManualValidationResult struct {
TotalKeys int `json:"total_keys"`
ValidKeys int `json:"valid_keys"`
InvalidKeys int `json:"invalid_keys"`
}
// KeyManualValidationService handles user-initiated key validation for a group.
type KeyManualValidationService struct {
DB *gorm.DB
Validator *KeyValidatorService
TaskService *TaskService
SettingsManager *config.SystemSettingsManager
}
// NewKeyManualValidationService creates a new KeyManualValidationService.
func NewKeyManualValidationService(db *gorm.DB, validator *KeyValidatorService, taskService *TaskService, settingsManager *config.SystemSettingsManager) *KeyManualValidationService {
return &KeyManualValidationService{
DB: db,
Validator: validator,
TaskService: taskService,
SettingsManager: settingsManager,
}
}
// StartValidationTask starts a new manual validation task for a given group.
func (s *KeyManualValidationService) StartValidationTask(group *models.Group) (*TaskStatus, error) {
var keys []models.APIKey
if err := s.DB.Where("group_id = ?", group.ID).Find(&keys).Error; err != nil {
return nil, fmt.Errorf("failed to get keys for group %s: %w", group.Name, err)
}
if len(keys) == 0 {
return nil, fmt.Errorf("no keys to validate in group %s", group.Name)
}
taskID := uuid.New().String()
timeoutMinutes := s.SettingsManager.GetInt("key_validation_task_timeout_minutes", 60)
timeout := time.Duration(timeoutMinutes) * time.Minute
taskStatus, err := s.TaskService.StartTask(taskID, group.Name, len(keys), timeout)
if err != nil {
return nil, err // A task is already running
}
// Run the validation in a separate goroutine
go s.runValidation(group, keys, taskStatus)
return taskStatus, nil
}
func (s *KeyManualValidationService) runValidation(group *models.Group, keys []models.APIKey, task *TaskStatus) {
defer s.TaskService.EndTask()
logrus.Infof("Starting manual validation for group %s (TaskID: %s)", group.Name, task.TaskID)
jobs := make(chan models.APIKey, len(keys))
results := make(chan bool, len(keys))
concurrency := s.SettingsManager.GetInt("key_validation_concurrency", 10)
if concurrency <= 0 {
concurrency = 10 // Fallback to a safe default
}
var wg sync.WaitGroup
for i := 0; i < concurrency; i++ {
wg.Add(1)
go s.validationWorker(&wg, group, jobs, results)
}
for _, key := range keys {
jobs <- key
}
close(jobs)
wg.Wait()
close(results)
validCount := 0
processedCount := 0
for isValid := range results {
processedCount++
if isValid {
validCount++
}
// Update progress
s.TaskService.UpdateProgress(processedCount)
}
result := ManualValidationResult{
TotalKeys: len(keys),
ValidKeys: validCount,
InvalidKeys: len(keys) - validCount,
}
// Store the final result
s.TaskService.StoreResult(task.TaskID, result)
logrus.Infof("Manual validation finished for group %s (TaskID: %s): %+v", group.Name, task.TaskID, result)
}
func (s *KeyManualValidationService) validationWorker(wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- bool) {
defer wg.Done()
for key := range jobs {
isValid, _ := s.Validator.ValidateSingleKey(context.Background(), &key, group)
results <- isValid
}
}

View File

@@ -0,0 +1,206 @@
package services
import (
"encoding/json"
"fmt"
"gpt-load/internal/models"
"regexp"
"strings"
"gorm.io/gorm"
)
// AddKeysResult holds the result of adding multiple keys.
type AddKeysResult struct {
AddedCount int `json:"added_count"`
IgnoredCount int `json:"ignored_count"`
TotalInGroup int64 `json:"total_in_group"`
}
// KeyService provides services related to API keys.
type KeyService struct {
DB *gorm.DB
}
// NewKeyService creates a new KeyService.
func NewKeyService(db *gorm.DB) *KeyService {
return &KeyService{DB: db}
}
// AddMultipleKeys handles the business logic of creating new keys from a text block.
func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysResult, error) {
// 1. Parse keys from the text block
keys := s.parseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found in the input text")
}
// 2. Get the group information for validation
var group models.Group
if err := s.DB.First(&group, groupID).Error; err != nil {
return nil, fmt.Errorf("failed to find group: %w", err)
}
// 3. Get existing keys in the group for deduplication
var existingKeys []models.APIKey
if err := s.DB.Where("group_id = ?", groupID).Select("key_value").Find(&existingKeys).Error; err != nil {
return nil, err
}
existingKeyMap := make(map[string]bool)
for _, k := range existingKeys {
existingKeyMap[k.KeyValue] = true
}
// 4. Prepare new keys with basic validation only
var newKeysToCreate []models.APIKey
uniqueNewKeys := make(map[string]bool)
for _, keyVal := range keys {
trimmedKey := strings.TrimSpace(keyVal)
if trimmedKey == "" {
continue
}
// Check if key already exists
if existingKeyMap[trimmedKey] || uniqueNewKeys[trimmedKey] {
continue
}
// 通用验证:只做基础格式检查,不做渠道特定验证
if s.isValidKeyFormat(trimmedKey) {
uniqueNewKeys[trimmedKey] = true
newKeysToCreate = append(newKeysToCreate, models.APIKey{
GroupID: groupID,
KeyValue: trimmedKey,
Status: "active",
})
}
}
addedCount := len(newKeysToCreate)
// 更准确的忽略计数:包括重复的和无效的
ignoredCount := len(keys) - addedCount
// 5. Insert new keys if any
if addedCount > 0 {
if err := s.DB.Create(&newKeysToCreate).Error; err != nil {
return nil, err
}
}
// 6. Get the new total count
var totalInGroup int64
if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil {
return nil, err
}
return &AddKeysResult{
AddedCount: addedCount,
IgnoredCount: ignoredCount,
TotalInGroup: totalInGroup,
}, nil
}
func (s *KeyService) parseKeysFromText(text string) []string {
var keys []string
// First, try to parse as a JSON array of strings
if json.Unmarshal([]byte(text), &keys) == nil && len(keys) > 0 {
return s.filterValidKeys(keys)
}
// 通用解析:通过分隔符分割文本,不使用复杂的正则表达式
delimiters := regexp.MustCompile(`[\s,;|\n\r\t]+`)
splitKeys := delimiters.Split(strings.TrimSpace(text), -1)
for _, key := range splitKeys {
key = strings.TrimSpace(key)
if key != "" {
keys = append(keys, key)
}
}
return s.filterValidKeys(keys)
}
// filterValidKeys validates and filters potential API keys
func (s *KeyService) filterValidKeys(keys []string) []string {
var validKeys []string
for _, key := range keys {
key = strings.TrimSpace(key)
if s.isValidKeyFormat(key) {
validKeys = append(validKeys, key)
}
}
return validKeys
}
// isValidKeyFormat performs basic validation on key format
func (s *KeyService) isValidKeyFormat(key string) bool {
if len(key) < 4 || len(key) > 1000 {
return false
}
if key == "" ||
strings.TrimSpace(key) == "" {
return false
}
validChars := regexp.MustCompile(`^[a-zA-Z0-9_\-./+=:]+$`)
return validChars.MatchString(key)
}
// RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'.
func (s *KeyService) RestoreAllInvalidKeys(groupID uint) (int64, error) {
result := s.DB.Model(&models.APIKey{}).Where("group_id = ? AND status = ?", groupID, "inactive").Update("status", "active")
return result.RowsAffected, result.Error
}
// ClearAllInvalidKeys deletes all 'inactive' keys from a group.
func (s *KeyService) ClearAllInvalidKeys(groupID uint) (int64, error) {
result := s.DB.Where("group_id = ? AND status = ?", groupID, "inactive").Delete(&models.APIKey{})
return result.RowsAffected, result.Error
}
// DeleteSingleKey deletes a specific key from a group.
func (s *KeyService) DeleteSingleKey(groupID, keyID uint) (int64, error) {
result := s.DB.Where("group_id = ? AND id = ?", groupID, keyID).Delete(&models.APIKey{})
return result.RowsAffected, result.Error
}
// ExportKeys returns a list of keys for a group, filtered by status.
func (s *KeyService) ExportKeys(groupID uint, filter string) ([]string, error) {
query := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID)
switch filter {
case "valid":
query = query.Where("status = ?", "active")
case "invalid":
query = query.Where("status = ?", "inactive")
case "all":
// No status filter needed
default:
return nil, fmt.Errorf("invalid filter value. Use 'all', 'valid', or 'invalid'")
}
var keys []string
if err := query.Pluck("key_value", &keys).Error; err != nil {
return nil, err
}
return keys, nil
}
// ListKeysInGroup lists all keys within a specific group, filtered by status.
func (s *KeyService) ListKeysInGroup(groupID uint, statusFilter string) ([]models.APIKey, error) {
var keys []models.APIKey
query := s.DB.Where("group_id = ?", groupID)
if statusFilter != "" {
query = query.Where("status = ?", statusFilter)
}
if err := query.Find(&keys).Error; err != nil {
return nil, err
}
return keys, nil
}

View File

@@ -0,0 +1,91 @@
package services
import (
"context"
"fmt"
"gpt-load/internal/channel"
"gpt-load/internal/models"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// KeyValidatorService provides methods to validate API keys.
type KeyValidatorService struct {
DB *gorm.DB
channelFactory *channel.Factory
}
// NewKeyValidatorService creates a new KeyValidatorService.
func NewKeyValidatorService(db *gorm.DB, factory *channel.Factory) *KeyValidatorService {
return &KeyValidatorService{
DB: db,
channelFactory: factory,
}
}
// ValidateSingleKey performs a validation check on a single API key.
// It does not modify the key's state in the database.
// It returns true if the key is valid, and an error if it's not.
func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models.APIKey, group *models.Group) (bool, error) {
// 添加超时保护
if ctx.Err() != nil {
return false, fmt.Errorf("context cancelled or timed out: %w", ctx.Err())
}
ch, err := s.channelFactory.GetChannel(group)
if err != nil {
logrus.WithFields(logrus.Fields{
"group_id": group.ID,
"group_name": group.Name,
"channel_type": group.ChannelType,
"error": err,
}).Error("Failed to get channel for key validation")
return false, fmt.Errorf("failed to get channel for group %s: %w", group.Name, err)
}
// 记录验证开始
logrus.WithFields(logrus.Fields{
"key_id": key.ID,
"group_id": group.ID,
"group_name": group.Name,
}).Debug("Starting key validation")
isValid, validationErr := ch.ValidateKey(ctx, key.KeyValue)
if validationErr != nil {
logrus.WithFields(logrus.Fields{
"key_id": key.ID,
"group_id": group.ID,
"group_name": group.Name,
"error": validationErr,
}).Warn("Key validation failed")
return false, validationErr
}
// 记录验证结果
logrus.WithFields(logrus.Fields{
"key_id": key.ID,
"group_id": group.ID,
"group_name": group.Name,
"is_valid": isValid,
}).Debug("Key validation completed")
return isValid, nil
}
// TestSingleKeyByID performs a synchronous validation test for a single API key by its ID.
// It is intended for handling user-initiated "Test" actions.
// It does not modify the key's state in the database.
func (s *KeyValidatorService) TestSingleKeyByID(ctx context.Context, keyID uint) (bool, error) {
var apiKey models.APIKey
if err := s.DB.First(&apiKey, keyID).Error; err != nil {
return false, fmt.Errorf("failed to find api key with id %d: %w", keyID, err)
}
var group models.Group
if err := s.DB.First(&group, apiKey.GroupID).Error; err != nil {
return false, fmt.Errorf("failed to find group with id %d: %w", apiKey.GroupID, err)
}
return s.ValidateSingleKey(ctx, &apiKey, &group)
}

View File

@@ -0,0 +1,125 @@
package services
import (
"errors"
"sync"
"time"
)
// TaskStatus represents the status of a long-running task.
type TaskStatus struct {
IsRunning bool `json:"is_running"`
GroupName string `json:"group_name,omitempty"`
Processed int `json:"processed,omitempty"`
Total int `json:"total,omitempty"`
TaskID string `json:"task_id,omitempty"`
ExpiresAt time.Time `json:"-"` // Internal field to handle zombie tasks
lastUpdated time.Time
}
// TaskService manages the state of a single, global, long-running task.
type TaskService struct {
mu sync.Mutex
status TaskStatus
resultsCache map[string]interface{}
cacheOrder []string
maxCacheSize int
}
// NewTaskService creates a new TaskService.
func NewTaskService() *TaskService {
return &TaskService{
resultsCache: make(map[string]interface{}),
cacheOrder: make([]string, 0),
maxCacheSize: 100, // Store results for the last 100 tasks
}
}
// StartTask attempts to start a new task. It returns an error if a task is already running.
func (s *TaskService) StartTask(taskID, groupName string, total int, timeout time.Duration) (*TaskStatus, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Zombie task check
if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) {
// The previous task is considered a zombie, reset it.
s.status = TaskStatus{}
}
if s.status.IsRunning {
return nil, errors.New("a task is already running")
}
s.status = TaskStatus{
IsRunning: true,
TaskID: taskID,
GroupName: groupName,
Total: total,
Processed: 0,
ExpiresAt: time.Now().Add(timeout),
lastUpdated: time.Now(),
}
return &s.status, nil
}
// GetTaskStatus returns the current status of the task.
func (s *TaskService) GetTaskStatus() *TaskStatus {
s.mu.Lock()
defer s.mu.Unlock()
// Zombie task check
if s.status.IsRunning && time.Now().After(s.status.ExpiresAt) {
s.status = TaskStatus{} // Reset if expired
}
// Return a copy to prevent race conditions on the caller's side
statusCopy := s.status
return &statusCopy
}
// UpdateProgress updates the progress of the current task.
func (s *TaskService) UpdateProgress(processed int) {
s.mu.Lock()
defer s.mu.Unlock()
if !s.status.IsRunning {
return
}
s.status.Processed = processed
s.status.lastUpdated = time.Now()
}
// EndTask marks the current task as finished.
func (s *TaskService) EndTask() {
s.mu.Lock()
defer s.mu.Unlock()
s.status.IsRunning = false
}
// StoreResult stores the result of a finished task.
func (s *TaskService) StoreResult(taskID string, result interface{}) {
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.resultsCache[taskID]; !exists {
if len(s.cacheOrder) >= s.maxCacheSize {
oldestTaskID := s.cacheOrder[0]
delete(s.resultsCache, oldestTaskID)
s.cacheOrder = s.cacheOrder[1:]
}
s.cacheOrder = append(s.cacheOrder, taskID)
}
s.resultsCache[taskID] = result
}
// GetResult retrieves the result of a finished task.
func (s *TaskService) GetResult(taskID string) (interface{}, bool) {
s.mu.Lock()
defer s.mu.Unlock()
result, found := s.resultsCache[taskID]
return result, found
}