fix: 优化代码

This commit is contained in:
tbphp
2025-06-09 22:03:15 +08:00
parent 0c5cf4266d
commit 9562f0f6dd
5 changed files with 48 additions and 46 deletions

View File

@@ -125,11 +125,10 @@ func setupRoutes(handlers *handler.Handler, proxyServer *proxy.ProxyServer, conf
router.GET("/reset-keys", handlers.ResetKeys) router.GET("/reset-keys", handlers.ResetKeys)
router.GET("/config", handlers.GetConfig) // Debug endpoint router.GET("/config", handlers.GetConfig) // Debug endpoint
// Handle 404 and 405 // Handle 405 Method Not Allowed
router.NoRoute(handlers.NotFound)
router.NoMethod(handlers.MethodNotAllowed) router.NoMethod(handlers.MethodNotAllowed)
// Proxy all other requests // Proxy all other requests (this handles 404 as well)
router.NoRoute(proxyServer.HandleProxy) router.NoRoute(proxyServer.HandleProxy)
return router return router

View File

@@ -29,10 +29,10 @@ func NewHandler(keyManager types.KeyManager, config types.ConfigManager) *Handle
// Health handles health check requests // Health handles health check requests
func (h *Handler) Health(c *gin.Context) { func (h *Handler) Health(c *gin.Context) {
stats := h.keyManager.GetStats() stats := h.keyManager.GetStats()
status := "healthy" status := "healthy"
httpStatus := http.StatusOK httpStatus := http.StatusOK
// Check if there are any healthy keys // Check if there are any healthy keys
if stats.HealthyKeys == 0 { if stats.HealthyKeys == 0 {
status = "unhealthy" status = "unhealthy"
@@ -51,16 +51,16 @@ func (h *Handler) Health(c *gin.Context) {
// Stats handles statistics requests // Stats handles statistics requests
func (h *Handler) Stats(c *gin.Context) { func (h *Handler) Stats(c *gin.Context) {
stats := h.keyManager.GetStats() stats := h.keyManager.GetStats()
// Add additional system information // Add additional system information
var m runtime.MemStats var m runtime.MemStats
runtime.ReadMemStats(&m) runtime.ReadMemStats(&m)
response := gin.H{ response := gin.H{
"keys": gin.H{ "keys": gin.H{
"total": stats.TotalKeys, "total": stats.TotalKeys,
"healthy": stats.HealthyKeys, "healthy": stats.HealthyKeys,
"blacklisted": stats.BlacklistedKeys, "blacklisted": stats.BlacklistedKeys,
"current_index": stats.CurrentIndex, "current_index": stats.CurrentIndex,
}, },
"requests": gin.H{ "requests": gin.H{
@@ -77,9 +77,9 @@ func (h *Handler) Stats(c *gin.Context) {
"next_gc_mb": bToMb(m.NextGC), "next_gc_mb": bToMb(m.NextGC),
}, },
"system": gin.H{ "system": gin.H{
"goroutines": runtime.NumGoroutine(), "goroutines": runtime.NumGoroutine(),
"cpu_count": runtime.NumCPU(), "cpu_count": runtime.NumCPU(),
"go_version": runtime.Version(), "go_version": runtime.Version(),
}, },
"timestamp": time.Now().UTC().Format(time.RFC3339), "timestamp": time.Now().UTC().Format(time.RFC3339),
} }
@@ -90,11 +90,11 @@ func (h *Handler) Stats(c *gin.Context) {
// Blacklist handles blacklist requests // Blacklist handles blacklist requests
func (h *Handler) Blacklist(c *gin.Context) { func (h *Handler) Blacklist(c *gin.Context) {
blacklist := h.keyManager.GetBlacklist() blacklist := h.keyManager.GetBlacklist()
response := gin.H{ response := gin.H{
"blacklisted_keys": blacklist, "blacklisted_keys": blacklist,
"count": len(blacklist), "count": len(blacklist),
"timestamp": time.Now().UTC().Format(time.RFC3339), "timestamp": time.Now().UTC().Format(time.RFC3339),
} }
c.JSON(http.StatusOK, response) c.JSON(http.StatusOK, response)
@@ -104,7 +104,7 @@ func (h *Handler) Blacklist(c *gin.Context) {
func (h *Handler) ResetKeys(c *gin.Context) { func (h *Handler) ResetKeys(c *gin.Context) {
// Reset blacklist // Reset blacklist
h.keyManager.ResetBlacklist() h.keyManager.ResetBlacklist()
// Reload keys from file // Reload keys from file
if err := h.keyManager.LoadKeys(); err != nil { if err := h.keyManager.LoadKeys(); err != nil {
logrus.Errorf("Failed to reload keys: %v", err) logrus.Errorf("Failed to reload keys: %v", err)
@@ -116,25 +116,15 @@ func (h *Handler) ResetKeys(c *gin.Context) {
} }
stats := h.keyManager.GetStats() stats := h.keyManager.GetStats()
c.JSON(http.StatusOK, gin.H{
"message": "Keys reset and reloaded successfully",
"total_keys": stats.TotalKeys,
"healthy_keys": stats.HealthyKeys,
"timestamp": time.Now().UTC().Format(time.RFC3339),
})
logrus.Info("Keys reset and reloaded successfully")
}
// NotFound handles 404 requests c.JSON(http.StatusOK, gin.H{
func (h *Handler) NotFound(c *gin.Context) { "message": "Keys reset and reloaded successfully",
c.JSON(http.StatusNotFound, gin.H{ "total_keys": stats.TotalKeys,
"error": "Endpoint not found", "healthy_keys": stats.HealthyKeys,
"path": c.Request.URL.Path, "timestamp": time.Now().UTC().Format(time.RFC3339),
"method": c.Request.Method,
"timestamp": time.Now().UTC().Format(time.RFC3339),
}) })
logrus.Info("Keys reset and reloaded successfully")
} }
// MethodNotAllowed handles 405 requests // MethodNotAllowed handles 405 requests

View File

@@ -204,13 +204,15 @@ func (km *Manager) RecordFailure(key string, err error) {
} }
// Increment failure count // Increment failure count
failCount, _ := km.keyFailureCounts.LoadOrStore(key, int64(0)) failCount, _ := km.keyFailureCounts.LoadOrStore(key, new(int64))
newFailCount := atomic.AddInt64(failCount.(*int64), 1) if counter, ok := failCount.(*int64); ok {
newFailCount := atomic.AddInt64(counter, 1)
// Blacklist if threshold exceeded // Blacklist if threshold exceeded
if int(newFailCount) >= km.config.BlacklistThreshold { if int(newFailCount) >= km.config.BlacklistThreshold {
km.blacklistedKeys.Store(key, time.Now()) km.blacklistedKeys.Store(key, time.Now())
logrus.Debugf("Key blacklisted after %d failures", newFailCount) logrus.Debugf("Key blacklisted after %d failures", newFailCount)
}
} }
} }
@@ -236,7 +238,7 @@ func (km *Manager) GetStats() types.Stats {
km.keysMutex.RUnlock() km.keysMutex.RUnlock()
blacklistedCount := 0 blacklistedCount := 0
km.blacklistedKeys.Range(func(key, value interface{}) bool { km.blacklistedKeys.Range(func(key, value any) bool {
blacklistedCount++ blacklistedCount++
return true return true
}) })
@@ -273,7 +275,7 @@ func (km *Manager) ResetBlacklist() {
func (km *Manager) GetBlacklist() []types.BlacklistEntry { func (km *Manager) GetBlacklist() []types.BlacklistEntry {
var blacklist []types.BlacklistEntry var blacklist []types.BlacklistEntry
km.blacklistedKeys.Range(func(key, value interface{}) bool { km.blacklistedKeys.Range(func(key, value any) bool {
keyStr := key.(string) keyStr := key.(string)
blacklistTime := value.(time.Time) blacklistTime := value.(time.Time)

View File

@@ -173,7 +173,7 @@ func Auth(config types.AuthConfig) gin.HandlerFunc {
// Recovery creates a recovery middleware with custom error handling // Recovery creates a recovery middleware with custom error handling
func Recovery() gin.HandlerFunc { func Recovery() gin.HandlerFunc {
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) { return gin.CustomRecovery(func(c *gin.Context, recovered any) {
if err, ok := recovered.(string); ok { if err, ok := recovered.(string); ok {
logrus.Errorf("Panic recovered: %s", err) logrus.Errorf("Panic recovered: %s", err)
c.JSON(500, gin.H{ c.JSON(500, gin.H{

View File

@@ -1,7 +1,11 @@
// Package types defines common interfaces and types used across the application // Package types defines common interfaces and types used across the application
package types package types
import "time" import (
"time"
"github.com/gin-gonic/gin"
)
// ConfigManager defines the interface for configuration management // ConfigManager defines the interface for configuration management
type ConfigManager interface { type ConfigManager interface {
@@ -13,6 +17,7 @@ type ConfigManager interface {
GetPerformanceConfig() PerformanceConfig GetPerformanceConfig() PerformanceConfig
GetLogConfig() LogConfig GetLogConfig() LogConfig
Validate() error Validate() error
DisplayConfig()
} }
// KeyManager defines the interface for API key management // KeyManager defines the interface for API key management
@@ -27,6 +32,12 @@ type KeyManager interface {
Close() Close()
} }
// ProxyServer defines the interface for proxy server
type ProxyServer interface {
HandleProxy(c *gin.Context)
Close()
}
// ServerConfig represents server configuration // ServerConfig represents server configuration
type ServerConfig struct { type ServerConfig struct {
Port int `json:"port"` Port int `json:"port"`
@@ -64,8 +75,8 @@ type CORSConfig struct {
// PerformanceConfig represents performance configuration // PerformanceConfig represents performance configuration
type PerformanceConfig struct { type PerformanceConfig struct {
MaxConcurrentRequests int `json:"maxConcurrentRequests"` MaxConcurrentRequests int `json:"maxConcurrentRequests"`
RequestTimeout int `json:"requestTimeout"` RequestTimeout int `json:"requestTimeout"`
EnableGzip bool `json:"enableGzip"` EnableGzip bool `json:"enableGzip"`
} }