From d64ada41810caf0184293fe13c63839bb791d31a Mon Sep 17 00:00:00 2001 From: tbphp Date: Sat, 5 Jul 2025 14:50:58 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20key=E6=8E=A5=E5=8F=A3=E5=AE=8C=E5=96=84?= =?UTF-8?q?=E5=8F=8A=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/channel/base_channel.go | 1 + internal/channel/factory.go | 3 +- internal/channel/gemini_channel.go | 45 ++++- internal/channel/openai_channel.go | 45 ++++- internal/errors/parser.go | 81 ++++++++ internal/handler/key_handler.go | 213 +++++++++++---------- internal/models/types.go | 5 +- internal/response/pagination.go | 74 +++++++ internal/router/router.go | 24 ++- internal/services/key_cron_service.go | 71 +++---- internal/services/key_service.go | 78 +++++--- internal/services/key_validator_service.go | 55 ++++-- 12 files changed, 487 insertions(+), 208 deletions(-) create mode 100644 internal/errors/parser.go create mode 100644 internal/response/pagination.go diff --git a/internal/channel/base_channel.go b/internal/channel/base_channel.go index fa6a1b0..0b14675 100644 --- a/internal/channel/base_channel.go +++ b/internal/channel/base_channel.go @@ -26,6 +26,7 @@ type BaseChannel struct { Name string Upstreams []UpstreamInfo HTTPClient *http.Client + TestModel string upstreamLock sync.Mutex } diff --git a/internal/channel/factory.go b/internal/channel/factory.go index 1703f83..84c0818 100644 --- a/internal/channel/factory.go +++ b/internal/channel/factory.go @@ -60,7 +60,7 @@ func (f *Factory) GetChannel(group *models.Group) (ChannelProxy, error) { } // newBaseChannel is a helper function to create and configure a BaseChannel. -func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, groupConfig datatypes.JSONMap) (*BaseChannel, error) { +func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, groupConfig datatypes.JSONMap, testModel string) (*BaseChannel, error) { type upstreamDef struct { URL string `json:"url"` Weight int `json:"weight"` @@ -103,5 +103,6 @@ func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, grou Name: name, Upstreams: upstreamInfos, HTTPClient: httpClient, + TestModel: testModel, }, nil } diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index 3412e11..1fa0042 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -1,11 +1,14 @@ package channel import ( + "bytes" "context" + "encoding/json" "fmt" + app_errors "gpt-load/internal/errors" "gpt-load/internal/models" + "io" "net/http" - "strings" "github.com/gin-gonic/gin" @@ -20,7 +23,7 @@ type GeminiChannel struct { } func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) { - base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config) + base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config, group.TestModel) if err != nil { return nil, err } @@ -39,20 +42,35 @@ func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo return ch.ProcessRequest(c, apiKey, modifier, ch) } -// ValidateKey checks if the given API key is valid by making a request to the models endpoint. +// ValidateKey checks if the given API key is valid by making a generateContent request. func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) { upstreamURL := ch.getUpstreamURL() if upstreamURL == nil { return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name) } - // Construct the request URL for listing models. - reqURL := fmt.Sprintf("%s/v1beta/models?key=%s", upstreamURL.String(), key) + // Use the test model specified in the group settings. + // The path format for Gemini is /v1beta/models/{model}:generateContent + reqURL := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", upstreamURL.String(), ch.TestModel, key) - req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + // Use a minimal, low-cost payload for validation + payload := gin.H{ + "contents": []gin.H{ + {"parts": []gin.H{ + {"text": "Only output 'ok'"}, + }}, + }, + } + body, err := json.Marshal(payload) + if err != nil { + return false, fmt.Errorf("failed to marshal validation payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", reqURL, bytes.NewBuffer(body)) if err != nil { return false, fmt.Errorf("failed to create validation request: %w", err) } + req.Header.Set("Content-Type", "application/json") resp, err := ch.HTTPClient.Do(req) if err != nil { @@ -61,7 +79,20 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, err defer resp.Body.Close() // A 200 OK status code indicates the key is valid. - return resp.StatusCode == http.StatusOK, nil + if resp.StatusCode == http.StatusOK { + return true, nil + } + + // For non-200 responses, parse the body to provide a more specific error reason. + errorBody, err := io.ReadAll(resp.Body) + if err != nil { + return false, fmt.Errorf("key is invalid (status %d), but failed to read error body: %w", resp.StatusCode, err) + } + + // Use the new parser to extract a clean error message. + parsedError := app_errors.ParseUpstreamError(errorBody) + + return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError) } // IsStreamingRequest checks if the request is for a streaming response. diff --git a/internal/channel/openai_channel.go b/internal/channel/openai_channel.go index a7fea5e..2b054da 100644 --- a/internal/channel/openai_channel.go +++ b/internal/channel/openai_channel.go @@ -1,9 +1,13 @@ package channel import ( + "bytes" "context" + "encoding/json" "fmt" + app_errors "gpt-load/internal/errors" "gpt-load/internal/models" + "io" "net/http" "github.com/gin-gonic/gin" @@ -19,7 +23,7 @@ type OpenAIChannel struct { } func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) { - base, err := f.newBaseChannel("openai", group.Upstreams, group.Config) + base, err := f.newBaseChannel("openai", group.Upstreams, group.Config, group.TestModel) if err != nil { return nil, err } @@ -36,21 +40,34 @@ func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo return ch.ProcessRequest(c, apiKey, modifier, ch) } -// ValidateKey checks if the given API key is valid by making a request to the models endpoint. +// ValidateKey checks if the given API key is valid by making a chat completion request. func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) { upstreamURL := ch.getUpstreamURL() if upstreamURL == nil { return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name) } - // Construct the request URL for listing models, a common endpoint for key validation. - reqURL := upstreamURL.String() + "/v1/models" + reqURL := upstreamURL.String() + "/v1/chat/completions" - req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + // Use a minimal, low-cost payload for validation + payload := gin.H{ + "model": ch.TestModel, + "messages": []gin.H{ + {"role": "user", "content": "Only output 'ok'"}, + }, + "max_tokens": 1, + } + body, err := json.Marshal(payload) + if err != nil { + return false, fmt.Errorf("failed to marshal validation payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", reqURL, bytes.NewBuffer(body)) if err != nil { return false, fmt.Errorf("failed to create validation request: %w", err) } req.Header.Set("Authorization", "Bearer "+key) + req.Header.Set("Content-Type", "application/json") resp, err := ch.HTTPClient.Do(req) if err != nil { @@ -58,9 +75,21 @@ func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, err } defer resp.Body.Close() - // A 200 OK status code indicates the key is valid. - // Other status codes (e.g., 401 Unauthorized) indicate an invalid key. - return resp.StatusCode == http.StatusOK, nil + // A 200 OK status code indicates the key is valid and can make requests. + if resp.StatusCode == http.StatusOK { + return true, nil + } + + // For non-200 responses, parse the body to provide a more specific error reason. + errorBody, err := io.ReadAll(resp.Body) + if err != nil { + return false, fmt.Errorf("key is invalid (status %d), but failed to read error body: %w", resp.StatusCode, err) + } + + // Use the new parser to extract a clean error message. + parsedError := app_errors.ParseUpstreamError(errorBody) + + return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError) } // IsStreamingRequest checks if the request is for a streaming response. diff --git a/internal/errors/parser.go b/internal/errors/parser.go new file mode 100644 index 0000000..cda5eff --- /dev/null +++ b/internal/errors/parser.go @@ -0,0 +1,81 @@ +package errors + +import ( + "encoding/json" + "strings" +) + +const ( + // maxErrorBodyLength defines the maximum length of an error message to be stored or returned. + maxErrorBodyLength = 2048 +) + +// standardErrorResponse matches formats like: {"error": {"message": "..."}} +type standardErrorResponse struct { + Error struct { + Message string `json:"message"` + } `json:"error"` +} + +// vendorErrorResponse matches formats like: {"error_msg": "..."} +type vendorErrorResponse struct { + ErrorMsg string `json:"error_msg"` +} + +// simpleErrorResponse matches formats like: {"error": "..."} +type simpleErrorResponse struct { + Error string `json:"error"` +} + +// rootMessageErrorResponse matches formats like: {"message": "..."} +type rootMessageErrorResponse struct { + Message string `json:"message"` +} + +// ParseUpstreamError attempts to parse a structured error message from an upstream response body +// using a chain of responsibility pattern. It tries various common formats and gracefully +// degrades to a raw string if all parsing attempts fail. +func ParseUpstreamError(body []byte) string { + // 1. Attempt to parse the standard OpenAI/Gemini format. + var stdErr standardErrorResponse + if err := json.Unmarshal(body, &stdErr); err == nil { + if msg := strings.TrimSpace(stdErr.Error.Message); msg != "" { + return truncateString(msg, maxErrorBodyLength) + } + } + + // 2. Attempt to parse vendor-specific format (e.g., Baidu). + var vendorErr vendorErrorResponse + if err := json.Unmarshal(body, &vendorErr); err == nil { + if msg := strings.TrimSpace(vendorErr.ErrorMsg); msg != "" { + return truncateString(msg, maxErrorBodyLength) + } + } + + // 3. Attempt to parse simple error format. + var simpleErr simpleErrorResponse + if err := json.Unmarshal(body, &simpleErr); err == nil { + if msg := strings.TrimSpace(simpleErr.Error); msg != "" { + return truncateString(msg, maxErrorBodyLength) + } + } + + // 4. Attempt to parse root-level message format. + var rootMsgErr rootMessageErrorResponse + if err := json.Unmarshal(body, &rootMsgErr); err == nil { + if msg := strings.TrimSpace(rootMsgErr.Message); msg != "" { + return truncateString(msg, maxErrorBodyLength) + } + } + + // 5. Graceful Degradation: If all parsing fails, return the raw (but safe) body. + return truncateString(string(body), maxErrorBodyLength) +} + +// truncateString ensures a string does not exceed a maximum length. +func truncateString(s string, maxLength int) string { + if len(s) > maxLength { + return s[:maxLength] + } + return s +} diff --git a/internal/handler/key_handler.go b/internal/handler/key_handler.go index 501fe5c..1668d99 100644 --- a/internal/handler/key_handler.go +++ b/internal/handler/key_handler.go @@ -12,51 +12,36 @@ import ( "gorm.io/gorm" ) -// validateGroupID validates and parses group ID from request parameter -func validateGroupID(c *gin.Context) (uint, error) { - groupIDStr := c.Param("id") +// validateGroupIDFromQuery validates and parses group ID from a query parameter. +func validateGroupIDFromQuery(c *gin.Context) (uint, error) { + groupIDStr := c.Query("group_id") if groupIDStr == "" { - return 0, fmt.Errorf("group ID is required") + return 0, fmt.Errorf("group_id query parameter is required") } groupID, err := strconv.Atoi(groupIDStr) if err != nil || groupID <= 0 { - return 0, fmt.Errorf("invalid group ID format") + return 0, fmt.Errorf("invalid group_id format") } return uint(groupID), nil } -// validateKeyID validates and parses key ID from request parameter -func validateKeyID(c *gin.Context) (uint, error) { - keyIDStr := c.Param("key_id") - if keyIDStr == "" { - return 0, fmt.Errorf("key ID is required") - } - - keyID, err := strconv.Atoi(keyIDStr) - if err != nil || keyID <= 0 { - return 0, fmt.Errorf("invalid key ID format") - } - - return uint(keyID), nil -} - // validateKeysText validates the keys text input func validateKeysText(keysText string) error { if strings.TrimSpace(keysText) == "" { return fmt.Errorf("keys text cannot be empty") } - if len(keysText) > 1024*1024 { // 1MB limit - return fmt.Errorf("keys text is too large (max 1MB)") + if len(keysText) > 10*1024*1024 { + return fmt.Errorf("keys text is too large (max 10MB)") } return nil } // findGroupByID is a helper function to find a group by its ID. -func (s *Server) findGroupByID(c *gin.Context, groupID int) (*models.Group, bool) { +func (s *Server) findGroupByID(c *gin.Context, groupID uint) (*models.Group, bool) { var group models.Group if err := s.DB.First(&group, groupID).Error; err != nil { if err == gorm.ErrRecordNotFound { @@ -69,22 +54,26 @@ func (s *Server) findGroupByID(c *gin.Context, groupID int) (*models.Group, bool return &group, true } -// AddMultipleKeysRequest defines the payload for adding multiple keys from a text block. -type AddMultipleKeysRequest struct { +// KeyTextRequest defines a generic payload for operations requiring a group ID and a text block of keys. +type KeyTextRequest struct { + GroupID uint `json:"group_id" binding:"required"` KeysText string `json:"keys_text" binding:"required"` } +// GroupIDRequest defines a generic payload for operations requiring only a group ID. +type GroupIDRequest struct { + GroupID uint `json:"group_id" binding:"required"` +} + // AddMultipleKeys handles creating new keys from a text block within a specific group. func (s *Server) AddMultipleKeys(c *gin.Context) { - groupID, err := validateGroupID(c) - if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error())) + var req KeyTextRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) return } - var req AddMultipleKeysRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) + if _, ok := s.findGroupByID(c, req.GroupID); !ok { return } @@ -93,20 +82,28 @@ func (s *Server) AddMultipleKeys(c *gin.Context) { return } - result, err := s.KeyService.AddMultipleKeys(groupID, req.KeysText) + result, err := s.KeyService.AddMultipleKeys(req.GroupID, req.KeysText) if err != nil { - response.Error(c, app_errors.ParseDBError(err)) + if err.Error() == "no valid keys found in the input text" { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error())) + } else { + response.Error(c, app_errors.ParseDBError(err)) + } return } response.Success(c, result) } -// ListKeysInGroup handles listing all keys within a specific group. +// ListKeysInGroup handles listing all keys within a specific group with pagination. func (s *Server) ListKeysInGroup(c *gin.Context) { - groupID, err := strconv.Atoi(c.Param("id")) + groupID, err := validateGroupIDFromQuery(c) if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID")) + response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error())) + return + } + + if _, ok := s.findGroupByID(c, groupID); !ok { return } @@ -116,72 +113,93 @@ func (s *Server) ListKeysInGroup(c *gin.Context) { return } - keys, err := s.KeyService.ListKeysInGroup(uint(groupID), statusFilter) + searchKeyword := c.Query("key") + + query := s.KeyService.ListKeysInGroupQuery(groupID, statusFilter, searchKeyword) + + var keys []models.APIKey + paginatedResult, err := response.Paginate(c, query, &keys) if err != nil { response.Error(c, app_errors.ParseDBError(err)) return } - response.Success(c, keys) + response.Success(c, paginatedResult) } -// DeleteSingleKey handles deleting a specific key. -func (s *Server) DeleteSingleKey(c *gin.Context) { - groupID, err := strconv.Atoi(c.Param("id")) - if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID")) +// DeleteMultipleKeys handles deleting keys from a text block within a specific group. +func (s *Server) DeleteMultipleKeys(c *gin.Context) { + var req KeyTextRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) return } - keyID, err := strconv.Atoi(c.Param("key_id")) - if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid key ID")) + if _, ok := s.findGroupByID(c, req.GroupID); !ok { return } - rowsAffected, err := s.KeyService.DeleteSingleKey(uint(groupID), uint(keyID)) + if err := validateKeysText(req.KeysText); err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error())) + return + } + + result, err := s.KeyService.DeleteMultipleKeys(req.GroupID, req.KeysText) + if err != nil { + if err.Error() == "no valid keys found in the input text" { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error())) + } else { + response.Error(c, app_errors.ParseDBError(err)) + } + return + } + + response.Success(c, result) +} + +// TestMultipleKeys handles a one-off validation test for multiple keys. +func (s *Server) TestMultipleKeys(c *gin.Context) { + var req KeyTextRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) + return + } + + group, ok := s.findGroupByID(c, req.GroupID) + if !ok { + return + } + + if err := validateKeysText(req.KeysText); err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error())) + return + } + + // Re-use the parsing logic from the key service + keysToTest := s.KeyService.ParseKeysFromText(req.KeysText) + if len(keysToTest) == 0 { + response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "no valid keys found in the input text")) + return + } + + results, err := s.KeyValidatorService.TestMultipleKeys(c.Request.Context(), group, keysToTest) if err != nil { response.Error(c, app_errors.ParseDBError(err)) return } - if rowsAffected == 0 { - response.Error(c, app_errors.ErrResourceNotFound) - return - } - response.Success(c, gin.H{"message": "Key deleted successfully"}) -} - -// TestSingleKey handles a one-off validation test for a single key. -func (s *Server) TestSingleKey(c *gin.Context) { - keyID, err := validateKeyID(c) - if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error())) - return - } - - isValid, validationErr := s.KeyValidatorService.TestSingleKeyByID(c.Request.Context(), keyID) - if validationErr != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadGateway, validationErr.Error())) - return - } - - if isValid { - response.Success(c, gin.H{"success": true, "message": "Key is valid."}) - } else { - response.Success(c, gin.H{"success": false, "message": "Key is invalid or has insufficient quota."}) - } + response.Success(c, results) } // ValidateGroupKeys initiates a manual validation task for all keys in a group. func (s *Server) ValidateGroupKeys(c *gin.Context) { - groupID, err := strconv.Atoi(c.Param("id")) - if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID")) + var req GroupIDRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) return } - group, ok := s.findGroupByID(c, groupID) + group, ok := s.findGroupByID(c, req.GroupID) if !ok { return } @@ -197,13 +215,17 @@ func (s *Server) ValidateGroupKeys(c *gin.Context) { // RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'. func (s *Server) RestoreAllInvalidKeys(c *gin.Context) { - groupID, err := strconv.Atoi(c.Param("id")) - if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID")) + var req GroupIDRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) return } - rowsAffected, err := s.KeyService.RestoreAllInvalidKeys(uint(groupID)) + if _, ok := s.findGroupByID(c, req.GroupID); !ok { + return + } + + rowsAffected, err := s.KeyService.RestoreAllInvalidKeys(req.GroupID) if err != nil { response.Error(c, app_errors.ParseDBError(err)) return @@ -214,13 +236,17 @@ func (s *Server) RestoreAllInvalidKeys(c *gin.Context) { // ClearAllInvalidKeys deletes all 'inactive' keys from a group. func (s *Server) ClearAllInvalidKeys(c *gin.Context) { - groupID, err := strconv.Atoi(c.Param("id")) - if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID")) + var req GroupIDRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error())) return } - rowsAffected, err := s.KeyService.ClearAllInvalidKeys(uint(groupID)) + if _, ok := s.findGroupByID(c, req.GroupID); !ok { + return + } + + rowsAffected, err := s.KeyService.ClearAllInvalidKeys(req.GroupID) if err != nil { response.Error(c, app_errors.ParseDBError(err)) return @@ -229,20 +255,3 @@ func (s *Server) ClearAllInvalidKeys(c *gin.Context) { response.Success(c, gin.H{"message": fmt.Sprintf("%d invalid keys cleared.", rowsAffected)}) } -// ExportKeys returns a list of keys for a group, filtered by status. -func (s *Server) ExportKeys(c *gin.Context) { - groupID, err := strconv.Atoi(c.Param("id")) - if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID")) - return - } - - filter := c.DefaultQuery("filter", "all") - keys, err := s.KeyService.ExportKeys(uint(groupID), filter) - if err != nil { - response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error())) - return - } - - response.Success(c, gin.H{"keys": keys}) -} diff --git a/internal/models/types.go b/internal/models/types.go index 550afc0..05d71c2 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -52,11 +52,12 @@ type Group struct { // APIKey 对应 api_keys 表 type APIKey struct { ID uint `gorm:"primaryKey;autoIncrement" json:"id"` - GroupID uint `gorm:"not null" json:"group_id"` - KeyValue string `gorm:"type:varchar(512);not null" json:"key_value"` + KeyValue string `gorm:"type:varchar(512);not null;uniqueIndex:idx_group_key" json:"key_value"` + GroupID uint `gorm:"not null;uniqueIndex:idx_group_key" json:"group_id"` Status string `gorm:"type:varchar(50);not null;default:'active'" json:"status"` RequestCount int64 `gorm:"not null;default:0" json:"request_count"` FailureCount int64 `gorm:"not null;default:0" json:"failure_count"` + ErrorReason string `gorm:"type:text" json:"error_reason"` LastUsedAt *time.Time `json:"last_used_at"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` diff --git a/internal/response/pagination.go b/internal/response/pagination.go new file mode 100644 index 0000000..e46f3cd --- /dev/null +++ b/internal/response/pagination.go @@ -0,0 +1,74 @@ +package response + +import ( + "math" + "strconv" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +const ( + DefaultPageSize = 15 + MaxPageSize = 1000 +) + +// Pagination represents the pagination details in a response. +type Pagination struct { + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalItems int64 `json:"total_items"` + TotalPages int `json:"total_pages"` +} + +// PaginatedResponse is the standard structure for all paginated API responses. +type PaginatedResponse struct { + Items interface{} `json:"items"` + Pagination Pagination `json:"pagination"` +} + +// Paginate performs pagination on a GORM query and returns a standardized response. +// It takes a Gin context, a GORM query builder, and a destination slice for the results. +func Paginate(c *gin.Context, query *gorm.DB, dest interface{}) (*PaginatedResponse, error) { + // 1. Get page and page size from query parameters + page, err := strconv.Atoi(c.DefaultQuery("page", "1")) + if err != nil || page < 1 { + page = 1 + } + + pageSize, err := strconv.Atoi(c.DefaultQuery("page_size", strconv.Itoa(DefaultPageSize))) + if err != nil || pageSize <= 0 { + pageSize = DefaultPageSize + } + if pageSize > MaxPageSize { + pageSize = MaxPageSize + } + + // 2. Get total count of items + var totalItems int64 + if err := query.Count(&totalItems).Error; err != nil { + return nil, err + } + + // 3. Calculate offset and total pages + offset := (page - 1) * pageSize + totalPages := int(math.Ceil(float64(totalItems) / float64(pageSize))) + + // 4. Retrieve the data for the current page + if err := query.Limit(pageSize).Offset(offset).Find(dest).Error; err != nil { + return nil, err + } + + // 5. Construct the paginated response + paginatedData := &PaginatedResponse{ + Items: dest, + Pagination: Pagination{ + Page: page, + PageSize: pageSize, + TotalItems: totalItems, + TotalPages: totalPages, + }, + } + + return paginatedData, nil +} diff --git a/internal/router/router.go b/internal/router/router.go index ef81757..84db09e 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -109,20 +109,18 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser groups.PUT("/:id", serverHandler.UpdateGroup) groups.DELETE("/:id", serverHandler.DeleteGroup) - // Key-specific routes - keys := groups.Group("/:id/keys") - { - keys.GET("", serverHandler.ListKeysInGroup) - keys.POST("/add-multiple", serverHandler.AddMultipleKeys) - keys.POST("/restore-all-invalid", serverHandler.RestoreAllInvalidKeys) - keys.POST("/clear-all-invalid", serverHandler.ClearAllInvalidKeys) - keys.GET("/export", serverHandler.ExportKeys) - keys.DELETE("/:key_id", serverHandler.DeleteSingleKey) - keys.POST("/:key_id/test", serverHandler.TestSingleKey) - } + } - // Group-level actions - groups.POST("/:id/validate-keys", serverHandler.ValidateGroupKeys) + // Key Management Routes + keys := api.Group("/keys") + { + keys.GET("", serverHandler.ListKeysInGroup) + keys.POST("/add-multiple", serverHandler.AddMultipleKeys) + keys.POST("/delete-multiple", serverHandler.DeleteMultipleKeys) + keys.POST("/restore-all-invalid", serverHandler.RestoreAllInvalidKeys) + keys.POST("/clear-all-invalid", serverHandler.ClearAllInvalidKeys) + keys.POST("/validate-group", serverHandler.ValidateGroupKeys) + keys.POST("/test-multiple", serverHandler.TestMultipleKeys) } // Tasks diff --git a/internal/services/key_cron_service.go b/internal/services/key_cron_service.go index 6fc9914..ec50fb2 100644 --- a/internal/services/key_cron_service.go +++ b/internal/services/key_cron_service.go @@ -147,21 +147,38 @@ func (s *KeyCronService) validateGroup(ctx context.Context, group *models.Group) func (s *KeyCronService) worker(ctx context.Context, wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- models.APIKey) { defer wg.Done() for key := range jobs { - isValid, err := s.Validator.ValidateSingleKey(ctx, &key, group) - // Only update status if there was no error during validation - if err != nil { - logrus.Warnf("KeyCronService: Failed to validate key ID %d for group %s: %v. Skipping status update.", key.ID, group.Name, err) - continue + isValid, validationErr := s.Validator.ValidateSingleKey(ctx, &key, group) + + newStatus := key.Status + newErrorReason := key.ErrorReason + statusChanged := false + + if validationErr != nil { + // Validation failed, mark as inactive and record the reason + newStatus = "inactive" + newErrorReason = validationErr.Error() + } else { + // Validation succeeded + if isValid { + newStatus = "active" + newErrorReason = "" // Clear reason on success + } else { + // This case might happen if the key is valid but has no quota, etc. + // The error would be in validationErr, so this branch is less likely. + // We still mark it as inactive but without a specific error from our side. + newStatus = "inactive" + newErrorReason = "Validation returned false without a specific error." + } } - newStatus := "inactive" - if isValid { - newStatus = "active" + // Check if status or error reason has changed + if key.Status != newStatus || key.ErrorReason != newErrorReason { + statusChanged = true } - // Only send to results if the status has changed - if key.Status != newStatus { + if statusChanged { key.Status = newStatus + key.ErrorReason = newErrorReason results <- key } } @@ -171,36 +188,24 @@ func (s *KeyCronService) batchUpdateKeyStatus(keys []models.APIKey) { if len(keys) == 0 { return } - logrus.Infof("KeyCronService: Batch updating status for %d keys.", len(keys)) - - activeIDs := []uint{} - inactiveIDs := []uint{} - - for _, key := range keys { - if key.Status == "active" { - activeIDs = append(activeIDs, key.ID) - } else { - inactiveIDs = append(inactiveIDs, key.ID) - } - } + logrus.Infof("KeyCronService: Batch updating status/reason for %d keys.", len(keys)) err := s.DB.Transaction(func(tx *gorm.DB) error { - if len(activeIDs) > 0 { - if err := tx.Model(&models.APIKey{}).Where("id IN ?", activeIDs).Update("status", "active").Error; err != nil { - return err + for _, key := range keys { + updates := map[string]interface{}{ + "status": key.Status, + "error_reason": key.ErrorReason, } - logrus.Infof("KeyCronService: Set %d keys to 'active'.", len(activeIDs)) - } - if len(inactiveIDs) > 0 { - if err := tx.Model(&models.APIKey{}).Where("id IN ?", inactiveIDs).Update("status", "inactive").Error; err != nil { - return err + if err := tx.Model(&models.APIKey{}).Where("id = ?", key.ID).Updates(updates).Error; err != nil { + // Log the error for this specific key but continue the transaction + logrus.Errorf("KeyCronService: Failed to update key ID %d: %v", key.ID, err) } - logrus.Infof("KeyCronService: Set %d keys to 'inactive'.", len(inactiveIDs)) } - return nil + return nil // Commit the transaction even if some updates failed }) if err != nil { - logrus.Errorf("KeyCronService: Failed to batch update key status: %v", err) + // This error is for the transaction itself, not individual updates + logrus.Errorf("KeyCronService: Transaction failed during batch update of key statuses: %v", err) } } diff --git a/internal/services/key_service.go b/internal/services/key_service.go index 624a2eb..e812e23 100644 --- a/internal/services/key_service.go +++ b/internal/services/key_service.go @@ -17,6 +17,13 @@ type AddKeysResult struct { TotalInGroup int64 `json:"total_in_group"` } +// DeleteKeysResult holds the result of deleting multiple keys. +type DeleteKeysResult struct { + DeletedCount int `json:"deleted_count"` + IgnoredCount int `json:"ignored_count"` + TotalInGroup int64 `json:"total_in_group"` +} + // KeyService provides services related to API keys. type KeyService struct { DB *gorm.DB @@ -30,7 +37,7 @@ func NewKeyService(db *gorm.DB) *KeyService { // AddMultipleKeys handles the business logic of creating new keys from a text block. func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysResult, error) { // 1. Parse keys from the text block - keys := s.parseKeysFromText(keysText) + keys := s.ParseKeysFromText(keysText) if len(keys) == 0 { return nil, fmt.Errorf("no valid keys found in the input text") } @@ -101,7 +108,9 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes }, nil } -func (s *KeyService) parseKeysFromText(text string) []string { +// ParseKeysFromText parses a string of keys from various formats into a string slice. +// This function is exported to be shared with the handler layer. +func (s *KeyService) ParseKeysFromText(text string) []string { var keys []string // First, try to parse as a JSON array of strings @@ -162,45 +171,52 @@ func (s *KeyService) ClearAllInvalidKeys(groupID uint) (int64, error) { return result.RowsAffected, result.Error } -// DeleteSingleKey deletes a specific key from a group. -func (s *KeyService) DeleteSingleKey(groupID, keyID uint) (int64, error) { - result := s.DB.Where("group_id = ? AND id = ?", groupID, keyID).Delete(&models.APIKey{}) - return result.RowsAffected, result.Error -} - -// ExportKeys returns a list of keys for a group, filtered by status. -func (s *KeyService) ExportKeys(groupID uint, filter string) ([]string, error) { - query := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID) - - switch filter { - case "valid": - query = query.Where("status = ?", "active") - case "invalid": - query = query.Where("status = ?", "inactive") - case "all": - // No status filter needed - default: - return nil, fmt.Errorf("invalid filter value. Use 'all', 'valid', or 'invalid'") +// DeleteMultipleKeys handles the business logic of deleting keys from a text block. +func (s *KeyService) DeleteMultipleKeys(groupID uint, keysText string) (*DeleteKeysResult, error) { + // 1. Parse keys from the text block + keysToDelete := s.ParseKeysFromText(keysText) + if len(keysToDelete) == 0 { + return nil, fmt.Errorf("no valid keys found in the input text") } - var keys []string - if err := query.Pluck("key_value", &keys).Error; err != nil { + // 2. Perform the deletion + // GORM's batch delete doesn't easily return which ones were deleted vs. ignored. + // We perform a bulk delete and then count the remaining to calculate the result. + result := s.DB.Where("group_id = ? AND key_value IN ?", groupID, keysToDelete).Delete(&models.APIKey{}) + if result.Error != nil { + return nil, result.Error + } + + deletedCount := int(result.RowsAffected) + ignoredCount := len(keysToDelete) - deletedCount + + // 3. Get the new total count + var totalInGroup int64 + if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil { return nil, err } - return keys, nil + + return &DeleteKeysResult{ + DeletedCount: deletedCount, + IgnoredCount: ignoredCount, + TotalInGroup: totalInGroup, + }, nil } -// ListKeysInGroup lists all keys within a specific group, filtered by status. -func (s *KeyService) ListKeysInGroup(groupID uint, statusFilter string) ([]models.APIKey, error) { - var keys []models.APIKey - query := s.DB.Where("group_id = ?", groupID) +// ListKeysInGroupQuery builds a query to list all keys within a specific group, filtered by status. +// It returns a GORM query builder, allowing the handler to apply pagination. +func (s *KeyService) ListKeysInGroupQuery(groupID uint, statusFilter string, searchKeyword string) *gorm.DB { + query := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID) if statusFilter != "" { query = query.Where("status = ?", statusFilter) } - if err := query.Find(&keys).Error; err != nil { - return nil, err + if searchKeyword != "" { + // Use LIKE for fuzzy search on the key_value + query = query.Where("key_value LIKE ?", "%"+searchKeyword+"%") } - return keys, nil + + return query } + diff --git a/internal/services/key_validator_service.go b/internal/services/key_validator_service.go index 54d49dc..72c9e25 100644 --- a/internal/services/key_validator_service.go +++ b/internal/services/key_validator_service.go @@ -10,6 +10,13 @@ import ( "gorm.io/gorm" ) +// KeyTestResult holds the validation result for a single key. +type KeyTestResult struct { + KeyValue string `json:"key_value"` + IsValid bool `json:"is_valid"` + Error string `json:"error,omitempty"` +} + // KeyValidatorService provides methods to validate API keys. type KeyValidatorService struct { DB *gorm.DB @@ -73,19 +80,45 @@ func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models return isValid, nil } -// TestSingleKeyByID performs a synchronous validation test for a single API key by its ID. -// It is intended for handling user-initiated "Test" actions. -// It does not modify the key's state in the database. -func (s *KeyValidatorService) TestSingleKeyByID(ctx context.Context, keyID uint) (bool, error) { - var apiKey models.APIKey - if err := s.DB.First(&apiKey, keyID).Error; err != nil { - return false, fmt.Errorf("failed to find api key with id %d: %w", keyID, err) +// TestMultipleKeys performs a synchronous validation for a list of key values within a specific group. +func (s *KeyValidatorService) TestMultipleKeys(ctx context.Context, group *models.Group, keyValues []string) ([]KeyTestResult, error) { + results := make([]KeyTestResult, len(keyValues)) + ch, err := s.channelFactory.GetChannel(group) + if err != nil { + return nil, fmt.Errorf("failed to get channel for group %s: %w", group.Name, err) } - var group models.Group - if err := s.DB.First(&group, apiKey.GroupID).Error; err != nil { - return false, fmt.Errorf("failed to find group with id %d: %w", apiKey.GroupID, err) + // Find which of the provided keys actually exist in the database for this group + var existingKeys []models.APIKey + if err := s.DB.Where("group_id = ? AND key_value IN ?", group.ID, keyValues).Find(&existingKeys).Error; err != nil { + return nil, fmt.Errorf("failed to query keys from DB: %w", err) + } + existingKeyMap := make(map[string]bool) + for _, k := range existingKeys { + existingKeyMap[k.KeyValue] = true } - return s.ValidateSingleKey(ctx, &apiKey, &group) + for i, kv := range keyValues { + // Pre-check: ensure the key belongs to the group to prevent unnecessary API calls + if !existingKeyMap[kv] { + results[i] = KeyTestResult{ + KeyValue: kv, + IsValid: false, + Error: "Key does not exist in this group or has been removed.", + } + continue + } + + isValid, validationErr := ch.ValidateKey(ctx, kv) + results[i] = KeyTestResult{ + KeyValue: kv, + IsValid: isValid, + Error: "", // Explicitly set error to empty string on success + } + if validationErr != nil { + results[i].Error = validationErr.Error() + } + } + + return results, nil }