feat: key接口完善及错误处理

This commit is contained in:
tbphp
2025-07-05 14:50:58 +08:00
parent 8d7b60875e
commit d64ada4181
12 changed files with 487 additions and 208 deletions

View File

@@ -26,6 +26,7 @@ type BaseChannel struct {
Name string
Upstreams []UpstreamInfo
HTTPClient *http.Client
TestModel string
upstreamLock sync.Mutex
}

View File

@@ -60,7 +60,7 @@ func (f *Factory) GetChannel(group *models.Group) (ChannelProxy, error) {
}
// newBaseChannel is a helper function to create and configure a BaseChannel.
func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, groupConfig datatypes.JSONMap) (*BaseChannel, error) {
func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, groupConfig datatypes.JSONMap, testModel string) (*BaseChannel, error) {
type upstreamDef struct {
URL string `json:"url"`
Weight int `json:"weight"`
@@ -103,5 +103,6 @@ func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, grou
Name: name,
Upstreams: upstreamInfos,
HTTPClient: httpClient,
TestModel: testModel,
}, nil
}

View File

@@ -1,11 +1,14 @@
package channel
import (
"bytes"
"context"
"encoding/json"
"fmt"
app_errors "gpt-load/internal/errors"
"gpt-load/internal/models"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
@@ -20,7 +23,7 @@ type GeminiChannel struct {
}
func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config)
base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config, group.TestModel)
if err != nil {
return nil, err
}
@@ -39,20 +42,35 @@ func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo
return ch.ProcessRequest(c, apiKey, modifier, ch)
}
// ValidateKey checks if the given API key is valid by making a request to the models endpoint.
// ValidateKey checks if the given API key is valid by making a generateContent request.
func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
upstreamURL := ch.getUpstreamURL()
if upstreamURL == nil {
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
}
// Construct the request URL for listing models.
reqURL := fmt.Sprintf("%s/v1beta/models?key=%s", upstreamURL.String(), key)
// Use the test model specified in the group settings.
// The path format for Gemini is /v1beta/models/{model}:generateContent
reqURL := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", upstreamURL.String(), ch.TestModel, key)
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
// Use a minimal, low-cost payload for validation
payload := gin.H{
"contents": []gin.H{
{"parts": []gin.H{
{"text": "Only output 'ok'"},
}},
},
}
body, err := json.Marshal(payload)
if err != nil {
return false, fmt.Errorf("failed to marshal validation payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, bytes.NewBuffer(body))
if err != nil {
return false, fmt.Errorf("failed to create validation request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := ch.HTTPClient.Do(req)
if err != nil {
@@ -61,7 +79,20 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, err
defer resp.Body.Close()
// A 200 OK status code indicates the key is valid.
return resp.StatusCode == http.StatusOK, nil
if resp.StatusCode == http.StatusOK {
return true, nil
}
// For non-200 responses, parse the body to provide a more specific error reason.
errorBody, err := io.ReadAll(resp.Body)
if err != nil {
return false, fmt.Errorf("key is invalid (status %d), but failed to read error body: %w", resp.StatusCode, err)
}
// Use the new parser to extract a clean error message.
parsedError := app_errors.ParseUpstreamError(errorBody)
return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
}
// IsStreamingRequest checks if the request is for a streaming response.

View File

@@ -1,9 +1,13 @@
package channel
import (
"bytes"
"context"
"encoding/json"
"fmt"
app_errors "gpt-load/internal/errors"
"gpt-load/internal/models"
"io"
"net/http"
"github.com/gin-gonic/gin"
@@ -19,7 +23,7 @@ type OpenAIChannel struct {
}
func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
base, err := f.newBaseChannel("openai", group.Upstreams, group.Config)
base, err := f.newBaseChannel("openai", group.Upstreams, group.Config, group.TestModel)
if err != nil {
return nil, err
}
@@ -36,21 +40,34 @@ func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo
return ch.ProcessRequest(c, apiKey, modifier, ch)
}
// ValidateKey checks if the given API key is valid by making a request to the models endpoint.
// ValidateKey checks if the given API key is valid by making a chat completion request.
func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
upstreamURL := ch.getUpstreamURL()
if upstreamURL == nil {
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
}
// Construct the request URL for listing models, a common endpoint for key validation.
reqURL := upstreamURL.String() + "/v1/models"
reqURL := upstreamURL.String() + "/v1/chat/completions"
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
// Use a minimal, low-cost payload for validation
payload := gin.H{
"model": ch.TestModel,
"messages": []gin.H{
{"role": "user", "content": "Only output 'ok'"},
},
"max_tokens": 1,
}
body, err := json.Marshal(payload)
if err != nil {
return false, fmt.Errorf("failed to marshal validation payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, bytes.NewBuffer(body))
if err != nil {
return false, fmt.Errorf("failed to create validation request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+key)
req.Header.Set("Content-Type", "application/json")
resp, err := ch.HTTPClient.Do(req)
if err != nil {
@@ -58,9 +75,21 @@ func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, err
}
defer resp.Body.Close()
// A 200 OK status code indicates the key is valid.
// Other status codes (e.g., 401 Unauthorized) indicate an invalid key.
return resp.StatusCode == http.StatusOK, nil
// A 200 OK status code indicates the key is valid and can make requests.
if resp.StatusCode == http.StatusOK {
return true, nil
}
// For non-200 responses, parse the body to provide a more specific error reason.
errorBody, err := io.ReadAll(resp.Body)
if err != nil {
return false, fmt.Errorf("key is invalid (status %d), but failed to read error body: %w", resp.StatusCode, err)
}
// Use the new parser to extract a clean error message.
parsedError := app_errors.ParseUpstreamError(errorBody)
return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
}
// IsStreamingRequest checks if the request is for a streaming response.

81
internal/errors/parser.go Normal file
View File

@@ -0,0 +1,81 @@
package errors
import (
"encoding/json"
"strings"
)
const (
// maxErrorBodyLength defines the maximum length of an error message to be stored or returned.
maxErrorBodyLength = 2048
)
// standardErrorResponse matches formats like: {"error": {"message": "..."}}
type standardErrorResponse struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
// vendorErrorResponse matches formats like: {"error_msg": "..."}
type vendorErrorResponse struct {
ErrorMsg string `json:"error_msg"`
}
// simpleErrorResponse matches formats like: {"error": "..."}
type simpleErrorResponse struct {
Error string `json:"error"`
}
// rootMessageErrorResponse matches formats like: {"message": "..."}
type rootMessageErrorResponse struct {
Message string `json:"message"`
}
// ParseUpstreamError attempts to parse a structured error message from an upstream response body
// using a chain of responsibility pattern. It tries various common formats and gracefully
// degrades to a raw string if all parsing attempts fail.
func ParseUpstreamError(body []byte) string {
// 1. Attempt to parse the standard OpenAI/Gemini format.
var stdErr standardErrorResponse
if err := json.Unmarshal(body, &stdErr); err == nil {
if msg := strings.TrimSpace(stdErr.Error.Message); msg != "" {
return truncateString(msg, maxErrorBodyLength)
}
}
// 2. Attempt to parse vendor-specific format (e.g., Baidu).
var vendorErr vendorErrorResponse
if err := json.Unmarshal(body, &vendorErr); err == nil {
if msg := strings.TrimSpace(vendorErr.ErrorMsg); msg != "" {
return truncateString(msg, maxErrorBodyLength)
}
}
// 3. Attempt to parse simple error format.
var simpleErr simpleErrorResponse
if err := json.Unmarshal(body, &simpleErr); err == nil {
if msg := strings.TrimSpace(simpleErr.Error); msg != "" {
return truncateString(msg, maxErrorBodyLength)
}
}
// 4. Attempt to parse root-level message format.
var rootMsgErr rootMessageErrorResponse
if err := json.Unmarshal(body, &rootMsgErr); err == nil {
if msg := strings.TrimSpace(rootMsgErr.Message); msg != "" {
return truncateString(msg, maxErrorBodyLength)
}
}
// 5. Graceful Degradation: If all parsing fails, return the raw (but safe) body.
return truncateString(string(body), maxErrorBodyLength)
}
// truncateString ensures a string does not exceed a maximum length.
func truncateString(s string, maxLength int) string {
if len(s) > maxLength {
return s[:maxLength]
}
return s
}

View File

@@ -12,51 +12,36 @@ import (
"gorm.io/gorm"
)
// validateGroupID validates and parses group ID from request parameter
func validateGroupID(c *gin.Context) (uint, error) {
groupIDStr := c.Param("id")
// validateGroupIDFromQuery validates and parses group ID from a query parameter.
func validateGroupIDFromQuery(c *gin.Context) (uint, error) {
groupIDStr := c.Query("group_id")
if groupIDStr == "" {
return 0, fmt.Errorf("group ID is required")
return 0, fmt.Errorf("group_id query parameter is required")
}
groupID, err := strconv.Atoi(groupIDStr)
if err != nil || groupID <= 0 {
return 0, fmt.Errorf("invalid group ID format")
return 0, fmt.Errorf("invalid group_id format")
}
return uint(groupID), nil
}
// validateKeyID validates and parses key ID from request parameter
func validateKeyID(c *gin.Context) (uint, error) {
keyIDStr := c.Param("key_id")
if keyIDStr == "" {
return 0, fmt.Errorf("key ID is required")
}
keyID, err := strconv.Atoi(keyIDStr)
if err != nil || keyID <= 0 {
return 0, fmt.Errorf("invalid key ID format")
}
return uint(keyID), nil
}
// validateKeysText validates the keys text input
func validateKeysText(keysText string) error {
if strings.TrimSpace(keysText) == "" {
return fmt.Errorf("keys text cannot be empty")
}
if len(keysText) > 1024*1024 { // 1MB limit
return fmt.Errorf("keys text is too large (max 1MB)")
if len(keysText) > 10*1024*1024 {
return fmt.Errorf("keys text is too large (max 10MB)")
}
return nil
}
// findGroupByID is a helper function to find a group by its ID.
func (s *Server) findGroupByID(c *gin.Context, groupID int) (*models.Group, bool) {
func (s *Server) findGroupByID(c *gin.Context, groupID uint) (*models.Group, bool) {
var group models.Group
if err := s.DB.First(&group, groupID).Error; err != nil {
if err == gorm.ErrRecordNotFound {
@@ -69,22 +54,26 @@ func (s *Server) findGroupByID(c *gin.Context, groupID int) (*models.Group, bool
return &group, true
}
// AddMultipleKeysRequest defines the payload for adding multiple keys from a text block.
type AddMultipleKeysRequest struct {
// KeyTextRequest defines a generic payload for operations requiring a group ID and a text block of keys.
type KeyTextRequest struct {
GroupID uint `json:"group_id" binding:"required"`
KeysText string `json:"keys_text" binding:"required"`
}
// GroupIDRequest defines a generic payload for operations requiring only a group ID.
type GroupIDRequest struct {
GroupID uint `json:"group_id" binding:"required"`
}
// AddMultipleKeys handles creating new keys from a text block within a specific group.
func (s *Server) AddMultipleKeys(c *gin.Context) {
groupID, err := validateGroupID(c)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error()))
var req KeyTextRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
return
}
var req AddMultipleKeysRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
if _, ok := s.findGroupByID(c, req.GroupID); !ok {
return
}
@@ -93,20 +82,28 @@ func (s *Server) AddMultipleKeys(c *gin.Context) {
return
}
result, err := s.KeyService.AddMultipleKeys(groupID, req.KeysText)
result, err := s.KeyService.AddMultipleKeys(req.GroupID, req.KeysText)
if err != nil {
response.Error(c, app_errors.ParseDBError(err))
if err.Error() == "no valid keys found in the input text" {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
} else {
response.Error(c, app_errors.ParseDBError(err))
}
return
}
response.Success(c, result)
}
// ListKeysInGroup handles listing all keys within a specific group.
// ListKeysInGroup handles listing all keys within a specific group with pagination.
func (s *Server) ListKeysInGroup(c *gin.Context) {
groupID, err := strconv.Atoi(c.Param("id"))
groupID, err := validateGroupIDFromQuery(c)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error()))
return
}
if _, ok := s.findGroupByID(c, groupID); !ok {
return
}
@@ -116,72 +113,93 @@ func (s *Server) ListKeysInGroup(c *gin.Context) {
return
}
keys, err := s.KeyService.ListKeysInGroup(uint(groupID), statusFilter)
searchKeyword := c.Query("key")
query := s.KeyService.ListKeysInGroupQuery(groupID, statusFilter, searchKeyword)
var keys []models.APIKey
paginatedResult, err := response.Paginate(c, query, &keys)
if err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
response.Success(c, keys)
response.Success(c, paginatedResult)
}
// DeleteSingleKey handles deleting a specific key.
func (s *Server) DeleteSingleKey(c *gin.Context) {
groupID, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
// DeleteMultipleKeys handles deleting keys from a text block within a specific group.
func (s *Server) DeleteMultipleKeys(c *gin.Context) {
var req KeyTextRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
return
}
keyID, err := strconv.Atoi(c.Param("key_id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid key ID"))
if _, ok := s.findGroupByID(c, req.GroupID); !ok {
return
}
rowsAffected, err := s.KeyService.DeleteSingleKey(uint(groupID), uint(keyID))
if err := validateKeysText(req.KeysText); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
return
}
result, err := s.KeyService.DeleteMultipleKeys(req.GroupID, req.KeysText)
if err != nil {
if err.Error() == "no valid keys found in the input text" {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
} else {
response.Error(c, app_errors.ParseDBError(err))
}
return
}
response.Success(c, result)
}
// TestMultipleKeys handles a one-off validation test for multiple keys.
func (s *Server) TestMultipleKeys(c *gin.Context) {
var req KeyTextRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
return
}
group, ok := s.findGroupByID(c, req.GroupID)
if !ok {
return
}
if err := validateKeysText(req.KeysText); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
return
}
// Re-use the parsing logic from the key service
keysToTest := s.KeyService.ParseKeysFromText(req.KeysText)
if len(keysToTest) == 0 {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "no valid keys found in the input text"))
return
}
results, err := s.KeyValidatorService.TestMultipleKeys(c.Request.Context(), group, keysToTest)
if err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}
if rowsAffected == 0 {
response.Error(c, app_errors.ErrResourceNotFound)
return
}
response.Success(c, gin.H{"message": "Key deleted successfully"})
}
// TestSingleKey handles a one-off validation test for a single key.
func (s *Server) TestSingleKey(c *gin.Context) {
keyID, err := validateKeyID(c)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error()))
return
}
isValid, validationErr := s.KeyValidatorService.TestSingleKeyByID(c.Request.Context(), keyID)
if validationErr != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadGateway, validationErr.Error()))
return
}
if isValid {
response.Success(c, gin.H{"success": true, "message": "Key is valid."})
} else {
response.Success(c, gin.H{"success": false, "message": "Key is invalid or has insufficient quota."})
}
response.Success(c, results)
}
// ValidateGroupKeys initiates a manual validation task for all keys in a group.
func (s *Server) ValidateGroupKeys(c *gin.Context) {
groupID, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
var req GroupIDRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
return
}
group, ok := s.findGroupByID(c, groupID)
group, ok := s.findGroupByID(c, req.GroupID)
if !ok {
return
}
@@ -197,13 +215,17 @@ func (s *Server) ValidateGroupKeys(c *gin.Context) {
// RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'.
func (s *Server) RestoreAllInvalidKeys(c *gin.Context) {
groupID, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
var req GroupIDRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
return
}
rowsAffected, err := s.KeyService.RestoreAllInvalidKeys(uint(groupID))
if _, ok := s.findGroupByID(c, req.GroupID); !ok {
return
}
rowsAffected, err := s.KeyService.RestoreAllInvalidKeys(req.GroupID)
if err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
@@ -214,13 +236,17 @@ func (s *Server) RestoreAllInvalidKeys(c *gin.Context) {
// ClearAllInvalidKeys deletes all 'inactive' keys from a group.
func (s *Server) ClearAllInvalidKeys(c *gin.Context) {
groupID, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
var req GroupIDRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
return
}
rowsAffected, err := s.KeyService.ClearAllInvalidKeys(uint(groupID))
if _, ok := s.findGroupByID(c, req.GroupID); !ok {
return
}
rowsAffected, err := s.KeyService.ClearAllInvalidKeys(req.GroupID)
if err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
@@ -229,20 +255,3 @@ func (s *Server) ClearAllInvalidKeys(c *gin.Context) {
response.Success(c, gin.H{"message": fmt.Sprintf("%d invalid keys cleared.", rowsAffected)})
}
// ExportKeys returns a list of keys for a group, filtered by status.
func (s *Server) ExportKeys(c *gin.Context) {
groupID, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
return
}
filter := c.DefaultQuery("filter", "all")
keys, err := s.KeyService.ExportKeys(uint(groupID), filter)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
return
}
response.Success(c, gin.H{"keys": keys})
}

View File

@@ -52,11 +52,12 @@ type Group struct {
// APIKey 对应 api_keys 表
type APIKey struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
GroupID uint `gorm:"not null" json:"group_id"`
KeyValue string `gorm:"type:varchar(512);not null" json:"key_value"`
KeyValue string `gorm:"type:varchar(512);not null;uniqueIndex:idx_group_key" json:"key_value"`
GroupID uint `gorm:"not null;uniqueIndex:idx_group_key" json:"group_id"`
Status string `gorm:"type:varchar(50);not null;default:'active'" json:"status"`
RequestCount int64 `gorm:"not null;default:0" json:"request_count"`
FailureCount int64 `gorm:"not null;default:0" json:"failure_count"`
ErrorReason string `gorm:"type:text" json:"error_reason"`
LastUsedAt *time.Time `json:"last_used_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`

View File

@@ -0,0 +1,74 @@
package response
import (
"math"
"strconv"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
const (
DefaultPageSize = 15
MaxPageSize = 1000
)
// Pagination represents the pagination details in a response.
type Pagination struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalItems int64 `json:"total_items"`
TotalPages int `json:"total_pages"`
}
// PaginatedResponse is the standard structure for all paginated API responses.
type PaginatedResponse struct {
Items interface{} `json:"items"`
Pagination Pagination `json:"pagination"`
}
// Paginate performs pagination on a GORM query and returns a standardized response.
// It takes a Gin context, a GORM query builder, and a destination slice for the results.
func Paginate(c *gin.Context, query *gorm.DB, dest interface{}) (*PaginatedResponse, error) {
// 1. Get page and page size from query parameters
page, err := strconv.Atoi(c.DefaultQuery("page", "1"))
if err != nil || page < 1 {
page = 1
}
pageSize, err := strconv.Atoi(c.DefaultQuery("page_size", strconv.Itoa(DefaultPageSize)))
if err != nil || pageSize <= 0 {
pageSize = DefaultPageSize
}
if pageSize > MaxPageSize {
pageSize = MaxPageSize
}
// 2. Get total count of items
var totalItems int64
if err := query.Count(&totalItems).Error; err != nil {
return nil, err
}
// 3. Calculate offset and total pages
offset := (page - 1) * pageSize
totalPages := int(math.Ceil(float64(totalItems) / float64(pageSize)))
// 4. Retrieve the data for the current page
if err := query.Limit(pageSize).Offset(offset).Find(dest).Error; err != nil {
return nil, err
}
// 5. Construct the paginated response
paginatedData := &PaginatedResponse{
Items: dest,
Pagination: Pagination{
Page: page,
PageSize: pageSize,
TotalItems: totalItems,
TotalPages: totalPages,
},
}
return paginatedData, nil
}

View File

@@ -109,20 +109,18 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser
groups.PUT("/:id", serverHandler.UpdateGroup)
groups.DELETE("/:id", serverHandler.DeleteGroup)
// Key-specific routes
keys := groups.Group("/:id/keys")
{
keys.GET("", serverHandler.ListKeysInGroup)
keys.POST("/add-multiple", serverHandler.AddMultipleKeys)
keys.POST("/restore-all-invalid", serverHandler.RestoreAllInvalidKeys)
keys.POST("/clear-all-invalid", serverHandler.ClearAllInvalidKeys)
keys.GET("/export", serverHandler.ExportKeys)
keys.DELETE("/:key_id", serverHandler.DeleteSingleKey)
keys.POST("/:key_id/test", serverHandler.TestSingleKey)
}
}
// Group-level actions
groups.POST("/:id/validate-keys", serverHandler.ValidateGroupKeys)
// Key Management Routes
keys := api.Group("/keys")
{
keys.GET("", serverHandler.ListKeysInGroup)
keys.POST("/add-multiple", serverHandler.AddMultipleKeys)
keys.POST("/delete-multiple", serverHandler.DeleteMultipleKeys)
keys.POST("/restore-all-invalid", serverHandler.RestoreAllInvalidKeys)
keys.POST("/clear-all-invalid", serverHandler.ClearAllInvalidKeys)
keys.POST("/validate-group", serverHandler.ValidateGroupKeys)
keys.POST("/test-multiple", serverHandler.TestMultipleKeys)
}
// Tasks

View File

@@ -147,21 +147,38 @@ func (s *KeyCronService) validateGroup(ctx context.Context, group *models.Group)
func (s *KeyCronService) worker(ctx context.Context, wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- models.APIKey) {
defer wg.Done()
for key := range jobs {
isValid, err := s.Validator.ValidateSingleKey(ctx, &key, group)
// Only update status if there was no error during validation
if err != nil {
logrus.Warnf("KeyCronService: Failed to validate key ID %d for group %s: %v. Skipping status update.", key.ID, group.Name, err)
continue
isValid, validationErr := s.Validator.ValidateSingleKey(ctx, &key, group)
newStatus := key.Status
newErrorReason := key.ErrorReason
statusChanged := false
if validationErr != nil {
// Validation failed, mark as inactive and record the reason
newStatus = "inactive"
newErrorReason = validationErr.Error()
} else {
// Validation succeeded
if isValid {
newStatus = "active"
newErrorReason = "" // Clear reason on success
} else {
// This case might happen if the key is valid but has no quota, etc.
// The error would be in validationErr, so this branch is less likely.
// We still mark it as inactive but without a specific error from our side.
newStatus = "inactive"
newErrorReason = "Validation returned false without a specific error."
}
}
newStatus := "inactive"
if isValid {
newStatus = "active"
// Check if status or error reason has changed
if key.Status != newStatus || key.ErrorReason != newErrorReason {
statusChanged = true
}
// Only send to results if the status has changed
if key.Status != newStatus {
if statusChanged {
key.Status = newStatus
key.ErrorReason = newErrorReason
results <- key
}
}
@@ -171,36 +188,24 @@ func (s *KeyCronService) batchUpdateKeyStatus(keys []models.APIKey) {
if len(keys) == 0 {
return
}
logrus.Infof("KeyCronService: Batch updating status for %d keys.", len(keys))
activeIDs := []uint{}
inactiveIDs := []uint{}
for _, key := range keys {
if key.Status == "active" {
activeIDs = append(activeIDs, key.ID)
} else {
inactiveIDs = append(inactiveIDs, key.ID)
}
}
logrus.Infof("KeyCronService: Batch updating status/reason for %d keys.", len(keys))
err := s.DB.Transaction(func(tx *gorm.DB) error {
if len(activeIDs) > 0 {
if err := tx.Model(&models.APIKey{}).Where("id IN ?", activeIDs).Update("status", "active").Error; err != nil {
return err
for _, key := range keys {
updates := map[string]interface{}{
"status": key.Status,
"error_reason": key.ErrorReason,
}
logrus.Infof("KeyCronService: Set %d keys to 'active'.", len(activeIDs))
}
if len(inactiveIDs) > 0 {
if err := tx.Model(&models.APIKey{}).Where("id IN ?", inactiveIDs).Update("status", "inactive").Error; err != nil {
return err
if err := tx.Model(&models.APIKey{}).Where("id = ?", key.ID).Updates(updates).Error; err != nil {
// Log the error for this specific key but continue the transaction
logrus.Errorf("KeyCronService: Failed to update key ID %d: %v", key.ID, err)
}
logrus.Infof("KeyCronService: Set %d keys to 'inactive'.", len(inactiveIDs))
}
return nil
return nil // Commit the transaction even if some updates failed
})
if err != nil {
logrus.Errorf("KeyCronService: Failed to batch update key status: %v", err)
// This error is for the transaction itself, not individual updates
logrus.Errorf("KeyCronService: Transaction failed during batch update of key statuses: %v", err)
}
}

View File

@@ -17,6 +17,13 @@ type AddKeysResult struct {
TotalInGroup int64 `json:"total_in_group"`
}
// DeleteKeysResult holds the result of deleting multiple keys.
type DeleteKeysResult struct {
DeletedCount int `json:"deleted_count"`
IgnoredCount int `json:"ignored_count"`
TotalInGroup int64 `json:"total_in_group"`
}
// KeyService provides services related to API keys.
type KeyService struct {
DB *gorm.DB
@@ -30,7 +37,7 @@ func NewKeyService(db *gorm.DB) *KeyService {
// AddMultipleKeys handles the business logic of creating new keys from a text block.
func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysResult, error) {
// 1. Parse keys from the text block
keys := s.parseKeysFromText(keysText)
keys := s.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found in the input text")
}
@@ -101,7 +108,9 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes
}, nil
}
func (s *KeyService) parseKeysFromText(text string) []string {
// ParseKeysFromText parses a string of keys from various formats into a string slice.
// This function is exported to be shared with the handler layer.
func (s *KeyService) ParseKeysFromText(text string) []string {
var keys []string
// First, try to parse as a JSON array of strings
@@ -162,45 +171,52 @@ func (s *KeyService) ClearAllInvalidKeys(groupID uint) (int64, error) {
return result.RowsAffected, result.Error
}
// DeleteSingleKey deletes a specific key from a group.
func (s *KeyService) DeleteSingleKey(groupID, keyID uint) (int64, error) {
result := s.DB.Where("group_id = ? AND id = ?", groupID, keyID).Delete(&models.APIKey{})
return result.RowsAffected, result.Error
}
// ExportKeys returns a list of keys for a group, filtered by status.
func (s *KeyService) ExportKeys(groupID uint, filter string) ([]string, error) {
query := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID)
switch filter {
case "valid":
query = query.Where("status = ?", "active")
case "invalid":
query = query.Where("status = ?", "inactive")
case "all":
// No status filter needed
default:
return nil, fmt.Errorf("invalid filter value. Use 'all', 'valid', or 'invalid'")
// DeleteMultipleKeys handles the business logic of deleting keys from a text block.
func (s *KeyService) DeleteMultipleKeys(groupID uint, keysText string) (*DeleteKeysResult, error) {
// 1. Parse keys from the text block
keysToDelete := s.ParseKeysFromText(keysText)
if len(keysToDelete) == 0 {
return nil, fmt.Errorf("no valid keys found in the input text")
}
var keys []string
if err := query.Pluck("key_value", &keys).Error; err != nil {
// 2. Perform the deletion
// GORM's batch delete doesn't easily return which ones were deleted vs. ignored.
// We perform a bulk delete and then count the remaining to calculate the result.
result := s.DB.Where("group_id = ? AND key_value IN ?", groupID, keysToDelete).Delete(&models.APIKey{})
if result.Error != nil {
return nil, result.Error
}
deletedCount := int(result.RowsAffected)
ignoredCount := len(keysToDelete) - deletedCount
// 3. Get the new total count
var totalInGroup int64
if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil {
return nil, err
}
return keys, nil
return &DeleteKeysResult{
DeletedCount: deletedCount,
IgnoredCount: ignoredCount,
TotalInGroup: totalInGroup,
}, nil
}
// ListKeysInGroup lists all keys within a specific group, filtered by status.
func (s *KeyService) ListKeysInGroup(groupID uint, statusFilter string) ([]models.APIKey, error) {
var keys []models.APIKey
query := s.DB.Where("group_id = ?", groupID)
// ListKeysInGroupQuery builds a query to list all keys within a specific group, filtered by status.
// It returns a GORM query builder, allowing the handler to apply pagination.
func (s *KeyService) ListKeysInGroupQuery(groupID uint, statusFilter string, searchKeyword string) *gorm.DB {
query := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID)
if statusFilter != "" {
query = query.Where("status = ?", statusFilter)
}
if err := query.Find(&keys).Error; err != nil {
return nil, err
if searchKeyword != "" {
// Use LIKE for fuzzy search on the key_value
query = query.Where("key_value LIKE ?", "%"+searchKeyword+"%")
}
return keys, nil
return query
}

View File

@@ -10,6 +10,13 @@ import (
"gorm.io/gorm"
)
// KeyTestResult holds the validation result for a single key.
type KeyTestResult struct {
KeyValue string `json:"key_value"`
IsValid bool `json:"is_valid"`
Error string `json:"error,omitempty"`
}
// KeyValidatorService provides methods to validate API keys.
type KeyValidatorService struct {
DB *gorm.DB
@@ -73,19 +80,45 @@ func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models
return isValid, nil
}
// TestSingleKeyByID performs a synchronous validation test for a single API key by its ID.
// It is intended for handling user-initiated "Test" actions.
// It does not modify the key's state in the database.
func (s *KeyValidatorService) TestSingleKeyByID(ctx context.Context, keyID uint) (bool, error) {
var apiKey models.APIKey
if err := s.DB.First(&apiKey, keyID).Error; err != nil {
return false, fmt.Errorf("failed to find api key with id %d: %w", keyID, err)
// TestMultipleKeys performs a synchronous validation for a list of key values within a specific group.
func (s *KeyValidatorService) TestMultipleKeys(ctx context.Context, group *models.Group, keyValues []string) ([]KeyTestResult, error) {
results := make([]KeyTestResult, len(keyValues))
ch, err := s.channelFactory.GetChannel(group)
if err != nil {
return nil, fmt.Errorf("failed to get channel for group %s: %w", group.Name, err)
}
var group models.Group
if err := s.DB.First(&group, apiKey.GroupID).Error; err != nil {
return false, fmt.Errorf("failed to find group with id %d: %w", apiKey.GroupID, err)
// Find which of the provided keys actually exist in the database for this group
var existingKeys []models.APIKey
if err := s.DB.Where("group_id = ? AND key_value IN ?", group.ID, keyValues).Find(&existingKeys).Error; err != nil {
return nil, fmt.Errorf("failed to query keys from DB: %w", err)
}
existingKeyMap := make(map[string]bool)
for _, k := range existingKeys {
existingKeyMap[k.KeyValue] = true
}
return s.ValidateSingleKey(ctx, &apiKey, &group)
for i, kv := range keyValues {
// Pre-check: ensure the key belongs to the group to prevent unnecessary API calls
if !existingKeyMap[kv] {
results[i] = KeyTestResult{
KeyValue: kv,
IsValid: false,
Error: "Key does not exist in this group or has been removed.",
}
continue
}
isValid, validationErr := ch.ValidateKey(ctx, kv)
results[i] = KeyTestResult{
KeyValue: kv,
IsValid: isValid,
Error: "", // Explicitly set error to empty string on success
}
if validationErr != nil {
results[i].Error = validationErr.Error()
}
}
return results, nil
}