Files
gpt-load/internal/store/memory.go
2025-07-09 13:28:11 +08:00

385 lines
8.5 KiB
Go

package store
import (
"fmt"
"strconv"
"sync"
"time"
)
// memoryStoreItem holds the value and expiration timestamp for a key.
type memoryStoreItem struct {
value []byte
expiresAt int64 // Unix-nano timestamp. 0 for no expiry.
}
// MemoryStore is an in-memory key-value store that is safe for concurrent use.
// It now supports simple K/V, HASH, and LIST data types.
type MemoryStore struct {
mu sync.RWMutex
data map[string]any
// For Pub/Sub
muSubscribers sync.RWMutex
subscribers map[string]map[chan *Message]struct{}
}
// NewMemoryStore creates and returns a new MemoryStore instance.
func NewMemoryStore() *MemoryStore {
s := &MemoryStore{
data: make(map[string]any),
subscribers: make(map[string]map[chan *Message]struct{}),
}
// The cleanup loop was removed as it's not compatible with multiple data types
// without a unified expiration mechanism, and the KeyPool feature does not rely on TTLs.
return s
}
// Close cleans up resources.
func (s *MemoryStore) Close() error {
// Nothing to close for now.
return nil
}
// Set stores a key-value pair.
func (s *MemoryStore) Set(key string, value []byte, ttl time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
var expiresAt int64
if ttl > 0 {
expiresAt = time.Now().UnixNano() + ttl.Nanoseconds()
}
s.data[key] = memoryStoreItem{
value: value,
expiresAt: expiresAt,
}
return nil
}
// Get retrieves a value by its key.
func (s *MemoryStore) Get(key string) ([]byte, error) {
s.mu.RLock()
rawItem, exists := s.data[key]
s.mu.RUnlock()
if !exists {
return nil, ErrNotFound
}
item, ok := rawItem.(memoryStoreItem)
if !ok {
return nil, fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
}
// Check for expiration
if item.expiresAt > 0 && time.Now().UnixNano() > item.expiresAt {
// Lazy deletion
s.mu.Lock()
delete(s.data, key)
s.mu.Unlock()
return nil, ErrNotFound
}
return item.value, nil
}
// Delete removes a value by its key.
func (s *MemoryStore) Delete(key string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.data, key)
return nil
}
// Exists checks if a key exists.
func (s *MemoryStore) Exists(key string) (bool, error) {
s.mu.RLock()
rawItem, exists := s.data[key]
s.mu.RUnlock()
if !exists {
return false, nil
}
// Check for expiration only if it's a simple K/V item
if item, ok := rawItem.(memoryStoreItem); ok {
if item.expiresAt > 0 && time.Now().UnixNano() > item.expiresAt {
// Lazy deletion
s.mu.Lock()
delete(s.data, key)
s.mu.Unlock()
return false, nil
}
}
return true, nil
}
// SetNX sets a key-value pair if the key does not already exist.
func (s *MemoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
// In memory store, we need to manually check for existence and expiration
rawItem, exists := s.data[key]
if exists {
if item, ok := rawItem.(memoryStoreItem); ok {
if item.expiresAt == 0 || time.Now().UnixNano() < item.expiresAt {
// Key exists and is not expired
return false, nil
}
} else {
// Key exists but is not a simple K/V item, treat as existing
return false, nil
}
}
// Key does not exist or is expired, so we can set it.
var expiresAt int64
if ttl > 0 {
expiresAt = time.Now().UnixNano() + ttl.Nanoseconds()
}
s.data[key] = memoryStoreItem{
value: value,
expiresAt: expiresAt,
}
return true, nil
}
// --- HASH operations ---
func (s *MemoryStore) HSet(key string, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
var hash map[string]string
rawHash, exists := s.data[key]
if !exists {
hash = make(map[string]string)
s.data[key] = hash
} else {
var ok bool
hash, ok = rawHash.(map[string]string)
if !ok {
return fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
}
}
for field, value := range values {
hash[field] = fmt.Sprint(value)
}
return nil
}
func (s *MemoryStore) HGetAll(key string) (map[string]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
rawHash, exists := s.data[key]
if !exists {
// Per Redis convention, HGETALL on a non-existent key returns an empty map, not an error.
return make(map[string]string), nil
}
hash, ok := rawHash.(map[string]string)
if !ok {
return nil, fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
}
// Return a copy to prevent race conditions on the returned map
result := make(map[string]string, len(hash))
for k, v := range hash {
result[k] = v
}
return result, nil
}
func (s *MemoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
var hash map[string]string
rawHash, exists := s.data[key]
if !exists {
hash = make(map[string]string)
s.data[key] = hash
} else {
var ok bool
hash, ok = rawHash.(map[string]string)
if !ok {
return 0, fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
}
}
currentVal, _ := strconv.ParseInt(hash[field], 10, 64)
newVal := currentVal + incr
hash[field] = strconv.FormatInt(newVal, 10)
return newVal, nil
}
// --- LIST operations ---
func (s *MemoryStore) LPush(key string, values ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
var list []string
rawList, exists := s.data[key]
if !exists {
list = make([]string, 0)
} else {
var ok bool
list, ok = rawList.([]string)
if !ok {
return fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
}
}
strValues := make([]string, len(values))
for i, v := range values {
strValues[i] = fmt.Sprint(v)
}
s.data[key] = append(strValues, list...) // Prepend
return nil
}
func (s *MemoryStore) LRem(key string, count int64, value any) error {
s.mu.Lock()
defer s.mu.Unlock()
rawList, exists := s.data[key]
if !exists {
return nil
}
list, ok := rawList.([]string)
if !ok {
return fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
}
strValue := fmt.Sprint(value)
newList := make([]string, 0, len(list))
// LREM with count = 0: Remove all elements equal to value.
if count != 0 {
// For now, only implement count = 0 behavior as it's what we need.
return fmt.Errorf("LRem with non-zero count is not implemented in MemoryStore")
}
for _, item := range list {
if item != strValue {
newList = append(newList, item)
}
}
s.data[key] = newList
return nil
}
func (s *MemoryStore) Rotate(key string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
rawList, exists := s.data[key]
if !exists {
return "", ErrNotFound
}
list, ok := rawList.([]string)
if !ok {
return "", fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
}
if len(list) == 0 {
return "", ErrNotFound
}
// "RPOP"
lastIndex := len(list) - 1
item := list[lastIndex]
// "LPUSH"
newList := append([]string{item}, list[:lastIndex]...)
s.data[key] = newList
return item, nil
}
// --- Pub/Sub operations ---
// memorySubscription implements the Subscription interface for the in-memory store.
type memorySubscription struct {
store *MemoryStore
channel string
msgChan chan *Message
}
// Channel returns the message channel for the subscription.
func (ms *memorySubscription) Channel() <-chan *Message {
return ms.msgChan
}
// Close removes the subscription from the store.
func (ms *memorySubscription) Close() error {
ms.store.muSubscribers.Lock()
defer ms.store.muSubscribers.Unlock()
if subs, ok := ms.store.subscribers[ms.channel]; ok {
delete(subs, ms.msgChan)
if len(subs) == 0 {
delete(ms.store.subscribers, ms.channel)
}
}
close(ms.msgChan)
return nil
}
// Publish sends a message to all subscribers of a channel.
func (s *MemoryStore) Publish(channel string, message []byte) error {
s.muSubscribers.RLock()
defer s.muSubscribers.RUnlock()
msg := &Message{
Channel: channel,
Payload: message,
}
if subs, ok := s.subscribers[channel]; ok {
for subCh := range subs {
// Non-blocking send
go func(c chan *Message) {
select {
case c <- msg:
case <-time.After(1 * time.Second): // Prevent goroutine leak if receiver is stuck
}
}(subCh)
}
}
return nil
}
// Subscribe listens for messages on a given channel.
func (s *MemoryStore) Subscribe(channel string) (Subscription, error) {
s.muSubscribers.Lock()
defer s.muSubscribers.Unlock()
msgChan := make(chan *Message, 10) // Buffered channel
if _, ok := s.subscribers[channel]; !ok {
s.subscribers[channel] = make(map[chan *Message]struct{})
}
s.subscribers[channel][msgChan] = struct{}{}
sub := &memorySubscription{
store: s,
channel: channel,
msgChan: msgChan,
}
return sub, nil
}