From 745c646530c35abd8c4923ac4d4d80d9361aa03c Mon Sep 17 00:00:00 2001 From: tbphp Date: Thu, 3 Jul 2025 21:18:43 +0800 Subject: [PATCH] fix: group api --- internal/handler/dashboard_handler.go | 6 +-- internal/handler/group_handler.go | 70 +++++++++++++++++++-------- internal/models/types.go | 3 +- internal/router/router.go | 1 - web/src/types/models.ts | 3 +- 5 files changed, 57 insertions(+), 26 deletions(-) diff --git a/internal/handler/dashboard_handler.go b/internal/handler/dashboard_handler.go index 82e28c6..bab3cc0 100644 --- a/internal/handler/dashboard_handler.go +++ b/internal/handler/dashboard_handler.go @@ -24,9 +24,9 @@ func (s *Server) Stats(c *gin.Context) { // 2. Get request counts per group s.DB.Table("api_keys"). - Select("groups.name as group_name, SUM(api_keys.request_count) as request_count"). + Select("groups.nickname as group_nickname, SUM(api_keys.request_count) as request_count"). Joins("join groups on groups.id = api_keys.group_id"). - Group("groups.name"). + Group("groups.id, groups.nickname"). Order("request_count DESC"). Scan(&groupStats) @@ -44,4 +44,4 @@ func (s *Server) Stats(c *gin.Context) { } response.Success(c, stats) -} \ No newline at end of file +} diff --git a/internal/handler/group_handler.go b/internal/handler/group_handler.go index 5075278..e9c7692 100644 --- a/internal/handler/group_handler.go +++ b/internal/handler/group_handler.go @@ -6,11 +6,22 @@ import ( "gpt-load/internal/models" "gpt-load/internal/response" "net/http" + "regexp" "strconv" "github.com/gin-gonic/gin" ) +// isValidGroupName checks if the group name is valid. +func isValidGroupName(name string) bool { + if name == "" { + return false + } + // 允许使用小写字母、数字和下划线,长度在 3 到 30 个字符之间 + match, _ := regexp.MatchString("^[a-z0-9_]{3,30}$", name) + return match +} + // CreateGroup handles the creation of a new group. func (s *Server) CreateGroup(c *gin.Context) { var group models.Group @@ -20,8 +31,8 @@ func (s *Server) CreateGroup(c *gin.Context) { } // Validation - if group.Name == "" { - response.Error(c, http.StatusBadRequest, "Group name is required") + if !isValidGroupName(group.Name) { + response.Error(c, http.StatusBadRequest, "Invalid group name format. Use lowercase letters and underscores, and do not start with an underscore.") return } if len(group.Upstreams) == 0 { @@ -51,23 +62,6 @@ func (s *Server) ListGroups(c *gin.Context) { response.Success(c, groups) } -// GetGroup handles getting a single group by its ID. -func (s *Server) GetGroup(c *gin.Context) { - id, err := strconv.Atoi(c.Param("id")) - if err != nil { - response.Error(c, http.StatusBadRequest, "Invalid group ID") - return - } - - var group models.Group - if err := s.DB.Preload("APIKeys").First(&group, id).Error; err != nil { - response.Error(c, http.StatusNotFound, "Group not found") - return - } - - response.Success(c, group) -} - // UpdateGroup handles updating an existing group. func (s *Server) UpdateGroup(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) @@ -88,6 +82,12 @@ func (s *Server) UpdateGroup(c *gin.Context) { return } + // Validate group name if it's being updated + if updateData.Name != "" && !isValidGroupName(updateData.Name) { + response.Error(c, http.StatusBadRequest, "Invalid group name format. Use lowercase letters and underscores, and do not start with an underscore.") + return + } + // Use a transaction to ensure atomicity tx := s.DB.Begin() if tx.Error != nil { @@ -103,6 +103,36 @@ func (s *Server) UpdateGroup(c *gin.Context) { return } + // If config is being updated, it needs to be marshalled to JSON string for GORM + if config, ok := updateMap["config"]; ok { + if configMap, isMap := config.(map[string]interface{}); isMap { + configJSON, err := json.Marshal(configMap) + if err != nil { + response.Error(c, http.StatusBadRequest, "Failed to process config data") + return + } + updateMap["config"] = string(configJSON) + } + } + + // Handle upstreams field specifically + if upstreams, ok := updateMap["upstreams"]; ok { + if upstreamsSlice, isSlice := upstreams.([]interface{}); isSlice { + upstreamsJSON, err := json.Marshal(upstreamsSlice) + if err != nil { + response.Error(c, http.StatusBadRequest, "Failed to process upstreams data") + return + } + updateMap["upstreams"] = string(upstreamsJSON) + } + } + + // Remove fields that are not actual columns or should not be updated from the map + delete(updateMap, "id") + delete(updateMap, "api_keys") + delete(updateMap, "created_at") + delete(updateMap, "updated_at") + // Use Updates with a map to only update provided fields, including zero values if err := tx.Model(&group).Updates(updateMap).Error; err != nil { tx.Rollback() @@ -118,7 +148,7 @@ func (s *Server) UpdateGroup(c *gin.Context) { // Re-fetch the group to return the updated data var updatedGroup models.Group - if err := s.DB.Preload("APIKeys").First(&updatedGroup, id).Error; err != nil { + if err := s.DB.First(&updatedGroup, id).Error; err != nil { response.Error(c, http.StatusNotFound, "Failed to fetch updated group data") return } diff --git a/internal/models/types.go b/internal/models/types.go index 12bf19f..e047f0a 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -56,6 +56,7 @@ type GroupConfig struct { type Group struct { ID uint `gorm:"primaryKey;autoIncrement" json:"id"` Name string `gorm:"type:varchar(255);not null;unique" json:"name"` + Nickname string `gorm:"type:varchar(255)" json:"nickname"` Description string `gorm:"type:varchar(512)" json:"description"` Upstreams Upstreams `gorm:"type:json;not null" json:"upstreams"` ChannelType string `gorm:"type:varchar(50);not null" json:"channel_type"` @@ -93,7 +94,7 @@ type RequestLog struct { // GroupRequestStat 用于表示每个分组的请求统计 type GroupRequestStat struct { - GroupName string `json:"group_name"` + GroupNickname string `json:"group_nickname"` RequestCount int64 `json:"request_count"` } diff --git a/internal/router/router.go b/internal/router/router.go index ac99f96..b107633 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -103,7 +103,6 @@ func registerProtectedAPIRoutes(api *gin.RouterGroup, serverHandler *handler.Ser { groups.POST("", serverHandler.CreateGroup) groups.GET("", serverHandler.ListGroups) - groups.GET("/:id", serverHandler.GetGroup) groups.PUT("/:id", serverHandler.UpdateGroup) groups.DELETE("/:id", serverHandler.DeleteGroup) diff --git a/web/src/types/models.ts b/web/src/types/models.ts index 80af94a..376ca76 100644 --- a/web/src/types/models.ts +++ b/web/src/types/models.ts @@ -14,6 +14,7 @@ export interface APIKey { export interface Group { id: number; name: string; + nickname: string; description: string; channel_type: "openai" | "gemini"; config: string; @@ -58,6 +59,6 @@ export interface DashboardStats { } export interface GroupRequestStat { - group_name: string; + group_nickname: string; request_count: number; }