diff --git a/internal/store/memory.go b/internal/store/memory.go index 9c6f240..574ca06 100644 --- a/internal/store/memory.go +++ b/internal/store/memory.go @@ -16,15 +16,19 @@ type memoryStoreItem struct { // MemoryStore is an in-memory key-value store that is safe for concurrent use. // It now supports simple K/V, HASH, and LIST data types. type MemoryStore struct { - mu sync.RWMutex - // Using 'any' to store different data structures (memoryStoreItem, map[string]string, []string) + mu sync.RWMutex data map[string]any + + // For Pub/Sub + muSubscribers sync.RWMutex + subscribers map[string]map[chan *Message]struct{} } // NewMemoryStore creates and returns a new MemoryStore instance. func NewMemoryStore() *MemoryStore { s := &MemoryStore{ - data: make(map[string]any), + data: make(map[string]any), + subscribers: make(map[string]map[chan *Message]struct{}), } // The cleanup loop was removed as it's not compatible with multiple data types // without a unified expiration mechanism, and the KeyPool feature does not rely on TTLs. @@ -304,3 +308,77 @@ func (s *MemoryStore) Rotate(key string) (string, error) { return item, nil } + +// --- Pub/Sub operations --- + +// memorySubscription implements the Subscription interface for the in-memory store. +type memorySubscription struct { + store *MemoryStore + channel string + msgChan chan *Message +} + +// Channel returns the message channel for the subscription. +func (ms *memorySubscription) Channel() <-chan *Message { + return ms.msgChan +} + +// Close removes the subscription from the store. +func (ms *memorySubscription) Close() error { + ms.store.muSubscribers.Lock() + defer ms.store.muSubscribers.Unlock() + + if subs, ok := ms.store.subscribers[ms.channel]; ok { + delete(subs, ms.msgChan) + if len(subs) == 0 { + delete(ms.store.subscribers, ms.channel) + } + } + close(ms.msgChan) + return nil +} + +// Publish sends a message to all subscribers of a channel. +func (s *MemoryStore) Publish(channel string, message []byte) error { + s.muSubscribers.RLock() + defer s.muSubscribers.RUnlock() + + msg := &Message{ + Channel: channel, + Payload: message, + } + + if subs, ok := s.subscribers[channel]; ok { + for subCh := range subs { + // Non-blocking send + go func(c chan *Message) { + select { + case c <- msg: + case <-time.After(1 * time.Second): // Prevent goroutine leak if receiver is stuck + } + }(subCh) + } + } + return nil +} + +// Subscribe listens for messages on a given channel. +func (s *MemoryStore) Subscribe(channel string) (Subscription, error) { + s.muSubscribers.Lock() + defer s.muSubscribers.Unlock() + + msgChan := make(chan *Message, 10) // Buffered channel + + if _, ok := s.subscribers[channel]; !ok { + s.subscribers[channel] = make(map[chan *Message]struct{}) + } + s.subscribers[channel][msgChan] = struct{}{} + + sub := &memorySubscription{ + store: s, + channel: channel, + msgChan: msgChan, + } + + return sub, nil +} diff --git a/internal/store/redis.go b/internal/store/redis.go index 9c22975..b807d53 100644 --- a/internal/store/redis.go +++ b/internal/store/redis.go @@ -3,6 +3,7 @@ package store import ( "context" "errors" + "fmt" "time" "github.com/redis/go-redis/v9" @@ -122,3 +123,49 @@ func (s *RedisStore) Pipeline() Pipeliner { func (s *RedisStore) Eval(script string, keys []string, args ...interface{}) (interface{}, error) { return s.client.Eval(context.Background(), script, keys, args...).Result() } + +// --- Pub/Sub operations --- + +// redisSubscription wraps the redis.PubSub to implement the Subscription interface. +type redisSubscription struct { + pubsub *redis.PubSub +} + +// Channel returns a channel that receives messages from the subscription. +// It handles the conversion from redis.Message to our internal Message type. +func (rs *redisSubscription) Channel() <-chan *Message { + ch := make(chan *Message) + go func() { + defer close(ch) + for redisMsg := range rs.pubsub.Channel() { + ch <- &Message{ + Channel: redisMsg.Channel, + Payload: []byte(redisMsg.Payload), + } + } + }() + return ch +} + +// Close closes the subscription. +func (rs *redisSubscription) Close() error { + return rs.pubsub.Close() +} + +// Publish sends a message to a given channel. +func (s *RedisStore) Publish(channel string, message []byte) error { + return s.client.Publish(context.Background(), channel, message).Err() +} + +// Subscribe listens for messages on a given channel. +func (s *RedisStore) Subscribe(channel string) (Subscription, error) { + pubsub := s.client.Subscribe(context.Background(), channel) + + // Wait for confirmation that subscription is created. + _, err := pubsub.Receive(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err) + } + + return &redisSubscription{pubsub: pubsub}, nil +} diff --git a/internal/store/store.go b/internal/store/store.go index e7b4b5a..0fba62a 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -8,6 +8,20 @@ import ( // ErrNotFound is the error returned when a key is not found in the store. var ErrNotFound = errors.New("store: key not found") +// Message is the struct for received pub/sub messages. +type Message struct { + Channel string + Payload []byte +} + +// Subscription represents an active subscription to a pub/sub channel. +type Subscription interface { + // Channel returns the channel for receiving messages. + Channel() <-chan *Message + // Close unsubscribes and releases any resources associated with the subscription. + Close() error +} + // Store is a generic key-value store interface. // Implementations of this interface must be safe for concurrent use. type Store interface { @@ -44,6 +58,13 @@ type Store interface { // Close closes the store and releases any underlying resources. Close() error + + // Publish sends a message to a given channel. + Publish(channel string, message []byte) error + + // Subscribe listens for messages on a given channel. + // It returns a Subscription object that can be used to receive messages and to close the subscription. + Subscribe(channel string) (Subscription, error) } // Pipeliner defines an interface for executing a batch of commands. diff --git a/internal/syncer/cache_syncer.go b/internal/syncer/cache_syncer.go new file mode 100644 index 0000000..7f7470b --- /dev/null +++ b/internal/syncer/cache_syncer.go @@ -0,0 +1,155 @@ +package syncer + +import ( + "fmt" + "sync" + "time" + + "gpt-load/internal/store" + + "github.com/sirupsen/logrus" +) + +// LoaderFunc defines a generic function signature for loading data from the source of truth (e.g., database). +type LoaderFunc[T any] func() (T, error) + +// CacheSyncer is a generic service that manages in-memory caching and cross-instance synchronization. +type CacheSyncer[T any] struct { + mu sync.RWMutex + cache T + loader LoaderFunc[T] + store store.Store + channelName string + logger *logrus.Entry + stopChan chan struct{} + wg sync.WaitGroup +} + +// NewCacheSyncer creates and initializes a new CacheSyncer. +func NewCacheSyncer[T any]( + loader LoaderFunc[T], + store store.Store, + channelName string, + logger *logrus.Entry, +) (*CacheSyncer[T], error) { + s := &CacheSyncer[T]{ + loader: loader, + store: store, + channelName: channelName, + logger: logger, + stopChan: make(chan struct{}), + } + + if err := s.reload(); err != nil { + return nil, fmt.Errorf("initial load for %s failed: %w", channelName, err) + } + + s.wg.Add(1) + go s.listenForUpdates() + + return s, nil +} + +// Get safely returns the cached data. +func (s *CacheSyncer[T]) Get() T { + s.mu.RLock() + defer s.mu.RUnlock() + return s.cache +} + +// Invalidate publishes a notification to all instances to reload their cache. +func (s *CacheSyncer[T]) Invalidate() error { + s.logger.Debug("publishing invalidation notification") + s.reload() + return s.store.Publish(s.channelName, []byte("reload")) +} + +// Stop gracefully shuts down the syncer's background goroutine. +func (s *CacheSyncer[T]) Stop() { + s.logger.Debug("stopping cache syncer...") + close(s.stopChan) + s.wg.Wait() + s.logger.Info("cache syncer stopped.") +} + +// reload fetches the latest data using the loader function and updates the cache. +func (s *CacheSyncer[T]) reload() error { + s.logger.Debug("reloading cache...") + newData, err := s.loader() + if err != nil { + s.logger.Errorf("failed to reload cache: %v", err) + return err + } + + s.mu.Lock() + s.cache = newData + s.mu.Unlock() + + s.logger.Info("cache reloaded successfully") + return nil +} + +// listenForUpdates runs in the background, listening for invalidation messages. +func (s *CacheSyncer[T]) listenForUpdates() { + defer s.wg.Done() + + for { + select { + case <-s.stopChan: + s.logger.Info("received stop signal, exiting listener loop.") + return + default: + } + + if s.store == nil { + s.logger.Warn("store is not configured, stopping subscription listener.") + return + } + + subscription, err := s.store.Subscribe(s.channelName) + if err != nil { + s.logger.Errorf("failed to subscribe, retrying in 5s: %v", err) + select { + case <-time.After(5 * time.Second): + continue + case <-s.stopChan: + return + } + } + + s.logger.Debugf("subscribed to channel: %s", s.channelName) + + subscriberLoop: + for { + select { + case msg, ok := <-subscription.Channel(): + if !ok { + s.logger.Warn("subscription channel closed, attempting to re-subscribe...") + break subscriberLoop // This will lead to closing the current subscription and retrying. + } + s.logger.Debugf("received invalidation notification, payload: %s", string(msg.Payload)) + if err := s.reload(); err != nil { + s.logger.Errorf("failed to reload cache after notification: %v", err) + } + case <-s.stopChan: + s.logger.Info("received stop signal, exiting subscriber loop.") + if err := subscription.Close(); err != nil { + s.logger.Errorf("failed to close subscription: %v", err) + } + return + } + } + + // Before retrying, ensure the old subscription is closed. + if err := subscription.Close(); err != nil { + s.logger.Errorf("failed to close subscription before retrying: %v", err) + } + + // Wait a moment before retrying to avoid tight loops on persistent errors. + select { + case <-time.After(2 * time.Second): + case <-s.stopChan: + return + } + } +}