memoization.go

  1// Package memoization implement a simple memoization cache. It's designed to
  2// improve performance in textarea.
  3package memoization
  4
  5import (
  6	"container/list"
  7	"crypto/sha256"
  8	"fmt"
  9	"sync"
 10)
 11
 12// Hasher is an interface that requires a Hash method. The Hash method is
 13// expected to return a string representation of the hash of the object.
 14type Hasher interface {
 15	Hash() string
 16}
 17
 18// entry is a struct that holds a key-value pair. It is used as an element
 19// in the evictionList of the MemoCache.
 20type entry[T any] struct {
 21	key   string
 22	value T
 23}
 24
 25// MemoCache is a struct that represents a cache with a set capacity. It
 26// uses an LRU (Least Recently Used) eviction policy. It is safe for
 27// concurrent use.
 28type MemoCache[H Hasher, T any] struct {
 29	capacity      int
 30	mutex         sync.Mutex
 31	cache         map[string]*list.Element // The cache holding the results
 32	evictionList  *list.List               // A list to keep track of the order for LRU
 33	hashableItems map[string]T             // This map keeps track of the original hashable items (optional)
 34}
 35
 36// NewMemoCache is a function that creates a new MemoCache with a given
 37// capacity. It returns a pointer to the created MemoCache.
 38func NewMemoCache[H Hasher, T any](capacity int) *MemoCache[H, T] {
 39	return &MemoCache[H, T]{
 40		capacity:      capacity,
 41		cache:         make(map[string]*list.Element),
 42		evictionList:  list.New(),
 43		hashableItems: make(map[string]T),
 44	}
 45}
 46
 47// Capacity is a method that returns the capacity of the MemoCache.
 48func (m *MemoCache[H, T]) Capacity() int {
 49	return m.capacity
 50}
 51
 52// Size is a method that returns the current size of the MemoCache. It is
 53// the number of items currently stored in the cache.
 54func (m *MemoCache[H, T]) Size() int {
 55	m.mutex.Lock()
 56	defer m.mutex.Unlock()
 57	return m.evictionList.Len()
 58}
 59
 60// Get is a method that returns the value associated with the given
 61// hashable item in the MemoCache. If there is no corresponding value, the
 62// method returns nil.
 63func (m *MemoCache[H, T]) Get(h H) (T, bool) {
 64	m.mutex.Lock()
 65	defer m.mutex.Unlock()
 66
 67	hashedKey := h.Hash()
 68	if element, found := m.cache[hashedKey]; found {
 69		m.evictionList.MoveToFront(element)
 70		return element.Value.(*entry[T]).value, true
 71	}
 72	var result T
 73	return result, false
 74}
 75
 76// Set is a method that sets the value for the given hashable item in the
 77// MemoCache. If the cache is at capacity, it evicts the least recently
 78// used item before adding the new item.
 79func (m *MemoCache[H, T]) Set(h H, value T) {
 80	m.mutex.Lock()
 81	defer m.mutex.Unlock()
 82
 83	hashedKey := h.Hash()
 84	if element, found := m.cache[hashedKey]; found {
 85		m.evictionList.MoveToFront(element)
 86		element.Value.(*entry[T]).value = value
 87		return
 88	}
 89
 90	// Check if the cache is at capacity
 91	if m.evictionList.Len() >= m.capacity {
 92		// Evict the least recently used item from the cache
 93		toEvict := m.evictionList.Back()
 94		if toEvict != nil {
 95			evictedEntry := m.evictionList.Remove(toEvict).(*entry[T])
 96			delete(m.cache, evictedEntry.key)
 97			delete(m.hashableItems, evictedEntry.key) // if you're keeping track of original items
 98		}
 99	}
100
101	// Add the value to the cache and the evictionList
102	newEntry := &entry[T]{
103		key:   hashedKey,
104		value: value,
105	}
106	element := m.evictionList.PushFront(newEntry)
107	m.cache[hashedKey] = element
108	m.hashableItems[hashedKey] = value // if you're keeping track of original items
109}
110
111// HString is a type that implements the Hasher interface for strings.
112type HString string
113
114// Hash is a method that returns the hash of the string.
115func (h HString) Hash() string {
116	return fmt.Sprintf("%x", sha256.Sum256([]byte(h)))
117}
118
119// HInt is a type that implements the Hasher interface for integers.
120type HInt int
121
122// Hash is a method that returns the hash of the integer.
123func (h HInt) Hash() string {
124	return fmt.Sprintf("%x", sha256.Sum256(fmt.Appendf(nil, "%d", h)))
125}