feat: key接口完善及错误处理
This commit is contained in:
@@ -26,6 +26,7 @@ type BaseChannel struct {
|
|||||||
Name string
|
Name string
|
||||||
Upstreams []UpstreamInfo
|
Upstreams []UpstreamInfo
|
||||||
HTTPClient *http.Client
|
HTTPClient *http.Client
|
||||||
|
TestModel string
|
||||||
upstreamLock sync.Mutex
|
upstreamLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -60,7 +60,7 @@ func (f *Factory) GetChannel(group *models.Group) (ChannelProxy, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newBaseChannel is a helper function to create and configure a BaseChannel.
|
// newBaseChannel is a helper function to create and configure a BaseChannel.
|
||||||
func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, groupConfig datatypes.JSONMap) (*BaseChannel, error) {
|
func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, groupConfig datatypes.JSONMap, testModel string) (*BaseChannel, error) {
|
||||||
type upstreamDef struct {
|
type upstreamDef struct {
|
||||||
URL string `json:"url"`
|
URL string `json:"url"`
|
||||||
Weight int `json:"weight"`
|
Weight int `json:"weight"`
|
||||||
@@ -103,5 +103,6 @@ func (f *Factory) newBaseChannel(name string, upstreamsJSON datatypes.JSON, grou
|
|||||||
Name: name,
|
Name: name,
|
||||||
Upstreams: upstreamInfos,
|
Upstreams: upstreamInfos,
|
||||||
HTTPClient: httpClient,
|
HTTPClient: httpClient,
|
||||||
|
TestModel: testModel,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@@ -1,11 +1,14 @@
|
|||||||
package channel
|
package channel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
app_errors "gpt-load/internal/errors"
|
||||||
"gpt-load/internal/models"
|
"gpt-load/internal/models"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -20,7 +23,7 @@ type GeminiChannel struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
|
func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
|
||||||
base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config)
|
base, err := f.newBaseChannel("gemini", group.Upstreams, group.Config, group.TestModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -39,20 +42,35 @@ func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo
|
|||||||
return ch.ProcessRequest(c, apiKey, modifier, ch)
|
return ch.ProcessRequest(c, apiKey, modifier, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateKey checks if the given API key is valid by making a request to the models endpoint.
|
// ValidateKey checks if the given API key is valid by making a generateContent request.
|
||||||
func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
|
func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
|
||||||
upstreamURL := ch.getUpstreamURL()
|
upstreamURL := ch.getUpstreamURL()
|
||||||
if upstreamURL == nil {
|
if upstreamURL == nil {
|
||||||
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
|
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct the request URL for listing models.
|
// Use the test model specified in the group settings.
|
||||||
reqURL := fmt.Sprintf("%s/v1beta/models?key=%s", upstreamURL.String(), key)
|
// The path format for Gemini is /v1beta/models/{model}:generateContent
|
||||||
|
reqURL := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", upstreamURL.String(), ch.TestModel, key)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
// Use a minimal, low-cost payload for validation
|
||||||
|
payload := gin.H{
|
||||||
|
"contents": []gin.H{
|
||||||
|
{"parts": []gin.H{
|
||||||
|
{"text": "Only output 'ok'"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to marshal validation payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, bytes.NewBuffer(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to create validation request: %w", err)
|
return false, fmt.Errorf("failed to create validation request: %w", err)
|
||||||
}
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
resp, err := ch.HTTPClient.Do(req)
|
resp, err := ch.HTTPClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -61,7 +79,20 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, err
|
|||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// A 200 OK status code indicates the key is valid.
|
// A 200 OK status code indicates the key is valid.
|
||||||
return resp.StatusCode == http.StatusOK, nil
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// For non-200 responses, parse the body to provide a more specific error reason.
|
||||||
|
errorBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("key is invalid (status %d), but failed to read error body: %w", resp.StatusCode, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the new parser to extract a clean error message.
|
||||||
|
parsedError := app_errors.ParseUpstreamError(errorBody)
|
||||||
|
|
||||||
|
return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsStreamingRequest checks if the request is for a streaming response.
|
// IsStreamingRequest checks if the request is for a streaming response.
|
||||||
|
@@ -1,9 +1,13 @@
|
|||||||
package channel
|
package channel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
app_errors "gpt-load/internal/errors"
|
||||||
"gpt-load/internal/models"
|
"gpt-load/internal/models"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -19,7 +23,7 @@ type OpenAIChannel struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
|
func newOpenAIChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
|
||||||
base, err := f.newBaseChannel("openai", group.Upstreams, group.Config)
|
base, err := f.newBaseChannel("openai", group.Upstreams, group.Config, group.TestModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -36,21 +40,34 @@ func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo
|
|||||||
return ch.ProcessRequest(c, apiKey, modifier, ch)
|
return ch.ProcessRequest(c, apiKey, modifier, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateKey checks if the given API key is valid by making a request to the models endpoint.
|
// ValidateKey checks if the given API key is valid by making a chat completion request.
|
||||||
func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
|
func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
|
||||||
upstreamURL := ch.getUpstreamURL()
|
upstreamURL := ch.getUpstreamURL()
|
||||||
if upstreamURL == nil {
|
if upstreamURL == nil {
|
||||||
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
|
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct the request URL for listing models, a common endpoint for key validation.
|
reqURL := upstreamURL.String() + "/v1/chat/completions"
|
||||||
reqURL := upstreamURL.String() + "/v1/models"
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
// Use a minimal, low-cost payload for validation
|
||||||
|
payload := gin.H{
|
||||||
|
"model": ch.TestModel,
|
||||||
|
"messages": []gin.H{
|
||||||
|
{"role": "user", "content": "Only output 'ok'"},
|
||||||
|
},
|
||||||
|
"max_tokens": 1,
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to marshal validation payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, bytes.NewBuffer(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to create validation request: %w", err)
|
return false, fmt.Errorf("failed to create validation request: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+key)
|
req.Header.Set("Authorization", "Bearer "+key)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
resp, err := ch.HTTPClient.Do(req)
|
resp, err := ch.HTTPClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -58,9 +75,21 @@ func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, err
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// A 200 OK status code indicates the key is valid.
|
// A 200 OK status code indicates the key is valid and can make requests.
|
||||||
// Other status codes (e.g., 401 Unauthorized) indicate an invalid key.
|
if resp.StatusCode == http.StatusOK {
|
||||||
return resp.StatusCode == http.StatusOK, nil
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// For non-200 responses, parse the body to provide a more specific error reason.
|
||||||
|
errorBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("key is invalid (status %d), but failed to read error body: %w", resp.StatusCode, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the new parser to extract a clean error message.
|
||||||
|
parsedError := app_errors.ParseUpstreamError(errorBody)
|
||||||
|
|
||||||
|
return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsStreamingRequest checks if the request is for a streaming response.
|
// IsStreamingRequest checks if the request is for a streaming response.
|
||||||
|
81
internal/errors/parser.go
Normal file
81
internal/errors/parser.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package errors
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// maxErrorBodyLength defines the maximum length of an error message to be stored or returned.
|
||||||
|
maxErrorBodyLength = 2048
|
||||||
|
)
|
||||||
|
|
||||||
|
// standardErrorResponse matches formats like: {"error": {"message": "..."}}
|
||||||
|
type standardErrorResponse struct {
|
||||||
|
Error struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
} `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// vendorErrorResponse matches formats like: {"error_msg": "..."}
|
||||||
|
type vendorErrorResponse struct {
|
||||||
|
ErrorMsg string `json:"error_msg"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// simpleErrorResponse matches formats like: {"error": "..."}
|
||||||
|
type simpleErrorResponse struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// rootMessageErrorResponse matches formats like: {"message": "..."}
|
||||||
|
type rootMessageErrorResponse struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseUpstreamError attempts to parse a structured error message from an upstream response body
|
||||||
|
// using a chain of responsibility pattern. It tries various common formats and gracefully
|
||||||
|
// degrades to a raw string if all parsing attempts fail.
|
||||||
|
func ParseUpstreamError(body []byte) string {
|
||||||
|
// 1. Attempt to parse the standard OpenAI/Gemini format.
|
||||||
|
var stdErr standardErrorResponse
|
||||||
|
if err := json.Unmarshal(body, &stdErr); err == nil {
|
||||||
|
if msg := strings.TrimSpace(stdErr.Error.Message); msg != "" {
|
||||||
|
return truncateString(msg, maxErrorBodyLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Attempt to parse vendor-specific format (e.g., Baidu).
|
||||||
|
var vendorErr vendorErrorResponse
|
||||||
|
if err := json.Unmarshal(body, &vendorErr); err == nil {
|
||||||
|
if msg := strings.TrimSpace(vendorErr.ErrorMsg); msg != "" {
|
||||||
|
return truncateString(msg, maxErrorBodyLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Attempt to parse simple error format.
|
||||||
|
var simpleErr simpleErrorResponse
|
||||||
|
if err := json.Unmarshal(body, &simpleErr); err == nil {
|
||||||
|
if msg := strings.TrimSpace(simpleErr.Error); msg != "" {
|
||||||
|
return truncateString(msg, maxErrorBodyLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Attempt to parse root-level message format.
|
||||||
|
var rootMsgErr rootMessageErrorResponse
|
||||||
|
if err := json.Unmarshal(body, &rootMsgErr); err == nil {
|
||||||
|
if msg := strings.TrimSpace(rootMsgErr.Message); msg != "" {
|
||||||
|
return truncateString(msg, maxErrorBodyLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Graceful Degradation: If all parsing fails, return the raw (but safe) body.
|
||||||
|
return truncateString(string(body), maxErrorBodyLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncateString ensures a string does not exceed a maximum length.
|
||||||
|
func truncateString(s string, maxLength int) string {
|
||||||
|
if len(s) > maxLength {
|
||||||
|
return s[:maxLength]
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
@@ -12,51 +12,36 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// validateGroupID validates and parses group ID from request parameter
|
// validateGroupIDFromQuery validates and parses group ID from a query parameter.
|
||||||
func validateGroupID(c *gin.Context) (uint, error) {
|
func validateGroupIDFromQuery(c *gin.Context) (uint, error) {
|
||||||
groupIDStr := c.Param("id")
|
groupIDStr := c.Query("group_id")
|
||||||
if groupIDStr == "" {
|
if groupIDStr == "" {
|
||||||
return 0, fmt.Errorf("group ID is required")
|
return 0, fmt.Errorf("group_id query parameter is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
groupID, err := strconv.Atoi(groupIDStr)
|
groupID, err := strconv.Atoi(groupIDStr)
|
||||||
if err != nil || groupID <= 0 {
|
if err != nil || groupID <= 0 {
|
||||||
return 0, fmt.Errorf("invalid group ID format")
|
return 0, fmt.Errorf("invalid group_id format")
|
||||||
}
|
}
|
||||||
|
|
||||||
return uint(groupID), nil
|
return uint(groupID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateKeyID validates and parses key ID from request parameter
|
|
||||||
func validateKeyID(c *gin.Context) (uint, error) {
|
|
||||||
keyIDStr := c.Param("key_id")
|
|
||||||
if keyIDStr == "" {
|
|
||||||
return 0, fmt.Errorf("key ID is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
keyID, err := strconv.Atoi(keyIDStr)
|
|
||||||
if err != nil || keyID <= 0 {
|
|
||||||
return 0, fmt.Errorf("invalid key ID format")
|
|
||||||
}
|
|
||||||
|
|
||||||
return uint(keyID), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateKeysText validates the keys text input
|
// validateKeysText validates the keys text input
|
||||||
func validateKeysText(keysText string) error {
|
func validateKeysText(keysText string) error {
|
||||||
if strings.TrimSpace(keysText) == "" {
|
if strings.TrimSpace(keysText) == "" {
|
||||||
return fmt.Errorf("keys text cannot be empty")
|
return fmt.Errorf("keys text cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(keysText) > 1024*1024 { // 1MB limit
|
if len(keysText) > 10*1024*1024 {
|
||||||
return fmt.Errorf("keys text is too large (max 1MB)")
|
return fmt.Errorf("keys text is too large (max 10MB)")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// findGroupByID is a helper function to find a group by its ID.
|
// findGroupByID is a helper function to find a group by its ID.
|
||||||
func (s *Server) findGroupByID(c *gin.Context, groupID int) (*models.Group, bool) {
|
func (s *Server) findGroupByID(c *gin.Context, groupID uint) (*models.Group, bool) {
|
||||||
var group models.Group
|
var group models.Group
|
||||||
if err := s.DB.First(&group, groupID).Error; err != nil {
|
if err := s.DB.First(&group, groupID).Error; err != nil {
|
||||||
if err == gorm.ErrRecordNotFound {
|
if err == gorm.ErrRecordNotFound {
|
||||||
@@ -69,22 +54,26 @@ func (s *Server) findGroupByID(c *gin.Context, groupID int) (*models.Group, bool
|
|||||||
return &group, true
|
return &group, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddMultipleKeysRequest defines the payload for adding multiple keys from a text block.
|
// KeyTextRequest defines a generic payload for operations requiring a group ID and a text block of keys.
|
||||||
type AddMultipleKeysRequest struct {
|
type KeyTextRequest struct {
|
||||||
|
GroupID uint `json:"group_id" binding:"required"`
|
||||||
KeysText string `json:"keys_text" binding:"required"`
|
KeysText string `json:"keys_text" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GroupIDRequest defines a generic payload for operations requiring only a group ID.
|
||||||
|
type GroupIDRequest struct {
|
||||||
|
GroupID uint `json:"group_id" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
// AddMultipleKeys handles creating new keys from a text block within a specific group.
|
// AddMultipleKeys handles creating new keys from a text block within a specific group.
|
||||||
func (s *Server) AddMultipleKeys(c *gin.Context) {
|
func (s *Server) AddMultipleKeys(c *gin.Context) {
|
||||||
groupID, err := validateGroupID(c)
|
var req KeyTextRequest
|
||||||
if err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error()))
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req AddMultipleKeysRequest
|
if _, ok := s.findGroupByID(c, req.GroupID); !ok {
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,20 +82,28 @@ func (s *Server) AddMultipleKeys(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := s.KeyService.AddMultipleKeys(groupID, req.KeysText)
|
result, err := s.KeyService.AddMultipleKeys(req.GroupID, req.KeysText)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err.Error() == "no valid keys found in the input text" {
|
||||||
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
|
||||||
|
} else {
|
||||||
response.Error(c, app_errors.ParseDBError(err))
|
response.Error(c, app_errors.ParseDBError(err))
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, result)
|
response.Success(c, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListKeysInGroup handles listing all keys within a specific group.
|
// ListKeysInGroup handles listing all keys within a specific group with pagination.
|
||||||
func (s *Server) ListKeysInGroup(c *gin.Context) {
|
func (s *Server) ListKeysInGroup(c *gin.Context) {
|
||||||
groupID, err := strconv.Atoi(c.Param("id"))
|
groupID, err := validateGroupIDFromQuery(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := s.findGroupByID(c, groupID); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,72 +113,93 @@ func (s *Server) ListKeysInGroup(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
keys, err := s.KeyService.ListKeysInGroup(uint(groupID), statusFilter)
|
searchKeyword := c.Query("key")
|
||||||
|
|
||||||
|
query := s.KeyService.ListKeysInGroupQuery(groupID, statusFilter, searchKeyword)
|
||||||
|
|
||||||
|
var keys []models.APIKey
|
||||||
|
paginatedResult, err := response.Paginate(c, query, &keys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, app_errors.ParseDBError(err))
|
response.Error(c, app_errors.ParseDBError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, keys)
|
response.Success(c, paginatedResult)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteSingleKey handles deleting a specific key.
|
// DeleteMultipleKeys handles deleting keys from a text block within a specific group.
|
||||||
func (s *Server) DeleteSingleKey(c *gin.Context) {
|
func (s *Server) DeleteMultipleKeys(c *gin.Context) {
|
||||||
groupID, err := strconv.Atoi(c.Param("id"))
|
var req KeyTextRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := s.findGroupByID(c, req.GroupID); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateKeysText(req.KeysText); err != nil {
|
||||||
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := s.KeyService.DeleteMultipleKeys(req.GroupID, req.KeysText)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
|
if err.Error() == "no valid keys found in the input text" {
|
||||||
return
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
|
||||||
}
|
|
||||||
|
|
||||||
keyID, err := strconv.Atoi(c.Param("key_id"))
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid key ID"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rowsAffected, err := s.KeyService.DeleteSingleKey(uint(groupID), uint(keyID))
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, app_errors.ParseDBError(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if rowsAffected == 0 {
|
|
||||||
response.Error(c, app_errors.ErrResourceNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Success(c, gin.H{"message": "Key deleted successfully"})
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSingleKey handles a one-off validation test for a single key.
|
|
||||||
func (s *Server) TestSingleKey(c *gin.Context) {
|
|
||||||
keyID, err := validateKeyID(c)
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
isValid, validationErr := s.KeyValidatorService.TestSingleKeyByID(c.Request.Context(), keyID)
|
|
||||||
if validationErr != nil {
|
|
||||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadGateway, validationErr.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if isValid {
|
|
||||||
response.Success(c, gin.H{"success": true, "message": "Key is valid."})
|
|
||||||
} else {
|
} else {
|
||||||
response.Success(c, gin.H{"success": false, "message": "Key is invalid or has insufficient quota."})
|
response.Error(c, app_errors.ParseDBError(err))
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMultipleKeys handles a one-off validation test for multiple keys.
|
||||||
|
func (s *Server) TestMultipleKeys(c *gin.Context) {
|
||||||
|
var req KeyTextRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
group, ok := s.findGroupByID(c, req.GroupID)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateKeysText(req.KeysText); err != nil {
|
||||||
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-use the parsing logic from the key service
|
||||||
|
keysToTest := s.KeyService.ParseKeysFromText(req.KeysText)
|
||||||
|
if len(keysToTest) == 0 {
|
||||||
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, "no valid keys found in the input text"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := s.KeyValidatorService.TestMultipleKeys(c.Request.Context(), group, keysToTest)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, app_errors.ParseDBError(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, results)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateGroupKeys initiates a manual validation task for all keys in a group.
|
// ValidateGroupKeys initiates a manual validation task for all keys in a group.
|
||||||
func (s *Server) ValidateGroupKeys(c *gin.Context) {
|
func (s *Server) ValidateGroupKeys(c *gin.Context) {
|
||||||
groupID, err := strconv.Atoi(c.Param("id"))
|
var req GroupIDRequest
|
||||||
if err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
group, ok := s.findGroupByID(c, groupID)
|
group, ok := s.findGroupByID(c, req.GroupID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -197,13 +215,17 @@ func (s *Server) ValidateGroupKeys(c *gin.Context) {
|
|||||||
|
|
||||||
// RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'.
|
// RestoreAllInvalidKeys sets the status of all 'inactive' keys in a group to 'active'.
|
||||||
func (s *Server) RestoreAllInvalidKeys(c *gin.Context) {
|
func (s *Server) RestoreAllInvalidKeys(c *gin.Context) {
|
||||||
groupID, err := strconv.Atoi(c.Param("id"))
|
var req GroupIDRequest
|
||||||
if err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsAffected, err := s.KeyService.RestoreAllInvalidKeys(uint(groupID))
|
if _, ok := s.findGroupByID(c, req.GroupID); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rowsAffected, err := s.KeyService.RestoreAllInvalidKeys(req.GroupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, app_errors.ParseDBError(err))
|
response.Error(c, app_errors.ParseDBError(err))
|
||||||
return
|
return
|
||||||
@@ -214,13 +236,17 @@ func (s *Server) RestoreAllInvalidKeys(c *gin.Context) {
|
|||||||
|
|
||||||
// ClearAllInvalidKeys deletes all 'inactive' keys from a group.
|
// ClearAllInvalidKeys deletes all 'inactive' keys from a group.
|
||||||
func (s *Server) ClearAllInvalidKeys(c *gin.Context) {
|
func (s *Server) ClearAllInvalidKeys(c *gin.Context) {
|
||||||
groupID, err := strconv.Atoi(c.Param("id"))
|
var req GroupIDRequest
|
||||||
if err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
|
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsAffected, err := s.KeyService.ClearAllInvalidKeys(uint(groupID))
|
if _, ok := s.findGroupByID(c, req.GroupID); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rowsAffected, err := s.KeyService.ClearAllInvalidKeys(req.GroupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, app_errors.ParseDBError(err))
|
response.Error(c, app_errors.ParseDBError(err))
|
||||||
return
|
return
|
||||||
@@ -229,20 +255,3 @@ func (s *Server) ClearAllInvalidKeys(c *gin.Context) {
|
|||||||
response.Success(c, gin.H{"message": fmt.Sprintf("%d invalid keys cleared.", rowsAffected)})
|
response.Success(c, gin.H{"message": fmt.Sprintf("%d invalid keys cleared.", rowsAffected)})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExportKeys returns a list of keys for a group, filtered by status.
|
|
||||||
func (s *Server) ExportKeys(c *gin.Context) {
|
|
||||||
groupID, err := strconv.Atoi(c.Param("id"))
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, "Invalid group ID"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
filter := c.DefaultQuery("filter", "all")
|
|
||||||
keys, err := s.KeyService.ExportKeys(uint(groupID), filter)
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Success(c, gin.H{"keys": keys})
|
|
||||||
}
|
|
||||||
|
@@ -52,11 +52,12 @@ type Group struct {
|
|||||||
// APIKey 对应 api_keys 表
|
// APIKey 对应 api_keys 表
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||||
GroupID uint `gorm:"not null" json:"group_id"`
|
KeyValue string `gorm:"type:varchar(512);not null;uniqueIndex:idx_group_key" json:"key_value"`
|
||||||
KeyValue string `gorm:"type:varchar(512);not null" json:"key_value"`
|
GroupID uint `gorm:"not null;uniqueIndex:idx_group_key" json:"group_id"`
|
||||||
Status string `gorm:"type:varchar(50);not null;default:'active'" json:"status"`
|
Status string `gorm:"type:varchar(50);not null;default:'active'" json:"status"`
|
||||||
RequestCount int64 `gorm:"not null;default:0" json:"request_count"`
|
RequestCount int64 `gorm:"not null;default:0" json:"request_count"`
|
||||||
FailureCount int64 `gorm:"not null;default:0" json:"failure_count"`
|
FailureCount int64 `gorm:"not null;default:0" json:"failure_count"`
|
||||||
|
ErrorReason string `gorm:"type:text" json:"error_reason"`
|
||||||
LastUsedAt *time.Time `json:"last_used_at"`
|
LastUsedAt *time.Time `json:"last_used_at"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
74
internal/response/pagination.go
Normal file
74
internal/response/pagination.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package response
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultPageSize = 15
|
||||||
|
MaxPageSize = 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
// Pagination represents the pagination details in a response.
|
||||||
|
type Pagination struct {
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
TotalItems int64 `json:"total_items"`
|
||||||
|
TotalPages int `json:"total_pages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PaginatedResponse is the standard structure for all paginated API responses.
|
||||||
|
type PaginatedResponse struct {
|
||||||
|
Items interface{} `json:"items"`
|
||||||
|
Pagination Pagination `json:"pagination"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Paginate performs pagination on a GORM query and returns a standardized response.
|
||||||
|
// It takes a Gin context, a GORM query builder, and a destination slice for the results.
|
||||||
|
func Paginate(c *gin.Context, query *gorm.DB, dest interface{}) (*PaginatedResponse, error) {
|
||||||
|
// 1. Get page and page size from query parameters
|
||||||
|
page, err := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
if err != nil || page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
pageSize, err := strconv.Atoi(c.DefaultQuery("page_size", strconv.Itoa(DefaultPageSize)))
|
||||||
|
if err != nil || pageSize <= 0 {
|
||||||
|
pageSize = DefaultPageSize
|
||||||
|
}
|
||||||
|
if pageSize > MaxPageSize {
|
||||||
|
pageSize = MaxPageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Get total count of items
|
||||||
|
var totalItems int64
|
||||||
|
if err := query.Count(&totalItems).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Calculate offset and total pages
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
totalPages := int(math.Ceil(float64(totalItems) / float64(pageSize)))
|
||||||
|
|
||||||
|
// 4. Retrieve the data for the current page
|
||||||
|
if err := query.Limit(pageSize).Offset(offset).Find(dest).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Construct the paginated response
|
||||||
|
paginatedData := &PaginatedResponse{
|
||||||
|
Items: dest,
|
||||||
|
Pagination: Pagination{
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
TotalItems: totalItems,
|
||||||
|
TotalPages: totalPages,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return paginatedData, nil
|
||||||
|
}
|
@@ -109,20 +109,18 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser
|
|||||||
groups.PUT("/:id", serverHandler.UpdateGroup)
|
groups.PUT("/:id", serverHandler.UpdateGroup)
|
||||||
groups.DELETE("/:id", serverHandler.DeleteGroup)
|
groups.DELETE("/:id", serverHandler.DeleteGroup)
|
||||||
|
|
||||||
// Key-specific routes
|
}
|
||||||
keys := groups.Group("/:id/keys")
|
|
||||||
|
// Key Management Routes
|
||||||
|
keys := api.Group("/keys")
|
||||||
{
|
{
|
||||||
keys.GET("", serverHandler.ListKeysInGroup)
|
keys.GET("", serverHandler.ListKeysInGroup)
|
||||||
keys.POST("/add-multiple", serverHandler.AddMultipleKeys)
|
keys.POST("/add-multiple", serverHandler.AddMultipleKeys)
|
||||||
|
keys.POST("/delete-multiple", serverHandler.DeleteMultipleKeys)
|
||||||
keys.POST("/restore-all-invalid", serverHandler.RestoreAllInvalidKeys)
|
keys.POST("/restore-all-invalid", serverHandler.RestoreAllInvalidKeys)
|
||||||
keys.POST("/clear-all-invalid", serverHandler.ClearAllInvalidKeys)
|
keys.POST("/clear-all-invalid", serverHandler.ClearAllInvalidKeys)
|
||||||
keys.GET("/export", serverHandler.ExportKeys)
|
keys.POST("/validate-group", serverHandler.ValidateGroupKeys)
|
||||||
keys.DELETE("/:key_id", serverHandler.DeleteSingleKey)
|
keys.POST("/test-multiple", serverHandler.TestMultipleKeys)
|
||||||
keys.POST("/:key_id/test", serverHandler.TestSingleKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Group-level actions
|
|
||||||
groups.POST("/:id/validate-keys", serverHandler.ValidateGroupKeys)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tasks
|
// Tasks
|
||||||
|
@@ -147,21 +147,38 @@ func (s *KeyCronService) validateGroup(ctx context.Context, group *models.Group)
|
|||||||
func (s *KeyCronService) worker(ctx context.Context, wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- models.APIKey) {
|
func (s *KeyCronService) worker(ctx context.Context, wg *sync.WaitGroup, group *models.Group, jobs <-chan models.APIKey, results chan<- models.APIKey) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for key := range jobs {
|
for key := range jobs {
|
||||||
isValid, err := s.Validator.ValidateSingleKey(ctx, &key, group)
|
isValid, validationErr := s.Validator.ValidateSingleKey(ctx, &key, group)
|
||||||
// Only update status if there was no error during validation
|
|
||||||
if err != nil {
|
|
||||||
logrus.Warnf("KeyCronService: Failed to validate key ID %d for group %s: %v. Skipping status update.", key.ID, group.Name, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
newStatus := "inactive"
|
newStatus := key.Status
|
||||||
|
newErrorReason := key.ErrorReason
|
||||||
|
statusChanged := false
|
||||||
|
|
||||||
|
if validationErr != nil {
|
||||||
|
// Validation failed, mark as inactive and record the reason
|
||||||
|
newStatus = "inactive"
|
||||||
|
newErrorReason = validationErr.Error()
|
||||||
|
} else {
|
||||||
|
// Validation succeeded
|
||||||
if isValid {
|
if isValid {
|
||||||
newStatus = "active"
|
newStatus = "active"
|
||||||
|
newErrorReason = "" // Clear reason on success
|
||||||
|
} else {
|
||||||
|
// This case might happen if the key is valid but has no quota, etc.
|
||||||
|
// The error would be in validationErr, so this branch is less likely.
|
||||||
|
// We still mark it as inactive but without a specific error from our side.
|
||||||
|
newStatus = "inactive"
|
||||||
|
newErrorReason = "Validation returned false without a specific error."
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only send to results if the status has changed
|
// Check if status or error reason has changed
|
||||||
if key.Status != newStatus {
|
if key.Status != newStatus || key.ErrorReason != newErrorReason {
|
||||||
|
statusChanged = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusChanged {
|
||||||
key.Status = newStatus
|
key.Status = newStatus
|
||||||
|
key.ErrorReason = newErrorReason
|
||||||
results <- key
|
results <- key
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -171,36 +188,24 @@ func (s *KeyCronService) batchUpdateKeyStatus(keys []models.APIKey) {
|
|||||||
if len(keys) == 0 {
|
if len(keys) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logrus.Infof("KeyCronService: Batch updating status for %d keys.", len(keys))
|
logrus.Infof("KeyCronService: Batch updating status/reason for %d keys.", len(keys))
|
||||||
|
|
||||||
activeIDs := []uint{}
|
|
||||||
inactiveIDs := []uint{}
|
|
||||||
|
|
||||||
for _, key := range keys {
|
|
||||||
if key.Status == "active" {
|
|
||||||
activeIDs = append(activeIDs, key.ID)
|
|
||||||
} else {
|
|
||||||
inactiveIDs = append(inactiveIDs, key.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := s.DB.Transaction(func(tx *gorm.DB) error {
|
err := s.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
if len(activeIDs) > 0 {
|
for _, key := range keys {
|
||||||
if err := tx.Model(&models.APIKey{}).Where("id IN ?", activeIDs).Update("status", "active").Error; err != nil {
|
updates := map[string]interface{}{
|
||||||
return err
|
"status": key.Status,
|
||||||
|
"error_reason": key.ErrorReason,
|
||||||
}
|
}
|
||||||
logrus.Infof("KeyCronService: Set %d keys to 'active'.", len(activeIDs))
|
if err := tx.Model(&models.APIKey{}).Where("id = ?", key.ID).Updates(updates).Error; err != nil {
|
||||||
|
// Log the error for this specific key but continue the transaction
|
||||||
|
logrus.Errorf("KeyCronService: Failed to update key ID %d: %v", key.ID, err)
|
||||||
}
|
}
|
||||||
if len(inactiveIDs) > 0 {
|
|
||||||
if err := tx.Model(&models.APIKey{}).Where("id IN ?", inactiveIDs).Update("status", "inactive").Error; err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
logrus.Infof("KeyCronService: Set %d keys to 'inactive'.", len(inactiveIDs))
|
return nil // Commit the transaction even if some updates failed
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorf("KeyCronService: Failed to batch update key status: %v", err)
|
// This error is for the transaction itself, not individual updates
|
||||||
|
logrus.Errorf("KeyCronService: Transaction failed during batch update of key statuses: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -17,6 +17,13 @@ type AddKeysResult struct {
|
|||||||
TotalInGroup int64 `json:"total_in_group"`
|
TotalInGroup int64 `json:"total_in_group"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteKeysResult holds the result of deleting multiple keys.
|
||||||
|
type DeleteKeysResult struct {
|
||||||
|
DeletedCount int `json:"deleted_count"`
|
||||||
|
IgnoredCount int `json:"ignored_count"`
|
||||||
|
TotalInGroup int64 `json:"total_in_group"`
|
||||||
|
}
|
||||||
|
|
||||||
// KeyService provides services related to API keys.
|
// KeyService provides services related to API keys.
|
||||||
type KeyService struct {
|
type KeyService struct {
|
||||||
DB *gorm.DB
|
DB *gorm.DB
|
||||||
@@ -30,7 +37,7 @@ func NewKeyService(db *gorm.DB) *KeyService {
|
|||||||
// AddMultipleKeys handles the business logic of creating new keys from a text block.
|
// AddMultipleKeys handles the business logic of creating new keys from a text block.
|
||||||
func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysResult, error) {
|
func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysResult, error) {
|
||||||
// 1. Parse keys from the text block
|
// 1. Parse keys from the text block
|
||||||
keys := s.parseKeysFromText(keysText)
|
keys := s.ParseKeysFromText(keysText)
|
||||||
if len(keys) == 0 {
|
if len(keys) == 0 {
|
||||||
return nil, fmt.Errorf("no valid keys found in the input text")
|
return nil, fmt.Errorf("no valid keys found in the input text")
|
||||||
}
|
}
|
||||||
@@ -101,7 +108,9 @@ func (s *KeyService) AddMultipleKeys(groupID uint, keysText string) (*AddKeysRes
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *KeyService) parseKeysFromText(text string) []string {
|
// ParseKeysFromText parses a string of keys from various formats into a string slice.
|
||||||
|
// This function is exported to be shared with the handler layer.
|
||||||
|
func (s *KeyService) ParseKeysFromText(text string) []string {
|
||||||
var keys []string
|
var keys []string
|
||||||
|
|
||||||
// First, try to parse as a JSON array of strings
|
// First, try to parse as a JSON array of strings
|
||||||
@@ -162,45 +171,52 @@ func (s *KeyService) ClearAllInvalidKeys(groupID uint) (int64, error) {
|
|||||||
return result.RowsAffected, result.Error
|
return result.RowsAffected, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteSingleKey deletes a specific key from a group.
|
// DeleteMultipleKeys handles the business logic of deleting keys from a text block.
|
||||||
func (s *KeyService) DeleteSingleKey(groupID, keyID uint) (int64, error) {
|
func (s *KeyService) DeleteMultipleKeys(groupID uint, keysText string) (*DeleteKeysResult, error) {
|
||||||
result := s.DB.Where("group_id = ? AND id = ?", groupID, keyID).Delete(&models.APIKey{})
|
// 1. Parse keys from the text block
|
||||||
return result.RowsAffected, result.Error
|
keysToDelete := s.ParseKeysFromText(keysText)
|
||||||
}
|
if len(keysToDelete) == 0 {
|
||||||
|
return nil, fmt.Errorf("no valid keys found in the input text")
|
||||||
// ExportKeys returns a list of keys for a group, filtered by status.
|
|
||||||
func (s *KeyService) ExportKeys(groupID uint, filter string) ([]string, error) {
|
|
||||||
query := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID)
|
|
||||||
|
|
||||||
switch filter {
|
|
||||||
case "valid":
|
|
||||||
query = query.Where("status = ?", "active")
|
|
||||||
case "invalid":
|
|
||||||
query = query.Where("status = ?", "inactive")
|
|
||||||
case "all":
|
|
||||||
// No status filter needed
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid filter value. Use 'all', 'valid', or 'invalid'")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var keys []string
|
// 2. Perform the deletion
|
||||||
if err := query.Pluck("key_value", &keys).Error; err != nil {
|
// GORM's batch delete doesn't easily return which ones were deleted vs. ignored.
|
||||||
|
// We perform a bulk delete and then count the remaining to calculate the result.
|
||||||
|
result := s.DB.Where("group_id = ? AND key_value IN ?", groupID, keysToDelete).Delete(&models.APIKey{})
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
deletedCount := int(result.RowsAffected)
|
||||||
|
ignoredCount := len(keysToDelete) - deletedCount
|
||||||
|
|
||||||
|
// 3. Get the new total count
|
||||||
|
var totalInGroup int64
|
||||||
|
if err := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID).Count(&totalInGroup).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return keys, nil
|
|
||||||
|
return &DeleteKeysResult{
|
||||||
|
DeletedCount: deletedCount,
|
||||||
|
IgnoredCount: ignoredCount,
|
||||||
|
TotalInGroup: totalInGroup,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListKeysInGroup lists all keys within a specific group, filtered by status.
|
// ListKeysInGroupQuery builds a query to list all keys within a specific group, filtered by status.
|
||||||
func (s *KeyService) ListKeysInGroup(groupID uint, statusFilter string) ([]models.APIKey, error) {
|
// It returns a GORM query builder, allowing the handler to apply pagination.
|
||||||
var keys []models.APIKey
|
func (s *KeyService) ListKeysInGroupQuery(groupID uint, statusFilter string, searchKeyword string) *gorm.DB {
|
||||||
query := s.DB.Where("group_id = ?", groupID)
|
query := s.DB.Model(&models.APIKey{}).Where("group_id = ?", groupID)
|
||||||
|
|
||||||
if statusFilter != "" {
|
if statusFilter != "" {
|
||||||
query = query.Where("status = ?", statusFilter)
|
query = query.Where("status = ?", statusFilter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := query.Find(&keys).Error; err != nil {
|
if searchKeyword != "" {
|
||||||
return nil, err
|
// Use LIKE for fuzzy search on the key_value
|
||||||
|
query = query.Where("key_value LIKE ?", "%"+searchKeyword+"%")
|
||||||
}
|
}
|
||||||
return keys, nil
|
|
||||||
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -10,6 +10,13 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// KeyTestResult holds the validation result for a single key.
|
||||||
|
type KeyTestResult struct {
|
||||||
|
KeyValue string `json:"key_value"`
|
||||||
|
IsValid bool `json:"is_valid"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// KeyValidatorService provides methods to validate API keys.
|
// KeyValidatorService provides methods to validate API keys.
|
||||||
type KeyValidatorService struct {
|
type KeyValidatorService struct {
|
||||||
DB *gorm.DB
|
DB *gorm.DB
|
||||||
@@ -73,19 +80,45 @@ func (s *KeyValidatorService) ValidateSingleKey(ctx context.Context, key *models
|
|||||||
return isValid, nil
|
return isValid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestSingleKeyByID performs a synchronous validation test for a single API key by its ID.
|
// TestMultipleKeys performs a synchronous validation for a list of key values within a specific group.
|
||||||
// It is intended for handling user-initiated "Test" actions.
|
func (s *KeyValidatorService) TestMultipleKeys(ctx context.Context, group *models.Group, keyValues []string) ([]KeyTestResult, error) {
|
||||||
// It does not modify the key's state in the database.
|
results := make([]KeyTestResult, len(keyValues))
|
||||||
func (s *KeyValidatorService) TestSingleKeyByID(ctx context.Context, keyID uint) (bool, error) {
|
ch, err := s.channelFactory.GetChannel(group)
|
||||||
var apiKey models.APIKey
|
if err != nil {
|
||||||
if err := s.DB.First(&apiKey, keyID).Error; err != nil {
|
return nil, fmt.Errorf("failed to get channel for group %s: %w", group.Name, err)
|
||||||
return false, fmt.Errorf("failed to find api key with id %d: %w", keyID, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var group models.Group
|
// Find which of the provided keys actually exist in the database for this group
|
||||||
if err := s.DB.First(&group, apiKey.GroupID).Error; err != nil {
|
var existingKeys []models.APIKey
|
||||||
return false, fmt.Errorf("failed to find group with id %d: %w", apiKey.GroupID, err)
|
if err := s.DB.Where("group_id = ? AND key_value IN ?", group.ID, keyValues).Find(&existingKeys).Error; err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query keys from DB: %w", err)
|
||||||
|
}
|
||||||
|
existingKeyMap := make(map[string]bool)
|
||||||
|
for _, k := range existingKeys {
|
||||||
|
existingKeyMap[k.KeyValue] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.ValidateSingleKey(ctx, &apiKey, &group)
|
for i, kv := range keyValues {
|
||||||
|
// Pre-check: ensure the key belongs to the group to prevent unnecessary API calls
|
||||||
|
if !existingKeyMap[kv] {
|
||||||
|
results[i] = KeyTestResult{
|
||||||
|
KeyValue: kv,
|
||||||
|
IsValid: false,
|
||||||
|
Error: "Key does not exist in this group or has been removed.",
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
isValid, validationErr := ch.ValidateKey(ctx, kv)
|
||||||
|
results[i] = KeyTestResult{
|
||||||
|
KeyValue: kv,
|
||||||
|
IsValid: isValid,
|
||||||
|
Error: "", // Explicitly set error to empty string on success
|
||||||
|
}
|
||||||
|
if validationErr != nil {
|
||||||
|
results[i].Error = validationErr.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user