feat: 添加基础通道实现和请求处理逻辑

This commit is contained in:
tbphp
2025-07-02 09:46:08 +08:00
parent 5818c3cb1d
commit 762dfe48e8
5 changed files with 185 additions and 68 deletions

View File

@@ -0,0 +1,122 @@
package channel
import (
"bytes"
"compress/gzip"
"fmt"
"gpt-load/internal/models"
"gpt-load/internal/response"
"io"
"net/http"
"net/url"
"strings"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
// RequestModifier defines a function that can modify the upstream request,
// for example, by adding authentication headers.
type RequestModifier func(req *http.Request, apiKey *models.APIKey)
// BaseChannel provides a foundation for specific channel implementations.
type BaseChannel struct {
Name string
BaseURL *url.URL
HTTPClient *http.Client
}
// ProcessRequest handles the generic logic of creating, sending, and handling an upstream request.
func (ch *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier) error {
// 1. Create the upstream request
req, err := ch.createUpstreamRequest(c, apiKey, modifier)
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to create upstream request")
return fmt.Errorf("create upstream request failed: %w", err)
}
// 2. Send the request
resp, err := ch.HTTPClient.Do(req)
if err != nil {
response.Error(c, http.StatusServiceUnavailable, "Upstream service unavailable")
return fmt.Errorf("upstream request failed: %w", err)
}
defer resp.Body.Close()
// 3. Handle non-200 status codes
if resp.StatusCode != http.StatusOK {
errorMsg := ch.getErrorMessage(resp)
response.Error(c, resp.StatusCode, errorMsg)
return fmt.Errorf("upstream returned status %d: %s", resp.StatusCode, errorMsg)
}
// 4. Stream the successful response back to the client
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
c.Status(http.StatusOK)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
logrus.Errorf("Failed to copy response body to client: %v", err)
return fmt.Errorf("copy response body failed: %w", err)
}
return nil
}
func (ch *BaseChannel) createUpstreamRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier) (*http.Request, error) {
targetURL := *ch.BaseURL
targetURL.Path = c.Param("path")
body, err := io.ReadAll(c.Request.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create new request: %w", err)
}
req.Header = c.Request.Header.Clone()
req.Host = ch.BaseURL.Host
// Apply the channel-specific modifications
if modifier != nil {
modifier(req, apiKey)
}
return req, nil
}
func (ch *BaseChannel) getErrorMessage(resp *http.Response) string {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Sprintf("HTTP %d (failed to read error body: %v)", resp.StatusCode, err)
}
var errorMessage string
if resp.Header.Get("Content-Encoding") == "gzip" {
reader, gErr := gzip.NewReader(bytes.NewReader(bodyBytes))
if gErr != nil {
return string(bodyBytes)
}
defer reader.Close()
uncompressedBytes, rErr := io.ReadAll(reader)
if rErr != nil {
return fmt.Sprintf("gzip read error: %v", rErr)
}
errorMessage = string(uncompressedBytes)
} else {
errorMessage = string(bodyBytes)
}
if strings.TrimSpace(errorMessage) == "" {
return fmt.Sprintf("HTTP %d: %s", resp.StatusCode, http.StatusText(resp.StatusCode))
}
return errorMessage
}

View File

@@ -10,5 +10,5 @@ import (
type ChannelProxy interface {
// Handle takes a context, an API key, and the original request,
// then forwards the request to the upstream service.
Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group)
Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error
}

View File

@@ -1,55 +1,50 @@
package channel
import (
"encoding/json"
"fmt"
"gpt-load/internal/models"
"net/http"
"net/http/httputil"
"net/url"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
const GeminiBaseURL = "https://generativelanguage.googleapis.com"
type GeminiChannel struct {
BaseURL *url.URL
BaseChannel
}
type GeminiChannelConfig struct {
BaseURL string `json:"base_url"`
}
func NewGeminiChannel(group *models.Group) (*GeminiChannel, error) {
baseURL, err := url.Parse(GeminiBaseURL)
if err != nil {
return nil, err // Should not happen with a constant
var config GeminiChannelConfig
if err := json.Unmarshal([]byte(group.Config), &config); err != nil {
return nil, fmt.Errorf("failed to unmarshal channel config: %w", err)
}
return &GeminiChannel{BaseURL: baseURL}, nil
if config.BaseURL == "" {
return nil, fmt.Errorf("base_url is required for gemini channel")
}
baseURL, err := url.Parse(config.BaseURL)
if err != nil {
return nil, fmt.Errorf("failed to parse base_url: %w", err)
}
return &GeminiChannel{
BaseChannel: BaseChannel{
Name: "gemini",
BaseURL: baseURL,
HTTPClient: &http.Client{},
},
}, nil
}
func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) {
proxy := httputil.NewSingleHostReverseProxy(ch.BaseURL)
proxy.Director = func(req *http.Request) {
// Gemini API key is passed as a query parameter
originalPath := c.Param("path")
newPath := fmt.Sprintf("%s?key=%s", originalPath, apiKey.KeyValue)
req.URL.Scheme = ch.BaseURL.Scheme
req.URL.Host = ch.BaseURL.Host
req.URL.Path = newPath
req.Host = ch.BaseURL.Host
// Remove the Authorization header if it was passed by the client
req.Header.Del("Authorization")
func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error {
modifier := func(req *http.Request, key *models.APIKey) {
q := req.URL.Query()
q.Set("key", key.KeyValue)
req.URL.RawQuery = q.Encode()
}
proxy.ModifyResponse = func(resp *http.Response) error {
// Log the response, etc.
return nil
}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
logrus.Errorf("Proxy error to Gemini: %v", err)
// Handle error, maybe update key status
}
proxy.ServeHTTP(c.Writer, c.Request)
return ch.ProcessRequest(c, apiKey, modifier)
}

View File

@@ -2,17 +2,15 @@ package channel
import (
"encoding/json"
"fmt"
"gpt-load/internal/models"
"net/http"
"net/http/httputil"
"net/url"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
type OpenAIChannel struct {
BaseURL *url.URL
BaseChannel
}
type OpenAIChannelConfig struct {
@@ -22,34 +20,29 @@ type OpenAIChannelConfig struct {
func NewOpenAIChannel(group *models.Group) (*OpenAIChannel, error) {
var config OpenAIChannelConfig
if err := json.Unmarshal([]byte(group.Config), &config); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal channel config: %w", err)
}
if config.BaseURL == "" {
return nil, fmt.Errorf("base_url is required for openai channel")
}
baseURL, err := url.Parse(config.BaseURL)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to parse base_url: %w", err)
}
return &OpenAIChannel{BaseURL: baseURL}, nil
return &OpenAIChannel{
BaseChannel: BaseChannel{
Name: "openai",
BaseURL: baseURL,
HTTPClient: &http.Client{},
},
}, nil
}
func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) {
proxy := httputil.NewSingleHostReverseProxy(ch.BaseURL)
proxy.Director = func(req *http.Request) {
req.URL.Scheme = ch.BaseURL.Scheme
req.URL.Host = ch.BaseURL.Host
req.URL.Path = c.Param("path")
req.Host = ch.BaseURL.Host
req.Header.Set("Authorization", "Bearer "+apiKey.KeyValue)
func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error {
modifier := func(req *http.Request, key *models.APIKey) {
req.Header.Set("Authorization", "Bearer "+key.KeyValue)
}
proxy.ModifyResponse = func(resp *http.Response) error {
// Log the response, etc.
return nil
}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
logrus.Errorf("Proxy error: %v", err)
// Handle error, maybe update key status
}
proxy.ServeHTTP(c.Writer, c.Request)
return ch.ProcessRequest(c, apiKey, modifier)
}

View File

@@ -59,10 +59,18 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) {
}
// 4. Forward the request using the channel handler
channelHandler.Handle(c, apiKey, &group)
err = channelHandler.Handle(c, apiKey, &group)
// 5. Log the request asynchronously
go ps.logRequest(c, &group, apiKey, startTime)
isSuccess := err == nil
if !isSuccess {
logrus.WithFields(logrus.Fields{
"group": group.Name,
"key_id": apiKey.ID,
"error": err.Error(),
}).Error("Channel handler failed")
}
go ps.logRequest(c, &group, apiKey, startTime, isSuccess)
}
// selectAPIKey selects an API key from a group using round-robin
@@ -89,9 +97,8 @@ func (ps *ProxyServer) selectAPIKey(group *models.Group) (*models.APIKey, error)
return &selectedKey, nil
}
func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time) {
func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time, isSuccess bool) {
// Update key stats based on request success
isSuccess := c.Writer.Status() < 400
go ps.updateKeyStats(key.ID, isSuccess)
logEntry := models.RequestLog{