feat: 添加基础通道实现和请求处理逻辑
This commit is contained in:
122
internal/channel/base_channel.go
Normal file
122
internal/channel/base_channel.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
package channel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"fmt"
|
||||||
|
"gpt-load/internal/models"
|
||||||
|
"gpt-load/internal/response"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestModifier defines a function that can modify the upstream request,
|
||||||
|
// for example, by adding authentication headers.
|
||||||
|
type RequestModifier func(req *http.Request, apiKey *models.APIKey)
|
||||||
|
|
||||||
|
// BaseChannel provides a foundation for specific channel implementations.
|
||||||
|
type BaseChannel struct {
|
||||||
|
Name string
|
||||||
|
BaseURL *url.URL
|
||||||
|
HTTPClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessRequest handles the generic logic of creating, sending, and handling an upstream request.
|
||||||
|
func (ch *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier) error {
|
||||||
|
// 1. Create the upstream request
|
||||||
|
req, err := ch.createUpstreamRequest(c, apiKey, modifier)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusInternalServerError, "Failed to create upstream request")
|
||||||
|
return fmt.Errorf("create upstream request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Send the request
|
||||||
|
resp, err := ch.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Upstream service unavailable")
|
||||||
|
return fmt.Errorf("upstream request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 3. Handle non-200 status codes
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
errorMsg := ch.getErrorMessage(resp)
|
||||||
|
response.Error(c, resp.StatusCode, errorMsg)
|
||||||
|
return fmt.Errorf("upstream returned status %d: %s", resp.StatusCode, errorMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Stream the successful response back to the client
|
||||||
|
for key, values := range resp.Header {
|
||||||
|
for _, value := range values {
|
||||||
|
c.Header(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorf("Failed to copy response body to client: %v", err)
|
||||||
|
return fmt.Errorf("copy response body failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *BaseChannel) createUpstreamRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier) (*http.Request, error) {
|
||||||
|
targetURL := *ch.BaseURL
|
||||||
|
targetURL.Path = c.Param("path")
|
||||||
|
|
||||||
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||||
|
}
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewBuffer(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create new request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header = c.Request.Header.Clone()
|
||||||
|
req.Host = ch.BaseURL.Host
|
||||||
|
|
||||||
|
// Apply the channel-specific modifications
|
||||||
|
if modifier != nil {
|
||||||
|
modifier(req, apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *BaseChannel) getErrorMessage(resp *http.Response) string {
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("HTTP %d (failed to read error body: %v)", resp.StatusCode, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errorMessage string
|
||||||
|
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||||||
|
reader, gErr := gzip.NewReader(bytes.NewReader(bodyBytes))
|
||||||
|
if gErr != nil {
|
||||||
|
return string(bodyBytes)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
uncompressedBytes, rErr := io.ReadAll(reader)
|
||||||
|
if rErr != nil {
|
||||||
|
return fmt.Sprintf("gzip read error: %v", rErr)
|
||||||
|
}
|
||||||
|
errorMessage = string(uncompressedBytes)
|
||||||
|
} else {
|
||||||
|
errorMessage = string(bodyBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(errorMessage) == "" {
|
||||||
|
return fmt.Sprintf("HTTP %d: %s", resp.StatusCode, http.StatusText(resp.StatusCode))
|
||||||
|
}
|
||||||
|
|
||||||
|
return errorMessage
|
||||||
|
}
|
@@ -10,5 +10,5 @@ import (
|
|||||||
type ChannelProxy interface {
|
type ChannelProxy interface {
|
||||||
// Handle takes a context, an API key, and the original request,
|
// Handle takes a context, an API key, and the original request,
|
||||||
// then forwards the request to the upstream service.
|
// then forwards the request to the upstream service.
|
||||||
Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group)
|
Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error
|
||||||
}
|
}
|
@@ -1,55 +1,50 @@
|
|||||||
package channel
|
package channel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gpt-load/internal/models"
|
"gpt-load/internal/models"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const GeminiBaseURL = "https://generativelanguage.googleapis.com"
|
|
||||||
|
|
||||||
type GeminiChannel struct {
|
type GeminiChannel struct {
|
||||||
BaseURL *url.URL
|
BaseChannel
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChannelConfig struct {
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGeminiChannel(group *models.Group) (*GeminiChannel, error) {
|
func NewGeminiChannel(group *models.Group) (*GeminiChannel, error) {
|
||||||
baseURL, err := url.Parse(GeminiBaseURL)
|
var config GeminiChannelConfig
|
||||||
|
if err := json.Unmarshal([]byte(group.Config), &config); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal channel config: %w", err)
|
||||||
|
}
|
||||||
|
if config.BaseURL == "" {
|
||||||
|
return nil, fmt.Errorf("base_url is required for gemini channel")
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL, err := url.Parse(config.BaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err // Should not happen with a constant
|
return nil, fmt.Errorf("failed to parse base_url: %w", err)
|
||||||
}
|
|
||||||
return &GeminiChannel{BaseURL: baseURL}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) {
|
return &GeminiChannel{
|
||||||
proxy := httputil.NewSingleHostReverseProxy(ch.BaseURL)
|
BaseChannel: BaseChannel{
|
||||||
|
Name: "gemini",
|
||||||
proxy.Director = func(req *http.Request) {
|
BaseURL: baseURL,
|
||||||
// Gemini API key is passed as a query parameter
|
HTTPClient: &http.Client{},
|
||||||
originalPath := c.Param("path")
|
},
|
||||||
newPath := fmt.Sprintf("%s?key=%s", originalPath, apiKey.KeyValue)
|
}, nil
|
||||||
|
|
||||||
req.URL.Scheme = ch.BaseURL.Scheme
|
|
||||||
req.URL.Host = ch.BaseURL.Host
|
|
||||||
req.URL.Path = newPath
|
|
||||||
req.Host = ch.BaseURL.Host
|
|
||||||
// Remove the Authorization header if it was passed by the client
|
|
||||||
req.Header.Del("Authorization")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error {
|
||||||
// Log the response, etc.
|
modifier := func(req *http.Request, key *models.APIKey) {
|
||||||
return nil
|
q := req.URL.Query()
|
||||||
|
q.Set("key", key.KeyValue)
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
}
|
}
|
||||||
|
return ch.ProcessRequest(c, apiKey, modifier)
|
||||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
|
||||||
logrus.Errorf("Proxy error to Gemini: %v", err)
|
|
||||||
// Handle error, maybe update key status
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy.ServeHTTP(c.Writer, c.Request)
|
|
||||||
}
|
}
|
@@ -2,17 +2,15 @@ package channel
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"gpt-load/internal/models"
|
"gpt-load/internal/models"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIChannel struct {
|
type OpenAIChannel struct {
|
||||||
BaseURL *url.URL
|
BaseChannel
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIChannelConfig struct {
|
type OpenAIChannelConfig struct {
|
||||||
@@ -22,34 +20,29 @@ type OpenAIChannelConfig struct {
|
|||||||
func NewOpenAIChannel(group *models.Group) (*OpenAIChannel, error) {
|
func NewOpenAIChannel(group *models.Group) (*OpenAIChannel, error) {
|
||||||
var config OpenAIChannelConfig
|
var config OpenAIChannelConfig
|
||||||
if err := json.Unmarshal([]byte(group.Config), &config); err != nil {
|
if err := json.Unmarshal([]byte(group.Config), &config); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to unmarshal channel config: %w", err)
|
||||||
}
|
}
|
||||||
|
if config.BaseURL == "" {
|
||||||
|
return nil, fmt.Errorf("base_url is required for openai channel")
|
||||||
|
}
|
||||||
|
|
||||||
baseURL, err := url.Parse(config.BaseURL)
|
baseURL, err := url.Parse(config.BaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to parse base_url: %w", err)
|
||||||
}
|
|
||||||
return &OpenAIChannel{BaseURL: baseURL}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) {
|
return &OpenAIChannel{
|
||||||
proxy := httputil.NewSingleHostReverseProxy(ch.BaseURL)
|
BaseChannel: BaseChannel{
|
||||||
proxy.Director = func(req *http.Request) {
|
Name: "openai",
|
||||||
req.URL.Scheme = ch.BaseURL.Scheme
|
BaseURL: baseURL,
|
||||||
req.URL.Host = ch.BaseURL.Host
|
HTTPClient: &http.Client{},
|
||||||
req.URL.Path = c.Param("path")
|
},
|
||||||
req.Host = ch.BaseURL.Host
|
}, nil
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey.KeyValue)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error {
|
||||||
// Log the response, etc.
|
modifier := func(req *http.Request, key *models.APIKey) {
|
||||||
return nil
|
req.Header.Set("Authorization", "Bearer "+key.KeyValue)
|
||||||
}
|
}
|
||||||
|
return ch.ProcessRequest(c, apiKey, modifier)
|
||||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
|
||||||
logrus.Errorf("Proxy error: %v", err)
|
|
||||||
// Handle error, maybe update key status
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy.ServeHTTP(c.Writer, c.Request)
|
|
||||||
}
|
}
|
@@ -59,10 +59,18 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 4. Forward the request using the channel handler
|
// 4. Forward the request using the channel handler
|
||||||
channelHandler.Handle(c, apiKey, &group)
|
err = channelHandler.Handle(c, apiKey, &group)
|
||||||
|
|
||||||
// 5. Log the request asynchronously
|
// 5. Log the request asynchronously
|
||||||
go ps.logRequest(c, &group, apiKey, startTime)
|
isSuccess := err == nil
|
||||||
|
if !isSuccess {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"group": group.Name,
|
||||||
|
"key_id": apiKey.ID,
|
||||||
|
"error": err.Error(),
|
||||||
|
}).Error("Channel handler failed")
|
||||||
|
}
|
||||||
|
go ps.logRequest(c, &group, apiKey, startTime, isSuccess)
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectAPIKey selects an API key from a group using round-robin
|
// selectAPIKey selects an API key from a group using round-robin
|
||||||
@@ -89,9 +97,8 @@ func (ps *ProxyServer) selectAPIKey(group *models.Group) (*models.APIKey, error)
|
|||||||
return &selectedKey, nil
|
return &selectedKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time) {
|
func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time, isSuccess bool) {
|
||||||
// Update key stats based on request success
|
// Update key stats based on request success
|
||||||
isSuccess := c.Writer.Status() < 400
|
|
||||||
go ps.updateKeyStats(key.ID, isSuccess)
|
go ps.updateKeyStats(key.ID, isSuccess)
|
||||||
|
|
||||||
logEntry := models.RequestLog{
|
logEntry := models.RequestLog{
|
||||||
|
Reference in New Issue
Block a user