feat: 密钥管理

This commit is contained in:
tbphp
2025-07-04 21:19:15 +08:00
parent 7c10474d19
commit 01b86f7e30
23 changed files with 1427 additions and 250 deletions

View File

@@ -2,11 +2,14 @@
package proxy
import (
"bytes"
"encoding/json"
"fmt"
"gpt-load/internal/channel"
app_errors "gpt-load/internal/errors"
"gpt-load/internal/models"
"gpt-load/internal/response"
"io"
"sync"
"sync/atomic"
"time"
@@ -19,14 +22,16 @@ import (
// ProxyServer represents the proxy server
type ProxyServer struct {
DB *gorm.DB
channelFactory *channel.Factory
groupCounters sync.Map // map[uint]*atomic.Uint64
requestLogChan chan models.RequestLog
}
// NewProxyServer creates a new proxy server
func NewProxyServer(db *gorm.DB, requestLogChan chan models.RequestLog) (*ProxyServer, error) {
func NewProxyServer(db *gorm.DB, channelFactory *channel.Factory, requestLogChan chan models.RequestLog) (*ProxyServer, error) {
return &ProxyServer{
DB: db,
channelFactory: channelFactory,
groupCounters: sync.Map{},
requestLogChan: requestLogChan,
}, nil
@@ -52,22 +57,31 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) {
}
// 3. Get the appropriate channel handler from the factory
channelHandler, err := channel.GetChannel(&group)
channelHandler, err := ps.channelFactory.GetChannel(&group)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s': %v", groupName, err)))
return
}
// 4. Forward the request using the channel handler
// 4. Apply parameter overrides if they exist
if len(group.ParamOverrides) > 0 {
err := ps.applyParamOverrides(c, &group)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to apply parameter overrides: %v", err)))
return
}
}
// 5. Forward the request using the channel handler
err = channelHandler.Handle(c, apiKey, &group)
// 5. Log the request asynchronously
// 6. Log the request asynchronously
isSuccess := err == nil
if !isSuccess {
logrus.WithFields(logrus.Fields{
"group": group.Name,
"group": group.Name,
"key_id": apiKey.ID,
"error": err.Error(),
"error": err.Error(),
}).Error("Channel handler failed")
}
go ps.logRequest(c, &group, apiKey, startTime, isSuccess)
@@ -145,3 +159,51 @@ func (ps *ProxyServer) updateKeyStats(keyID uint, success bool) {
func (ps *ProxyServer) Close() {
// Nothing to close for now
}
func (ps *ProxyServer) applyParamOverrides(c *gin.Context, group *models.Group) error {
// Read the original request body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return fmt.Errorf("failed to read request body: %w", err)
}
c.Request.Body.Close() // Close the original body
// If body is empty, nothing to override, just restore the body
if len(bodyBytes) == 0 {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
return nil
}
// Save the original Content-Type
originalContentType := c.GetHeader("Content-Type")
// Unmarshal the body into a map
var requestData map[string]interface{}
if err := json.Unmarshal(bodyBytes, &requestData); err != nil {
// If not a valid JSON, just pass it through
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
return nil
}
// Merge the overrides into the request data
for key, value := range group.ParamOverrides {
requestData[key] = value
}
// Marshal the new data back to JSON
newBodyBytes, err := json.Marshal(requestData)
if err != nil {
return fmt.Errorf("failed to marshal new request body: %w", err)
}
// Replace the request body with the new one
c.Request.Body = io.NopCloser(bytes.NewBuffer(newBodyBytes))
c.Request.ContentLength = int64(len(newBodyBytes))
// Restore the original Content-Type header
if originalContentType != "" {
c.Request.Header.Set("Content-Type", originalContentType)
}
return nil
}