feat: 密钥管理
This commit is contained in:
@@ -2,11 +2,14 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gpt-load/internal/channel"
|
||||
app_errors "gpt-load/internal/errors"
|
||||
"gpt-load/internal/models"
|
||||
"gpt-load/internal/response"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -19,14 +22,16 @@ import (
|
||||
// ProxyServer represents the proxy server
|
||||
type ProxyServer struct {
|
||||
DB *gorm.DB
|
||||
channelFactory *channel.Factory
|
||||
groupCounters sync.Map // map[uint]*atomic.Uint64
|
||||
requestLogChan chan models.RequestLog
|
||||
}
|
||||
|
||||
// NewProxyServer creates a new proxy server
|
||||
func NewProxyServer(db *gorm.DB, requestLogChan chan models.RequestLog) (*ProxyServer, error) {
|
||||
func NewProxyServer(db *gorm.DB, channelFactory *channel.Factory, requestLogChan chan models.RequestLog) (*ProxyServer, error) {
|
||||
return &ProxyServer{
|
||||
DB: db,
|
||||
channelFactory: channelFactory,
|
||||
groupCounters: sync.Map{},
|
||||
requestLogChan: requestLogChan,
|
||||
}, nil
|
||||
@@ -52,22 +57,31 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3. Get the appropriate channel handler from the factory
|
||||
channelHandler, err := channel.GetChannel(&group)
|
||||
channelHandler, err := ps.channelFactory.GetChannel(&group)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to get channel for group '%s': %v", groupName, err)))
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Forward the request using the channel handler
|
||||
// 4. Apply parameter overrides if they exist
|
||||
if len(group.ParamOverrides) > 0 {
|
||||
err := ps.applyParamOverrides(c, &group)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to apply parameter overrides: %v", err)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Forward the request using the channel handler
|
||||
err = channelHandler.Handle(c, apiKey, &group)
|
||||
|
||||
// 5. Log the request asynchronously
|
||||
// 6. Log the request asynchronously
|
||||
isSuccess := err == nil
|
||||
if !isSuccess {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"group": group.Name,
|
||||
"group": group.Name,
|
||||
"key_id": apiKey.ID,
|
||||
"error": err.Error(),
|
||||
"error": err.Error(),
|
||||
}).Error("Channel handler failed")
|
||||
}
|
||||
go ps.logRequest(c, &group, apiKey, startTime, isSuccess)
|
||||
@@ -145,3 +159,51 @@ func (ps *ProxyServer) updateKeyStats(keyID uint, success bool) {
|
||||
func (ps *ProxyServer) Close() {
|
||||
// Nothing to close for now
|
||||
}
|
||||
|
||||
func (ps *ProxyServer) applyParamOverrides(c *gin.Context, group *models.Group) error {
|
||||
// Read the original request body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read request body: %w", err)
|
||||
}
|
||||
c.Request.Body.Close() // Close the original body
|
||||
|
||||
// If body is empty, nothing to override, just restore the body
|
||||
if len(bodyBytes) == 0 {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save the original Content-Type
|
||||
originalContentType := c.GetHeader("Content-Type")
|
||||
|
||||
// Unmarshal the body into a map
|
||||
var requestData map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &requestData); err != nil {
|
||||
// If not a valid JSON, just pass it through
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Merge the overrides into the request data
|
||||
for key, value := range group.ParamOverrides {
|
||||
requestData[key] = value
|
||||
}
|
||||
|
||||
// Marshal the new data back to JSON
|
||||
newBodyBytes, err := json.Marshal(requestData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal new request body: %w", err)
|
||||
}
|
||||
|
||||
// Replace the request body with the new one
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(newBodyBytes))
|
||||
c.Request.ContentLength = int64(len(newBodyBytes))
|
||||
|
||||
// Restore the original Content-Type header
|
||||
if originalContentType != "" {
|
||||
c.Request.Header.Set("Content-Type", originalContentType)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user