From 762dfe48e816f90cf73fbfad938397fea5885ce1 Mon Sep 17 00:00:00 2001 From: tbphp Date: Wed, 2 Jul 2025 09:46:08 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=9F=BA=E7=A1=80?= =?UTF-8?q?=E9=80=9A=E9=81=93=E5=AE=9E=E7=8E=B0=E5=92=8C=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/channel/base_channel.go | 122 +++++++++++++++++++++++++++++ internal/channel/channel.go | 2 +- internal/channel/gemini_channel.go | 67 ++++++++-------- internal/channel/openai_channel.go | 47 +++++------ internal/proxy/server.go | 15 +++- 5 files changed, 185 insertions(+), 68 deletions(-) create mode 100644 internal/channel/base_channel.go diff --git a/internal/channel/base_channel.go b/internal/channel/base_channel.go new file mode 100644 index 0000000..8e0d7ee --- /dev/null +++ b/internal/channel/base_channel.go @@ -0,0 +1,122 @@ +package channel + +import ( + "bytes" + "compress/gzip" + "fmt" + "gpt-load/internal/models" + "gpt-load/internal/response" + "io" + "net/http" + "net/url" + "strings" + + "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" +) + +// RequestModifier defines a function that can modify the upstream request, +// for example, by adding authentication headers. +type RequestModifier func(req *http.Request, apiKey *models.APIKey) + +// BaseChannel provides a foundation for specific channel implementations. +type BaseChannel struct { + Name string + BaseURL *url.URL + HTTPClient *http.Client +} + +// ProcessRequest handles the generic logic of creating, sending, and handling an upstream request. +func (ch *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier) error { + // 1. Create the upstream request + req, err := ch.createUpstreamRequest(c, apiKey, modifier) + if err != nil { + response.Error(c, http.StatusInternalServerError, "Failed to create upstream request") + return fmt.Errorf("create upstream request failed: %w", err) + } + + // 2. Send the request + resp, err := ch.HTTPClient.Do(req) + if err != nil { + response.Error(c, http.StatusServiceUnavailable, "Upstream service unavailable") + return fmt.Errorf("upstream request failed: %w", err) + } + defer resp.Body.Close() + + // 3. Handle non-200 status codes + if resp.StatusCode != http.StatusOK { + errorMsg := ch.getErrorMessage(resp) + response.Error(c, resp.StatusCode, errorMsg) + return fmt.Errorf("upstream returned status %d: %s", resp.StatusCode, errorMsg) + } + + // 4. Stream the successful response back to the client + for key, values := range resp.Header { + for _, value := range values { + c.Header(key, value) + } + } + c.Status(http.StatusOK) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + logrus.Errorf("Failed to copy response body to client: %v", err) + return fmt.Errorf("copy response body failed: %w", err) + } + + return nil +} + +func (ch *BaseChannel) createUpstreamRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier) (*http.Request, error) { + targetURL := *ch.BaseURL + targetURL.Path = c.Param("path") + + body, err := io.ReadAll(c.Request.Body) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(body)) + + req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewBuffer(body)) + if err != nil { + return nil, fmt.Errorf("failed to create new request: %w", err) + } + + req.Header = c.Request.Header.Clone() + req.Host = ch.BaseURL.Host + + // Apply the channel-specific modifications + if modifier != nil { + modifier(req, apiKey) + } + + return req, nil +} + +func (ch *BaseChannel) getErrorMessage(resp *http.Response) string { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Sprintf("HTTP %d (failed to read error body: %v)", resp.StatusCode, err) + } + + var errorMessage string + if resp.Header.Get("Content-Encoding") == "gzip" { + reader, gErr := gzip.NewReader(bytes.NewReader(bodyBytes)) + if gErr != nil { + return string(bodyBytes) + } + defer reader.Close() + uncompressedBytes, rErr := io.ReadAll(reader) + if rErr != nil { + return fmt.Sprintf("gzip read error: %v", rErr) + } + errorMessage = string(uncompressedBytes) + } else { + errorMessage = string(bodyBytes) + } + + if strings.TrimSpace(errorMessage) == "" { + return fmt.Sprintf("HTTP %d: %s", resp.StatusCode, http.StatusText(resp.StatusCode)) + } + + return errorMessage +} \ No newline at end of file diff --git a/internal/channel/channel.go b/internal/channel/channel.go index 71fc2e6..9db511a 100644 --- a/internal/channel/channel.go +++ b/internal/channel/channel.go @@ -10,5 +10,5 @@ import ( type ChannelProxy interface { // Handle takes a context, an API key, and the original request, // then forwards the request to the upstream service. - Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) + Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error } \ No newline at end of file diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index d46b3a7..a90bda1 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -1,55 +1,50 @@ package channel import ( + "encoding/json" "fmt" "gpt-load/internal/models" "net/http" - "net/http/httputil" "net/url" "github.com/gin-gonic/gin" - "github.com/sirupsen/logrus" ) - -const GeminiBaseURL = "https://generativelanguage.googleapis.com" - type GeminiChannel struct { - BaseURL *url.URL + BaseChannel +} + +type GeminiChannelConfig struct { + BaseURL string `json:"base_url"` } func NewGeminiChannel(group *models.Group) (*GeminiChannel, error) { - baseURL, err := url.Parse(GeminiBaseURL) - if err != nil { - return nil, err // Should not happen with a constant + var config GeminiChannelConfig + if err := json.Unmarshal([]byte(group.Config), &config); err != nil { + return nil, fmt.Errorf("failed to unmarshal channel config: %w", err) } - return &GeminiChannel{BaseURL: baseURL}, nil + if config.BaseURL == "" { + return nil, fmt.Errorf("base_url is required for gemini channel") + } + + baseURL, err := url.Parse(config.BaseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse base_url: %w", err) + } + + return &GeminiChannel{ + BaseChannel: BaseChannel{ + Name: "gemini", + BaseURL: baseURL, + HTTPClient: &http.Client{}, + }, + }, nil } -func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) { - proxy := httputil.NewSingleHostReverseProxy(ch.BaseURL) - - proxy.Director = func(req *http.Request) { - // Gemini API key is passed as a query parameter - originalPath := c.Param("path") - newPath := fmt.Sprintf("%s?key=%s", originalPath, apiKey.KeyValue) - - req.URL.Scheme = ch.BaseURL.Scheme - req.URL.Host = ch.BaseURL.Host - req.URL.Path = newPath - req.Host = ch.BaseURL.Host - // Remove the Authorization header if it was passed by the client - req.Header.Del("Authorization") +func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error { + modifier := func(req *http.Request, key *models.APIKey) { + q := req.URL.Query() + q.Set("key", key.KeyValue) + req.URL.RawQuery = q.Encode() } - - proxy.ModifyResponse = func(resp *http.Response) error { - // Log the response, etc. - return nil - } - - proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - logrus.Errorf("Proxy error to Gemini: %v", err) - // Handle error, maybe update key status - } - - proxy.ServeHTTP(c.Writer, c.Request) + return ch.ProcessRequest(c, apiKey, modifier) } \ No newline at end of file diff --git a/internal/channel/openai_channel.go b/internal/channel/openai_channel.go index 6f566f9..ec9559a 100644 --- a/internal/channel/openai_channel.go +++ b/internal/channel/openai_channel.go @@ -2,17 +2,15 @@ package channel import ( "encoding/json" + "fmt" "gpt-load/internal/models" "net/http" - "net/http/httputil" "net/url" "github.com/gin-gonic/gin" - "github.com/sirupsen/logrus" ) - type OpenAIChannel struct { - BaseURL *url.URL + BaseChannel } type OpenAIChannelConfig struct { @@ -22,34 +20,29 @@ type OpenAIChannelConfig struct { func NewOpenAIChannel(group *models.Group) (*OpenAIChannel, error) { var config OpenAIChannelConfig if err := json.Unmarshal([]byte(group.Config), &config); err != nil { - return nil, err + return nil, fmt.Errorf("failed to unmarshal channel config: %w", err) } + if config.BaseURL == "" { + return nil, fmt.Errorf("base_url is required for openai channel") + } + baseURL, err := url.Parse(config.BaseURL) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse base_url: %w", err) } - return &OpenAIChannel{BaseURL: baseURL}, nil + + return &OpenAIChannel{ + BaseChannel: BaseChannel{ + Name: "openai", + BaseURL: baseURL, + HTTPClient: &http.Client{}, + }, + }, nil } -func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) { - proxy := httputil.NewSingleHostReverseProxy(ch.BaseURL) - proxy.Director = func(req *http.Request) { - req.URL.Scheme = ch.BaseURL.Scheme - req.URL.Host = ch.BaseURL.Host - req.URL.Path = c.Param("path") - req.Host = ch.BaseURL.Host - req.Header.Set("Authorization", "Bearer "+apiKey.KeyValue) +func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *models.Group) error { + modifier := func(req *http.Request, key *models.APIKey) { + req.Header.Set("Authorization", "Bearer "+key.KeyValue) } - - proxy.ModifyResponse = func(resp *http.Response) error { - // Log the response, etc. - return nil - } - - proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - logrus.Errorf("Proxy error: %v", err) - // Handle error, maybe update key status - } - - proxy.ServeHTTP(c.Writer, c.Request) + return ch.ProcessRequest(c, apiKey, modifier) } \ No newline at end of file diff --git a/internal/proxy/server.go b/internal/proxy/server.go index c474501..9f07f9d 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -59,10 +59,18 @@ func (ps *ProxyServer) HandleProxy(c *gin.Context) { } // 4. Forward the request using the channel handler - channelHandler.Handle(c, apiKey, &group) + err = channelHandler.Handle(c, apiKey, &group) // 5. Log the request asynchronously - go ps.logRequest(c, &group, apiKey, startTime) + isSuccess := err == nil + if !isSuccess { + logrus.WithFields(logrus.Fields{ + "group": group.Name, + "key_id": apiKey.ID, + "error": err.Error(), + }).Error("Channel handler failed") + } + go ps.logRequest(c, &group, apiKey, startTime, isSuccess) } // selectAPIKey selects an API key from a group using round-robin @@ -89,9 +97,8 @@ func (ps *ProxyServer) selectAPIKey(group *models.Group) (*models.APIKey, error) return &selectedKey, nil } -func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time) { +func (ps *ProxyServer) logRequest(c *gin.Context, group *models.Group, key *models.APIKey, startTime time.Time, isSuccess bool) { // Update key stats based on request success - isSuccess := c.Writer.Status() < 400 go ps.updateKeyStats(key.ID, isSuccess) logEntry := models.RequestLog{