434 lines
9.0 KiB
Go
434 lines
9.0 KiB
Go
package store
|
|
|
|
import (
|
|
"fmt"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// memoryStoreItem holds the value and expiration timestamp for a key.
|
|
type memoryStoreItem struct {
|
|
value []byte
|
|
expiresAt int64 // Unix-nano timestamp. 0 for no expiry.
|
|
}
|
|
|
|
// MemoryStore is an in-memory key-value store that is safe for concurrent use.
|
|
type MemoryStore struct {
|
|
mu sync.RWMutex
|
|
data map[string]any
|
|
muSubscribers sync.RWMutex
|
|
subscribers map[string]map[chan *Message]struct{}
|
|
}
|
|
|
|
// NewMemoryStore creates and returns a new MemoryStore instance.
|
|
func NewMemoryStore() *MemoryStore {
|
|
s := &MemoryStore{
|
|
data: make(map[string]any),
|
|
subscribers: make(map[string]map[chan *Message]struct{}),
|
|
}
|
|
return s
|
|
}
|
|
|
|
// Close cleans up resources.
|
|
func (s *MemoryStore) Close() error {
|
|
return nil
|
|
}
|
|
|
|
// Set stores a key-value pair.
|
|
func (s *MemoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
var expiresAt int64
|
|
if ttl > 0 {
|
|
expiresAt = time.Now().UnixNano() + ttl.Nanoseconds()
|
|
}
|
|
|
|
s.data[key] = memoryStoreItem{
|
|
value: value,
|
|
expiresAt: expiresAt,
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get retrieves a value by its key.
|
|
func (s *MemoryStore) Get(key string) ([]byte, error) {
|
|
s.mu.RLock()
|
|
rawItem, exists := s.data[key]
|
|
s.mu.RUnlock()
|
|
|
|
if !exists {
|
|
return nil, ErrNotFound
|
|
}
|
|
|
|
item, ok := rawItem.(memoryStoreItem)
|
|
if !ok {
|
|
return nil, fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
|
}
|
|
|
|
if item.expiresAt > 0 && time.Now().UnixNano() > item.expiresAt {
|
|
s.mu.Lock()
|
|
delete(s.data, key)
|
|
s.mu.Unlock()
|
|
return nil, ErrNotFound
|
|
}
|
|
|
|
return item.value, nil
|
|
}
|
|
|
|
// Delete removes a value by its key.
|
|
func (s *MemoryStore) Delete(key string) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
delete(s.data, key)
|
|
return nil
|
|
}
|
|
|
|
// Del removes multiple values by their keys.
|
|
func (s *MemoryStore) Del(keys ...string) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
for _, key := range keys {
|
|
delete(s.data, key)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Exists checks if a key exists.
|
|
func (s *MemoryStore) Exists(key string) (bool, error) {
|
|
s.mu.RLock()
|
|
rawItem, exists := s.data[key]
|
|
s.mu.RUnlock()
|
|
|
|
if !exists {
|
|
return false, nil
|
|
}
|
|
|
|
if item, ok := rawItem.(memoryStoreItem); ok {
|
|
if item.expiresAt > 0 && time.Now().UnixNano() > item.expiresAt {
|
|
s.mu.Lock()
|
|
delete(s.data, key)
|
|
s.mu.Unlock()
|
|
return false, nil
|
|
}
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// SetNX sets a key-value pair if the key does not already exist.
|
|
func (s *MemoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
rawItem, exists := s.data[key]
|
|
if exists {
|
|
if item, ok := rawItem.(memoryStoreItem); ok {
|
|
if item.expiresAt == 0 || time.Now().UnixNano() < item.expiresAt {
|
|
return false, nil
|
|
}
|
|
} else {
|
|
// Key exists but is not a simple K/V item, treat as existing
|
|
return false, nil
|
|
}
|
|
}
|
|
|
|
// Key does not exist or is expired, so we can set it.
|
|
var expiresAt int64
|
|
if ttl > 0 {
|
|
expiresAt = time.Now().UnixNano() + ttl.Nanoseconds()
|
|
}
|
|
s.data[key] = memoryStoreItem{
|
|
value: value,
|
|
expiresAt: expiresAt,
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
// --- HASH operations ---
|
|
|
|
func (s *MemoryStore) HSet(key string, values map[string]any) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
var hash map[string]string
|
|
rawHash, exists := s.data[key]
|
|
if !exists {
|
|
hash = make(map[string]string)
|
|
s.data[key] = hash
|
|
} else {
|
|
var ok bool
|
|
hash, ok = rawHash.(map[string]string)
|
|
if !ok {
|
|
return fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
|
}
|
|
}
|
|
|
|
for field, value := range values {
|
|
hash[field] = fmt.Sprint(value)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *MemoryStore) HGetAll(key string) (map[string]string, error) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
rawHash, exists := s.data[key]
|
|
if !exists {
|
|
return make(map[string]string), nil
|
|
}
|
|
|
|
hash, ok := rawHash.(map[string]string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
|
}
|
|
|
|
result := make(map[string]string, len(hash))
|
|
for k, v := range hash {
|
|
result[k] = v
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (s *MemoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
var hash map[string]string
|
|
rawHash, exists := s.data[key]
|
|
if !exists {
|
|
hash = make(map[string]string)
|
|
s.data[key] = hash
|
|
} else {
|
|
var ok bool
|
|
hash, ok = rawHash.(map[string]string)
|
|
if !ok {
|
|
return 0, fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
|
}
|
|
}
|
|
|
|
currentVal, _ := strconv.ParseInt(hash[field], 10, 64)
|
|
newVal := currentVal + incr
|
|
hash[field] = strconv.FormatInt(newVal, 10)
|
|
|
|
return newVal, nil
|
|
}
|
|
|
|
// --- LIST operations ---
|
|
|
|
func (s *MemoryStore) LPush(key string, values ...any) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
var list []string
|
|
rawList, exists := s.data[key]
|
|
if !exists {
|
|
list = make([]string, 0)
|
|
} else {
|
|
var ok bool
|
|
list, ok = rawList.([]string)
|
|
if !ok {
|
|
return fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
|
}
|
|
}
|
|
|
|
strValues := make([]string, len(values))
|
|
for i, v := range values {
|
|
strValues[i] = fmt.Sprint(v)
|
|
}
|
|
|
|
s.data[key] = append(strValues, list...) // Prepend
|
|
return nil
|
|
}
|
|
|
|
func (s *MemoryStore) LRem(key string, count int64, value any) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
rawList, exists := s.data[key]
|
|
if !exists {
|
|
return nil
|
|
}
|
|
|
|
list, ok := rawList.([]string)
|
|
if !ok {
|
|
return fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
|
}
|
|
|
|
strValue := fmt.Sprint(value)
|
|
newList := make([]string, 0, len(list))
|
|
|
|
if count != 0 {
|
|
return fmt.Errorf("LRem with non-zero count is not implemented in MemoryStore")
|
|
}
|
|
|
|
for _, item := range list {
|
|
if item != strValue {
|
|
newList = append(newList, item)
|
|
}
|
|
}
|
|
s.data[key] = newList
|
|
return nil
|
|
}
|
|
|
|
func (s *MemoryStore) Rotate(key string) (string, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
rawList, exists := s.data[key]
|
|
if !exists {
|
|
return "", ErrNotFound
|
|
}
|
|
|
|
list, ok := rawList.([]string)
|
|
if !ok {
|
|
return "", fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
|
}
|
|
|
|
if len(list) == 0 {
|
|
return "", ErrNotFound
|
|
}
|
|
|
|
lastIndex := len(list) - 1
|
|
item := list[lastIndex]
|
|
|
|
// "LPUSH"
|
|
newList := append([]string{item}, list[:lastIndex]...)
|
|
s.data[key] = newList
|
|
|
|
return item, nil
|
|
}
|
|
|
|
// --- SET operations ---
|
|
|
|
// SAdd adds members to a set.
|
|
func (s *MemoryStore) SAdd(key string, members ...any) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
var set map[string]struct{}
|
|
rawSet, exists := s.data[key]
|
|
if !exists {
|
|
set = make(map[string]struct{})
|
|
s.data[key] = set
|
|
} else {
|
|
var ok bool
|
|
set, ok = rawSet.(map[string]struct{})
|
|
if !ok {
|
|
return fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
|
}
|
|
}
|
|
|
|
for _, member := range members {
|
|
set[fmt.Sprint(member)] = struct{}{}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SPopN randomly removes and returns the given number of members from a set.
|
|
func (s *MemoryStore) SPopN(key string, count int64) ([]string, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
rawSet, exists := s.data[key]
|
|
if !exists {
|
|
return []string{}, nil
|
|
}
|
|
|
|
set, ok := rawSet.(map[string]struct{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("type mismatch: key '%s' holds a different data type", key)
|
|
}
|
|
|
|
if count > int64(len(set)) {
|
|
count = int64(len(set))
|
|
}
|
|
|
|
popped := make([]string, 0, count)
|
|
for member := range set {
|
|
if int64(len(popped)) >= count {
|
|
break
|
|
}
|
|
popped = append(popped, member)
|
|
delete(set, member)
|
|
}
|
|
|
|
return popped, nil
|
|
}
|
|
|
|
// --- Pub/Sub operations ---
|
|
|
|
// memorySubscription implements the Subscription interface for the in-memory store.
|
|
type memorySubscription struct {
|
|
store *MemoryStore
|
|
channel string
|
|
msgChan chan *Message
|
|
}
|
|
|
|
// Channel returns the message channel for the subscription.
|
|
func (ms *memorySubscription) Channel() <-chan *Message {
|
|
return ms.msgChan
|
|
}
|
|
|
|
// Close removes the subscription from the store.
|
|
func (ms *memorySubscription) Close() error {
|
|
ms.store.muSubscribers.Lock()
|
|
defer ms.store.muSubscribers.Unlock()
|
|
|
|
if subs, ok := ms.store.subscribers[ms.channel]; ok {
|
|
delete(subs, ms.msgChan)
|
|
if len(subs) == 0 {
|
|
delete(ms.store.subscribers, ms.channel)
|
|
}
|
|
}
|
|
close(ms.msgChan)
|
|
return nil
|
|
}
|
|
|
|
// Publish sends a message to all subscribers of a channel.
|
|
func (s *MemoryStore) Publish(channel string, message []byte) error {
|
|
s.muSubscribers.RLock()
|
|
defer s.muSubscribers.RUnlock()
|
|
|
|
msg := &Message{
|
|
Channel: channel,
|
|
Payload: message,
|
|
}
|
|
|
|
if subs, ok := s.subscribers[channel]; ok {
|
|
for subCh := range subs {
|
|
go func(c chan *Message) {
|
|
select {
|
|
case c <- msg:
|
|
case <-time.After(1 * time.Second):
|
|
}
|
|
}(subCh)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Subscribe listens for messages on a given channel.
|
|
func (s *MemoryStore) Subscribe(channel string) (Subscription, error) {
|
|
s.muSubscribers.Lock()
|
|
defer s.muSubscribers.Unlock()
|
|
|
|
msgChan := make(chan *Message, 10) // Buffered channel
|
|
|
|
if _, ok := s.subscribers[channel]; !ok {
|
|
s.subscribers[channel] = make(map[chan *Message]struct{})
|
|
}
|
|
s.subscribers[channel][msgChan] = struct{}{}
|
|
|
|
sub := &memorySubscription{
|
|
store: s,
|
|
channel: channel,
|
|
msgChan: msgChan,
|
|
}
|
|
|
|
return sub, nil
|
|
}
|