feat: key provider
This commit is contained in:
@@ -7,11 +7,10 @@ import (
|
||||
"fmt"
|
||||
"gpt-load/internal/channel"
|
||||
app_errors "gpt-load/internal/errors"
|
||||
"gpt-load/internal/keypool"
|
||||
"gpt-load/internal/models"
|
||||
"gpt-load/internal/response"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -23,16 +22,21 @@ import (
|
||||
type ProxyServer struct {
|
||||
DB *gorm.DB
|
||||
channelFactory *channel.Factory
|
||||
groupCounters sync.Map // map[uint]*atomic.Uint64
|
||||
keyProvider *keypool.KeyProvider
|
||||
requestLogChan chan models.RequestLog
|
||||
}
|
||||
|
||||
// NewProxyServer creates a new proxy server
|
||||
func NewProxyServer(db *gorm.DB, channelFactory *channel.Factory, requestLogChan chan models.RequestLog) (*ProxyServer, error) {
|
||||
func NewProxyServer(
|
||||
db *gorm.DB,
|
||||
channelFactory *channel.Factory,
|
||||
keyProvider *keypool.KeyProvider,
|
||||
requestLogChan chan models.RequestLog,
|
||||
) (*ProxyServer, error) {
|
||||
return &ProxyServer{
|
||||
DB: db,
|
||||
channelFactory: channelFactory,
|
||||
groupCounters: sync.Map{},
|
||||
keyProvider: keyProvider,
|
||||
requestLogChan: requestLogChan,
|
||||
}, nil
|
||||
}
|
||||
@@ -42,17 +46,22 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
groupName := c.Param("group_name")
|
||||
|
||||
// 1. Find the group by name
|
||||
// 1. Find the group by name (without preloading keys)
|
||||
var group models.Group
|
||||
if err := ps.DB.Preload("APIKeys").Where("name = ?", groupName).First(&group).Error; err != nil {
|
||||
if err := ps.DB.Where("name = ?", groupName).First(&group).Error; err != nil {
|
||||
response.Error(c, app_errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Select an available API key from the group
|
||||
apiKey, err := ps.selectAPIKey(&group)
|
||||
// 2. Select an available API key from the KeyPool
|
||||
apiKey, err := ps.keyProvider.SelectKey(group.ID)
|
||||
if err != nil {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, err.Error()))
|
||||
// Properly handle the case where no keys are available
|
||||
if apiErr, ok := err.(*app_errors.APIError); ok {
|
||||
response.Error(c, apiErr)
|
||||
} else {
|
||||
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, err.Error()))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -75,8 +84,10 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) {
|
||||
// 5. Forward the request using the channel handler
|
||||
err = channelHandler.Handle(c, apiKey, &group)
|
||||
|
||||
// 6. Log the request asynchronously
|
||||
// 6. Update key status and log the request asynchronously
|
||||
isSuccess := err == nil
|
||||
ps.keyProvider.UpdateStatus(apiKey.ID, group.ID, isSuccess)
|
||||
|
||||
if !isSuccess {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"group": group.Name,
|
||||
@@ -84,37 +95,10 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) {
|
||||
"error": err.Error(),
|
||||
}).Error("Channel handler failed")
|
||||
}
|
||||
go ps.logRequest(c, &group, apiKey, startTime, isSuccess)
|
||||
go ps.logRequest(c, &group, apiKey, startTime)
|
||||
}
|
||||
|
||||
// selectAPIKey selects an API key from a group using round-robin
|
||||
func (ps *ProxyServer) selectAPIKey(group *models.Group) (*models.APIKey, error) {
|
||||
activeKeys := make([]models.APIKey, 0, len(group.APIKeys))
|
||||
for _, key := range group.APIKeys {
|
||||
if key.Status == "active" {
|
||||
activeKeys = append(activeKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(activeKeys) == 0 {
|
||||
return nil, fmt.Errorf("no active API keys available in group '%s'", group.Name)
|
||||
}
|
||||
|
||||
// Get or create a counter for the group. The value is a pointer to a uint64.
|
||||
val, _ := ps.groupCounters.LoadOrStore(group.ID, new(atomic.Uint64))
|
||||
counter := val.(*atomic.Uint64)
|
||||
|
||||
// Atomically increment the counter and get the index for this request.
|
||||
index := counter.Add(1) - 1
|
||||
selectedKey := activeKeys[int(index%uint64(len(activeKeys)))]
|
||||
|
||||
return &selectedKey, nil
|
||||
}
|
||||
|
||||
func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time, isSuccess bool) {
|
||||
// Update key stats based on request success
|
||||
go ps.updateKeyStats(key.ID, isSuccess)
|
||||
|
||||
func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time) {
|
||||
logEntry := models.RequestLog{
|
||||
ID: fmt.Sprintf("req_%d", time.Now().UnixNano()),
|
||||
Timestamp: startTime,
|
||||
@@ -134,27 +118,6 @@ func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *mode
|
||||
}
|
||||
}
|
||||
|
||||
// updateKeyStats atomically updates the request and failure counts for a key
|
||||
func (ps *ProxyServer) updateKeyStats(keyID uint, success bool) {
|
||||
// Always increment the request count
|
||||
updates := map[string]any{
|
||||
"request_count": gorm.Expr("request_count + 1"),
|
||||
}
|
||||
|
||||
// Additionally, increment the failure count if the request was not successful
|
||||
if !success {
|
||||
updates["failure_count"] = gorm.Expr("failure_count + 1")
|
||||
}
|
||||
|
||||
result := ps.DB.Model(&models.APIKey{}).Where("id = ?", keyID).Updates(updates)
|
||||
if result.Error != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"keyID": keyID,
|
||||
"error": result.Error,
|
||||
}).Error("Failed to update key stats")
|
||||
}
|
||||
}
|
||||
|
||||
// Close cleans up resources
|
||||
func (ps *ProxyServer) Close() {
|
||||
// Nothing to close for now
|
||||
|
Reference in New Issue
Block a user