Files
gpt-load/internal/proxy/server.go
2025-06-29 21:59:32 +08:00

120 lines
3.5 KiB
Go

// Package proxy provides high-performance OpenAI multi-key proxy server
package proxy
import (
"fmt"
"gpt-load/internal/channel"
"gpt-load/internal/models"
"gpt-load/internal/response"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// ProxyServer represents the proxy server
type ProxyServer struct {
DB *gorm.DB
groupCounters sync.Map // For round-robin key selection
requestLogChan chan models.RequestLog
}
// NewProxyServer creates a new proxy server
func NewProxyServer(db *gorm.DB, requestLogChan chan models.RequestLog) (*ProxyServer, error) {
return &ProxyServer{
DB: db,
groupCounters: sync.Map{},
requestLogChan: requestLogChan,
}, nil
}
// RegisterProxyRoutes registers the main proxy route under a given router group
func (ps *ProxyServer) RegisterProxyRoutes(proxy *gin.RouterGroup) {
proxy.Any("/:group_name/*path", ps.HandleProxy)
}
// HandleProxy handles the main proxy logic
func (ps *ProxyServer) HandleProxy(c *gin.Context) {
startTime := time.Now()
groupName := c.Param("group_name")
// 1. Find the group by name
var group models.Group
if err := ps.DB.Preload("APIKeys").Where("name = ?", groupName).First(&group).Error; err != nil {
response.Error(c, http.StatusNotFound, fmt.Sprintf("Group '%s' not found", groupName))
return
}
// 2. Select an available API key from the group
apiKey, err := ps.selectAPIKey(&group)
if err != nil {
response.Error(c, http.StatusServiceUnavailable, err.Error())
return
}
// 3. Get the appropriate channel handler from the factory
channelHandler, err := channel.GetChannel(&group)
if err != nil {
response.Error(c, http.StatusInternalServerError, fmt.Sprintf("Failed to get channel for group '%s': %v", groupName, err))
return
}
// 4. Forward the request using the channel handler
channelHandler.Handle(c, apiKey, &group)
// 5. Log the request asynchronously
go ps.logRequest(c, &group, apiKey, startTime)
}
// selectAPIKey selects an API key from a group using round-robin
func (ps *ProxyServer) selectAPIKey(group *models.Group) (*models.APIKey, error) {
activeKeys := make([]models.APIKey, 0, len(group.APIKeys))
for _, key := range group.APIKeys {
if key.Status == "active" {
activeKeys = append(activeKeys, key)
}
}
if len(activeKeys) == 0 {
return nil, fmt.Errorf("no active API keys available in group '%s'", group.Name)
}
// Get the current counter for the group
counter, _ := ps.groupCounters.LoadOrStore(group.ID, uint64(0))
currentCounter := counter.(uint64)
// Select the key and increment the counter
selectedKey := activeKeys[int(currentCounter%uint64(len(activeKeys)))]
ps.groupCounters.Store(group.ID, currentCounter+1)
return &selectedKey, nil
}
func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time) {
logEntry := models.RequestLog{
ID: fmt.Sprintf("req_%d", time.Now().UnixNano()),
Timestamp: startTime,
GroupID: group.ID,
KeyID: key.ID,
SourceIP: c.ClientIP(),
StatusCode: c.Writer.Status(),
RequestPath: c.Request.URL.Path,
RequestBodySnippet: "", // Can be implemented later if needed
}
// Send to the logging channel without blocking
select {
case ps.requestLogChan <- logEntry:
default:
logrus.Warn("Request log channel is full. Dropping log entry.")
}
}
// Close cleans up resources
func (ps *ProxyServer) Close() {
// Nothing to close for now
}