From 80662af9deeaf20a4a65b95ef36488e82ed73d18 Mon Sep 17 00:00:00 2001 From: tbphp Date: Thu, 3 Jul 2025 16:29:12 +0800 Subject: [PATCH] feat: group api --- go.mod | 1 + go.sum | 26 ++++ internal/channel/base_channel.go | 203 ++++++++++++++++------------- internal/channel/factory.go | 30 ++++- internal/channel/gemini_channel.go | 30 +---- internal/channel/openai_channel.go | 30 +---- internal/handler/group_handler.go | 68 ++++++++-- internal/models/types.go | 56 ++++++-- 8 files changed, 288 insertions(+), 156 deletions(-) diff --git a/go.mod b/go.mod index 5976584..19bc47d 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/gin-gonic/gin v1.10.1 github.com/joho/godotenv v1.5.1 github.com/sirupsen/logrus v1.9.3 + gorm.io/datatypes v1.2.1 gorm.io/driver/mysql v1.6.0 gorm.io/gorm v1.30.0 ) diff --git a/go.sum b/go.sum index 18f6de4..116e32c 100644 --- a/go.sum +++ b/go.sum @@ -34,9 +34,21 @@ github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpv github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= +github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= +github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= +github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= @@ -57,6 +69,10 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= +github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE= +github.com/microsoft/go-mssqldb v0.17.0/go.mod h1:OkoNGhGEs8EZqchVTtochlXruEhEOaO4S0d2sB5aeGQ= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -90,6 +106,8 @@ golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= @@ -104,8 +122,16 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/datatypes v1.2.1 h1:r+g0bk4LPCW2v4+Ls7aeNgGme7JYdNDQ2VtvlNUfBh0= +gorm.io/datatypes v1.2.1/go.mod h1:hYK6OTb/1x+m96PgoZZq10UXJ6RvEBb9kRDQ2yyhzGs= gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg= gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo= +gorm.io/driver/postgres v1.5.0 h1:u2FXTy14l45qc3UeCJ7QaAXZmZfDDv0YrthvmRq1l0U= +gorm.io/driver/postgres v1.5.0/go.mod h1:FUZXzO+5Uqg5zzwzv4KK49R8lvGIyscBOqYrtI1Ce9A= +gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= +gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= +gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0= +gorm.io/driver/sqlserver v1.4.1/go.mod h1:DJ4P+MeZbc5rvY58PnmN1Lnyvb5gw5NPzGshHDnJLig= gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= diff --git a/internal/channel/base_channel.go b/internal/channel/base_channel.go index 8e0d7ee..5f3bed9 100644 --- a/internal/channel/base_channel.go +++ b/internal/channel/base_channel.go @@ -1,122 +1,151 @@ package channel import ( - "bytes" - "compress/gzip" "fmt" "gpt-load/internal/models" - "gpt-load/internal/response" "io" "net/http" + "net/http/httputil" "net/url" "strings" + "sync/atomic" "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" "github.com/sirupsen/logrus" ) -// RequestModifier defines a function that can modify the upstream request, -// for example, by adding authentication headers. -type RequestModifier func(req *http.Request, apiKey *models.APIKey) - -// BaseChannel provides a foundation for specific channel implementations. +// BaseChannel provides common functionality for channel proxies. type BaseChannel struct { - Name string - BaseURL *url.URL - HTTPClient *http.Client + Name string + Upstreams []*url.URL + HTTPClient *http.Client + roundRobin uint64 } -// ProcessRequest handles the generic logic of creating, sending, and handling an upstream request. -func (ch *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier) error { - // 1. Create the upstream request - req, err := ch.createUpstreamRequest(c, apiKey, modifier) - if err != nil { - response.Error(c, http.StatusInternalServerError, "Failed to create upstream request") - return fmt.Errorf("create upstream request failed: %w", err) +// RequestModifier is a function that can modify the request before it's sent. +type RequestModifier func(req *http.Request, key *models.APIKey) + +// getUpstreamURL selects an upstream URL using round-robin. +func (b *BaseChannel) getUpstreamURL() *url.URL { + if len(b.Upstreams) == 0 { + return nil + } + if len(b.Upstreams) == 1 { + return b.Upstreams[0] + } + index := atomic.AddUint64(&b.roundRobin, 1) - 1 + return b.Upstreams[index%uint64(len(b.Upstreams))] +} + +// ProcessRequest handles the common logic of processing and forwarding a request. +func (b *BaseChannel) ProcessRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier) error { + upstreamURL := b.getUpstreamURL() + if upstreamURL == nil { + return fmt.Errorf("no upstream URL configured for channel %s", b.Name) } - // 2. Send the request - resp, err := ch.HTTPClient.Do(req) - if err != nil { - response.Error(c, http.StatusServiceUnavailable, "Upstream service unavailable") - return fmt.Errorf("upstream request failed: %w", err) - } - defer resp.Body.Close() + director := func(req *http.Request) { + req.URL.Scheme = upstreamURL.Scheme + req.URL.Host = upstreamURL.Host + req.URL.Path = singleJoiningSlash(upstreamURL.Path, req.URL.Path) + req.Host = upstreamURL.Host - // 3. Handle non-200 status codes - if resp.StatusCode != http.StatusOK { - errorMsg := ch.getErrorMessage(resp) - response.Error(c, resp.StatusCode, errorMsg) - return fmt.Errorf("upstream returned status %d: %s", resp.StatusCode, errorMsg) - } - - // 4. Stream the successful response back to the client - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) + // Apply the channel-specific modifications + if modifier != nil { + modifier(req, apiKey) } + + // Remove headers that should not be forwarded + req.Header.Del("Cookie") + req.Header.Del("X-Real-Ip") + req.Header.Del("X-Forwarded-For") } - c.Status(http.StatusOK) - _, err = io.Copy(c.Writer, resp.Body) + + errorHandler := func(rw http.ResponseWriter, req *http.Request, err error) { + logrus.WithFields(logrus.Fields{ + "channel": b.Name, + "key_id": apiKey.ID, + "error": err, + }).Error("HTTP proxy error") + rw.WriteHeader(http.StatusBadGateway) + } + + proxy := &httputil.ReverseProxy{ + Director: director, + ErrorHandler: errorHandler, + Transport: b.HTTPClient.Transport, + } + + // Check if the client request is for a streaming endpoint + if isStreamingRequest(c) { + return b.handleStreaming(c, proxy) + } + + proxy.ServeHTTP(c.Writer, c.Request) + return nil +} + +func (b *BaseChannel) handleStreaming(c *gin.Context, proxy *httputil.ReverseProxy) error { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + + // Use a pipe to avoid buffering the entire response + pr, pw := io.Pipe() + defer pr.Close() + + // Create a new request with the pipe reader as the body + // This is a bit of a hack to get ReverseProxy to stream + req := c.Request.Clone(c.Request.Context()) + req.Body = pr + + // Start the proxy in a goroutine + go func() { + defer pw.Close() + proxy.ServeHTTP(c.Writer, req) + }() + + // Copy the original request body to the pipe writer + _, err := io.Copy(pw, c.Request.Body) if err != nil { - logrus.Errorf("Failed to copy response body to client: %v", err) - return fmt.Errorf("copy response body failed: %w", err) + logrus.Errorf("Error copying request body to pipe: %v", err) + return err } return nil } -func (ch *BaseChannel) createUpstreamRequest(c *gin.Context, apiKey *models.APIKey, modifier RequestModifier) (*http.Request, error) { - targetURL := *ch.BaseURL - targetURL.Path = c.Param("path") - - body, err := io.ReadAll(c.Request.Body) - if err != nil { - return nil, fmt.Errorf("failed to read request body: %w", err) - } - c.Request.Body = io.NopCloser(bytes.NewBuffer(body)) - - req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewBuffer(body)) - if err != nil { - return nil, fmt.Errorf("failed to create new request: %w", err) +// isStreamingRequest checks if the request is for a streaming response. +func isStreamingRequest(c *gin.Context) bool { + // For Gemini, streaming is indicated by the path. + if strings.Contains(c.Request.URL.Path, ":streamGenerateContent") { + return true } - req.Header = c.Request.Header.Clone() - req.Host = ch.BaseURL.Host - - // Apply the channel-specific modifications - if modifier != nil { - modifier(req, apiKey) + // For OpenAI, streaming is indicated by a "stream": true field in the JSON body. + // We use ShouldBindBodyWith to check the body without consuming it, so it can be read again by the proxy. + type streamPayload struct { + Stream bool `json:"stream"` + } + var p streamPayload + if err := c.ShouldBindBodyWith(&p, binding.JSON); err == nil { + return p.Stream } - return req, nil + return false } -func (ch *BaseChannel) getErrorMessage(resp *http.Response) string { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Sprintf("HTTP %d (failed to read error body: %v)", resp.StatusCode, err) +// singleJoiningSlash joins two URL paths with a single slash. +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b } - - var errorMessage string - if resp.Header.Get("Content-Encoding") == "gzip" { - reader, gErr := gzip.NewReader(bytes.NewReader(bodyBytes)) - if gErr != nil { - return string(bodyBytes) - } - defer reader.Close() - uncompressedBytes, rErr := io.ReadAll(reader) - if rErr != nil { - return fmt.Sprintf("gzip read error: %v", rErr) - } - errorMessage = string(uncompressedBytes) - } else { - errorMessage = string(bodyBytes) - } - - if strings.TrimSpace(errorMessage) == "" { - return fmt.Sprintf("HTTP %d: %s", resp.StatusCode, http.StatusText(resp.StatusCode)) - } - - return errorMessage -} \ No newline at end of file + return a + b +} diff --git a/internal/channel/factory.go b/internal/channel/factory.go index 3f4deb1..611149a 100644 --- a/internal/channel/factory.go +++ b/internal/channel/factory.go @@ -3,16 +3,40 @@ package channel import ( "fmt" "gpt-load/internal/models" + "net/http" + "net/url" ) // GetChannel returns a channel proxy based on the group's channel type. func GetChannel(group *models.Group) (ChannelProxy, error) { switch group.ChannelType { case "openai": - return NewOpenAIChannel(group) + return NewOpenAIChannel(group.Upstreams) case "gemini": - return NewGeminiChannel(group) + return NewGeminiChannel(group.Upstreams) default: return nil, fmt.Errorf("unsupported channel type: %s", group.ChannelType) } -} \ No newline at end of file +} + +// newBaseChannelWithUpstreams is a helper function to create and configure a BaseChannel. +func newBaseChannelWithUpstreams(name string, upstreams []string) (BaseChannel, error) { + if len(upstreams) == 0 { + return BaseChannel{}, fmt.Errorf("at least one upstream is required for %s channel", name) + } + + var upstreamURLs []*url.URL + for _, us := range upstreams { + u, err := url.Parse(us) + if err != nil { + return BaseChannel{}, fmt.Errorf("failed to parse upstream url '%s' for %s channel: %w", us, name, err) + } + upstreamURLs = append(upstreamURLs, u) + } + + return BaseChannel{ + Name: name, + Upstreams: upstreamURLs, + HTTPClient: &http.Client{}, + }, nil +} diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index a90bda1..37c57e4 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -1,42 +1,24 @@ package channel import ( - "encoding/json" - "fmt" "gpt-load/internal/models" "net/http" - "net/url" "github.com/gin-gonic/gin" ) + type GeminiChannel struct { BaseChannel } -type GeminiChannelConfig struct { - BaseURL string `json:"base_url"` -} - -func NewGeminiChannel(group *models.Group) (*GeminiChannel, error) { - var config GeminiChannelConfig - if err := json.Unmarshal([]byte(group.Config), &config); err != nil { - return nil, fmt.Errorf("failed to unmarshal channel config: %w", err) - } - if config.BaseURL == "" { - return nil, fmt.Errorf("base_url is required for gemini channel") - } - - baseURL, err := url.Parse(config.BaseURL) +func NewGeminiChannel(upstreams []string) (*GeminiChannel, error) { + base, err := newBaseChannelWithUpstreams("gemini", upstreams) if err != nil { - return nil, fmt.Errorf("failed to parse base_url: %w", err) + return nil, err } return &GeminiChannel{ - BaseChannel: BaseChannel{ - Name: "gemini", - BaseURL: baseURL, - HTTPClient: &http.Client{}, - }, + BaseChannel: base, }, nil } @@ -47,4 +29,4 @@ func (ch *GeminiChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo req.URL.RawQuery = q.Encode() } return ch.ProcessRequest(c, apiKey, modifier) -} \ No newline at end of file +} diff --git a/internal/channel/openai_channel.go b/internal/channel/openai_channel.go index ec9559a..da6ae5a 100644 --- a/internal/channel/openai_channel.go +++ b/internal/channel/openai_channel.go @@ -1,42 +1,24 @@ package channel import ( - "encoding/json" - "fmt" "gpt-load/internal/models" "net/http" - "net/url" "github.com/gin-gonic/gin" ) + type OpenAIChannel struct { BaseChannel } -type OpenAIChannelConfig struct { - BaseURL string `json:"base_url"` -} - -func NewOpenAIChannel(group *models.Group) (*OpenAIChannel, error) { - var config OpenAIChannelConfig - if err := json.Unmarshal([]byte(group.Config), &config); err != nil { - return nil, fmt.Errorf("failed to unmarshal channel config: %w", err) - } - if config.BaseURL == "" { - return nil, fmt.Errorf("base_url is required for openai channel") - } - - baseURL, err := url.Parse(config.BaseURL) +func NewOpenAIChannel(upstreams []string) (*OpenAIChannel, error) { + base, err := newBaseChannelWithUpstreams("openai", upstreams) if err != nil { - return nil, fmt.Errorf("failed to parse base_url: %w", err) + return nil, err } return &OpenAIChannel{ - BaseChannel: BaseChannel{ - Name: "openai", - BaseURL: baseURL, - HTTPClient: &http.Client{}, - }, + BaseChannel: base, }, nil } @@ -45,4 +27,4 @@ func (ch *OpenAIChannel) Handle(c *gin.Context, apiKey *models.APIKey, group *mo req.Header.Set("Authorization", "Bearer "+key.KeyValue) } return ch.ProcessRequest(c, apiKey, modifier) -} \ No newline at end of file +} diff --git a/internal/handler/group_handler.go b/internal/handler/group_handler.go index fd27ae9..5075278 100644 --- a/internal/handler/group_handler.go +++ b/internal/handler/group_handler.go @@ -2,6 +2,7 @@ package handler import ( + "encoding/json" "gpt-load/internal/models" "gpt-load/internal/response" "net/http" @@ -18,6 +19,20 @@ func (s *Server) CreateGroup(c *gin.Context) { return } + // Validation + if group.Name == "" { + response.Error(c, http.StatusBadRequest, "Group name is required") + return + } + if len(group.Upstreams) == 0 { + response.Error(c, http.StatusBadRequest, "At least one upstream is required") + return + } + if group.ChannelType == "" { + response.Error(c, http.StatusBadRequest, "Channel type is required") + return + } + if err := s.DB.Create(&group).Error; err != nil { response.Error(c, http.StatusInternalServerError, "Failed to create group") return @@ -29,7 +44,7 @@ func (s *Server) CreateGroup(c *gin.Context) { // ListGroups handles listing all groups. func (s *Server) ListGroups(c *gin.Context) { var groups []models.Group - if err := s.DB.Find(&groups).Error; err != nil { + if err := s.DB.Order("sort asc, id desc").Find(&groups).Error; err != nil { response.Error(c, http.StatusInternalServerError, "Failed to list groups") return } @@ -73,18 +88,42 @@ func (s *Server) UpdateGroup(c *gin.Context) { return } - // We only allow updating certain fields - group.Name = updateData.Name - group.Description = updateData.Description - group.ChannelType = updateData.ChannelType - group.Config = updateData.Config + // Use a transaction to ensure atomicity + tx := s.DB.Begin() + if tx.Error != nil { + response.Error(c, http.StatusInternalServerError, "Failed to start transaction") + return + } - if err := s.DB.Save(&group).Error; err != nil { + // Convert updateData to a map to ensure zero values (like Sort: 0) are updated + var updateMap map[string]interface{} + updateBytes, _ := json.Marshal(updateData) + if err := json.Unmarshal(updateBytes, &updateMap); err != nil { + response.Error(c, http.StatusBadRequest, "Failed to process update data") + return + } + + // Use Updates with a map to only update provided fields, including zero values + if err := tx.Model(&group).Updates(updateMap).Error; err != nil { + tx.Rollback() response.Error(c, http.StatusInternalServerError, "Failed to update group") return } - response.Success(c, group) + if err := tx.Commit().Error; err != nil { + tx.Rollback() + response.Error(c, http.StatusInternalServerError, "Failed to commit transaction") + return + } + + // Re-fetch the group to return the updated data + var updatedGroup models.Group + if err := s.DB.Preload("APIKeys").First(&updatedGroup, id).Error; err != nil { + response.Error(c, http.StatusNotFound, "Failed to fetch updated group data") + return + } + + response.Success(c, updatedGroup) } // DeleteGroup handles deleting a group. @@ -101,6 +140,11 @@ func (s *Server) DeleteGroup(c *gin.Context) { response.Error(c, http.StatusInternalServerError, "Failed to start transaction") return } + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() // Also delete associated API keys if err := tx.Where("group_id = ?", id).Delete(&models.APIKey{}).Error; err != nil { @@ -109,10 +153,14 @@ func (s *Server) DeleteGroup(c *gin.Context) { return } - if err := tx.Delete(&models.Group{}, id).Error; err != nil { + if result := tx.Delete(&models.Group{}, id); result.Error != nil { tx.Rollback() response.Error(c, http.StatusInternalServerError, "Failed to delete group") return + } else if result.RowsAffected == 0 { + tx.Rollback() + response.Error(c, http.StatusNotFound, "Group not found") + return } if err := tx.Commit().Error; err != nil { @@ -122,4 +170,4 @@ func (s *Server) DeleteGroup(c *gin.Context) { } response.Success(c, gin.H{"message": "Group and associated keys deleted successfully"}) -} \ No newline at end of file +} diff --git a/internal/models/types.go b/internal/models/types.go index cd000ae..12bf19f 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -1,7 +1,12 @@ package models import ( + "database/sql/driver" + "encoding/json" + "errors" "time" + + "gorm.io/datatypes" ) // SystemSetting 对应 system_settings 表 @@ -14,16 +19,51 @@ type SystemSetting struct { UpdatedAt time.Time `json:"updated_at"` } +// Upstreams 是一个上游地址的切片,可以被 GORM 正确处理 +type Upstreams []string + +// Value 实现 driver.Valuer 接口,用于将 Upstreams 类型转换为数据库值 +func (u Upstreams) Value() (driver.Value, error) { + if len(u) == 0 { + return "[]", nil + } + return json.Marshal(u) +} + +// Scan 实现 sql.Scanner 接口,用于将数据库值扫描到 Upstreams 类型 +func (u *Upstreams) Scan(value interface{}) error { + bytes, ok := value.([]byte) + if !ok { + return errors.New("type assertion to []byte failed") + } + return json.Unmarshal(bytes, u) +} + +// GroupConfig 存储特定于分组的配置 +type GroupConfig struct { + BlacklistThreshold *int `json:"blacklist_threshold,omitempty"` + MaxRetries *int `json:"max_retries,omitempty"` + ServerReadTimeout *int `json:"server_read_timeout,omitempty"` + ServerWriteTimeout *int `json:"server_write_timeout,omitempty"` + ServerIdleTimeout *int `json:"server_idle_timeout,omitempty"` + ServerGracefulShutdownTimeout *int `json:"server_graceful_shutdown_timeout,omitempty"` + RequestTimeout *int `json:"request_timeout,omitempty"` + ResponseTimeout *int `json:"response_timeout,omitempty"` + IdleConnTimeout *int `json:"idle_conn_timeout,omitempty"` +} + // Group 对应 groups 表 type Group struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` - Name string `gorm:"type:varchar(255);not null;unique" json:"name"` - Description string `gorm:"type:varchar(512)" json:"description"` - ChannelType string `gorm:"type:varchar(50);not null" json:"channel_type"` - Config string `gorm:"type:json" json:"config"` - APIKeys []APIKey `gorm:"foreignKey:GroupID" json:"api_keys"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + Name string `gorm:"type:varchar(255);not null;unique" json:"name"` + Description string `gorm:"type:varchar(512)" json:"description"` + Upstreams Upstreams `gorm:"type:json;not null" json:"upstreams"` + ChannelType string `gorm:"type:varchar(50);not null" json:"channel_type"` + Sort int `gorm:"default:0" json:"sort"` + Config datatypes.JSONMap `gorm:"type:json" json:"config"` + APIKeys []APIKey `gorm:"foreignKey:GroupID" json:"api_keys"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // APIKey 对应 api_keys 表