feat: add a cache and implement a default lru policy

Ayman Bagabas created

Change summary

cmd/soft/root.go                |  1 
internal/init/cache.go          | 12 +++++
server/backend/sqlite/sqlite.go | 73 ++++++++++--------------------
server/cache/cache.go           | 26 ++++++++++
server/cache/context.go         | 24 ++++++++++
server/cache/lru/lru.go         | 84 +++++++++++++++++++++++++++++++++++
server/cache/noop/noop.go       | 42 +++++++++++++++++
server/cache/registry.go        | 39 ++++++++++++++++
server/daemon/daemon_test.go    |  4 +
server/server.go                | 12 +++++
server/ssh/session_test.go      |  4 +
11 files changed, 272 insertions(+), 49 deletions(-)

Detailed changes

cmd/soft/root.go 🔗

@@ -6,6 +6,7 @@ import (
 	"runtime/debug"
 
 	"github.com/charmbracelet/log"
+	_ "github.com/charmbracelet/soft-serve/internal/init" // initialize registry
 	. "github.com/charmbracelet/soft-serve/internal/log"
 	"github.com/spf13/cobra"
 	"go.uber.org/automaxprocs/maxprocs"

internal/init/cache.go 🔗

@@ -0,0 +1,12 @@
+package init
+
+import (
+	"github.com/charmbracelet/soft-serve/server/cache"
+	"github.com/charmbracelet/soft-serve/server/cache/lru"
+	"github.com/charmbracelet/soft-serve/server/cache/noop"
+)
+
+func init() {
+	cache.Register("lru", lru.NewCache)
+	cache.Register("noop", noop.NewCache)
+}

server/backend/sqlite/sqlite.go 🔗

@@ -11,10 +11,10 @@ import (
 	"github.com/charmbracelet/log"
 	"github.com/charmbracelet/soft-serve/git"
 	"github.com/charmbracelet/soft-serve/server/backend"
+	"github.com/charmbracelet/soft-serve/server/cache"
 	"github.com/charmbracelet/soft-serve/server/config"
 	"github.com/charmbracelet/soft-serve/server/hooks"
 	"github.com/charmbracelet/soft-serve/server/utils"
-	lru "github.com/hashicorp/golang-lru/v2"
 	"github.com/jmoiron/sqlx"
 	_ "modernc.org/sqlite" // sqlite driver
 )
@@ -29,7 +29,7 @@ type SqliteBackend struct { //nolint: revive
 	logger *log.Logger
 
 	// Repositories cache
-	cache *cache
+	cache cache.Cache
 }
 
 var _ backend.Backend = (*SqliteBackend)(nil)
@@ -57,12 +57,10 @@ func NewSqliteBackend(ctx context.Context) (*SqliteBackend, error) {
 		ctx:    ctx,
 		dp:     dataPath,
 		db:     db,
+		cache:  cache.FromContext(ctx),
 		logger: log.FromContext(ctx).WithPrefix("sqlite"),
 	}
 
-	// Set up LRU cache with size 1000
-	d.cache = newCache(d, 1000)
-
 	if err := d.init(); err != nil {
 		return nil, err
 	}
@@ -74,7 +72,7 @@ func NewSqliteBackend(ctx context.Context) (*SqliteBackend, error) {
 	return d, d.initRepos()
 }
 
-// WithContext returns a copy of SqliteBackend with the given context.
+// WithContext returns a shallow copy of SqliteBackend with the given context.
 func (d SqliteBackend) WithContext(ctx context.Context) backend.Backend {
 	d.ctx = ctx
 	return &d
@@ -170,7 +168,7 @@ func (d *SqliteBackend) CreateRepository(name string, opts backend.RepositoryOpt
 	}
 
 	// Set cache
-	d.cache.Set(name, r)
+	d.cache.Set(d.ctx, cacheKey(name), r)
 
 	return r, d.initRepo(name)
 }
@@ -228,7 +226,7 @@ func (d *SqliteBackend) DeleteRepository(name string) error {
 
 	return wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		// Delete repo from cache
-		defer d.cache.Delete(name)
+		defer d.cache.Delete(d.ctx, cacheKey(name))
 
 		if _, err := tx.Exec("DELETE FROM repo WHERE name = ?;", name); err != nil {
 			return err
@@ -265,7 +263,7 @@ func (d *SqliteBackend) RenameRepository(oldName string, newName string) error {
 
 	if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		// Delete cache
-		defer d.cache.Delete(oldName)
+		defer d.cache.Delete(d.ctx, cacheKey(oldName))
 
 		_, err := tx.Exec("UPDATE repo SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?;", newName, oldName)
 		if err != nil {
@@ -308,8 +306,10 @@ func (d *SqliteBackend) Repositories() ([]backend.Repository, error) {
 				return err
 			}
 
-			if r, ok := d.cache.Get(name); ok && r != nil {
-				repos = append(repos, r)
+			if r, ok := d.cache.Get(d.ctx, cacheKey(name)); ok && r != nil {
+				if r, ok := r.(*Repo); ok {
+					repos = append(repos, r)
+				}
 				continue
 			}
 
@@ -320,7 +320,7 @@ func (d *SqliteBackend) Repositories() ([]backend.Repository, error) {
 			}
 
 			// Cache repositories
-			d.cache.Set(name, r)
+			d.cache.Set(d.ctx, cacheKey(name), r)
 
 			repos = append(repos, r)
 		}
@@ -339,8 +339,10 @@ func (d *SqliteBackend) Repositories() ([]backend.Repository, error) {
 func (d *SqliteBackend) Repository(repo string) (backend.Repository, error) {
 	repo = utils.SanitizeRepo(repo)
 
-	if r, ok := d.cache.Get(repo); ok && r != nil {
-		return r, nil
+	if r, ok := d.cache.Get(d.ctx, cacheKey(repo)); ok && r != nil {
+		if r, ok := r.(*Repo); ok {
+			return r, nil
+		}
 	}
 
 	rp := filepath.Join(d.reposPath(), repo+".git")
@@ -367,7 +369,7 @@ func (d *SqliteBackend) Repository(repo string) (backend.Repository, error) {
 	}
 
 	// Add to cache
-	d.cache.Set(repo, r)
+	d.cache.Set(d.ctx, cacheKey(repo), r)
 
 	return r, nil
 }
@@ -427,7 +429,7 @@ func (d *SqliteBackend) SetHidden(repo string, hidden bool) error {
 	repo = utils.SanitizeRepo(repo)
 
 	// Delete cache
-	d.cache.Delete(repo)
+	d.cache.Delete(d.ctx, cacheKey(repo))
 
 	return wrapDbErr(wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		var count int
@@ -461,7 +463,7 @@ func (d *SqliteBackend) SetDescription(repo string, desc string) error {
 	repo = utils.SanitizeRepo(repo)
 
 	// Delete cache
-	d.cache.Delete(repo)
+	d.cache.Delete(d.ctx, cacheKey(repo))
 
 	return wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
 		var count int
@@ -483,7 +485,7 @@ func (d *SqliteBackend) SetPrivate(repo string, private bool) error {
 	repo = utils.SanitizeRepo(repo)
 
 	// Delete cache
-	d.cache.Delete(repo)
+	d.cache.Delete(d.ctx, cacheKey(repo))
 
 	return wrapDbErr(
 		wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
@@ -507,7 +509,7 @@ func (d *SqliteBackend) SetProjectName(repo string, name string) error {
 	repo = utils.SanitizeRepo(repo)
 
 	// Delete cache
-	d.cache.Delete(repo)
+	d.cache.Delete(d.ctx, cacheKey(repo))
 
 	return wrapDbErr(
 		wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
@@ -616,34 +618,7 @@ func (d *SqliteBackend) initRepos() error {
 	return nil
 }
 
-// TODO: implement a caching interface.
-type cache struct {
-	b     *SqliteBackend
-	repos *lru.Cache[string, *Repo]
-}
-
-func newCache(b *SqliteBackend, size int) *cache {
-	if size <= 0 {
-		size = 1
-	}
-	c := &cache{b: b}
-	cache, _ := lru.New[string, *Repo](size)
-	c.repos = cache
-	return c
-}
-
-func (c *cache) Get(repo string) (*Repo, bool) {
-	return c.repos.Get(repo)
-}
-
-func (c *cache) Set(repo string, r *Repo) {
-	c.repos.Add(repo, r)
-}
-
-func (c *cache) Delete(repo string) {
-	c.repos.Remove(repo)
-}
-
-func (c *cache) Len() int {
-	return c.repos.Len()
+// cacheKey returns the cache key for a repository.
+func cacheKey(name string) string {
+	return fmt.Sprintf("repo:%s", name)
 }

server/cache/cache.go 🔗

@@ -0,0 +1,26 @@
+package cache
+
+import (
+	"context"
+)
+
+// ItemOption is an option for setting cache items.
+type ItemOption func(Item)
+
+// Item is an interface that represents a cache item.
+type Item interface {
+	item()
+}
+
+// Option is an option for creating new cache.
+type Option func(Cache)
+
+// Cache is a caching interface.
+type Cache interface {
+	Get(ctx context.Context, key string) (value any, ok bool)
+	Set(ctx context.Context, key string, val any, opts ...ItemOption)
+	Keys(ctx context.Context) []string
+	Len(ctx context.Context) int64
+	Contains(ctx context.Context, key string) bool
+	Delete(ctx context.Context, key string)
+}

server/cache/context.go 🔗

@@ -0,0 +1,24 @@
+package cache
+
+import "context"
+
+var contextKey = &struct{ string }{"cache"}
+
+// WithContext returns a new context with the cache.
+func WithContext(ctx context.Context, c Cache) context.Context {
+	if c == nil {
+		return ctx
+	}
+	return context.WithValue(ctx, contextKey, c)
+}
+
+// FromContext returns the cache from the context.
+// If no cache is found, nil is returned.
+func FromContext(ctx context.Context) Cache {
+	c, ok := ctx.Value(contextKey).(Cache)
+	if !ok {
+		return nil
+	}
+
+	return c
+}

server/cache/lru/lru.go 🔗

@@ -0,0 +1,84 @@
+package lru
+
+import (
+	"context"
+
+	"github.com/charmbracelet/soft-serve/server/cache"
+	lru "github.com/hashicorp/golang-lru/v2"
+)
+
+// Cache is a memory cache that uses a LRU cache policy.
+type Cache struct {
+	cache   *lru.Cache[string, any]
+	onEvict func(key string, value any)
+	size    int
+}
+
+var _ cache.Cache = (*Cache)(nil)
+
+// WithSize sets the cache size.
+func WithSize(s int) cache.Option {
+	return func(c cache.Cache) {
+		ca := c.(*Cache)
+		ca.size = s
+	}
+}
+
+// WithEvictCallback sets the eviction callback.
+func WithEvictCallback(cb func(key string, value any)) cache.Option {
+	return func(c cache.Cache) {
+		ca := c.(*Cache)
+		ca.onEvict = cb
+	}
+}
+
+// NewCache returns a new Cache.
+func NewCache(_ context.Context, opts ...cache.Option) (cache.Cache, error) {
+	c := &Cache{}
+	for _, opt := range opts {
+		opt(c)
+	}
+
+	if c.size <= 0 {
+		c.size = 1
+	}
+
+	var err error
+	c.cache, err = lru.NewWithEvict(c.size, c.onEvict)
+	if err != nil {
+		return nil, err
+	}
+
+	return c, nil
+}
+
+// Delete implements cache.Cache.
+func (c *Cache) Delete(_ context.Context, key string) {
+	c.cache.Remove(key)
+}
+
+// Get implements cache.Cache.
+func (c *Cache) Get(_ context.Context, key string) (value any, ok bool) {
+	value, ok = c.cache.Get(key)
+	return
+}
+
+// Keys implements cache.Cache.
+func (c *Cache) Keys(_ context.Context) []string {
+	return c.cache.Keys()
+}
+
+// Set implements cache.Cache.
+func (c *Cache) Set(_ context.Context, key string, val any, _ ...cache.ItemOption) {
+	c.cache.Add(key, val)
+}
+
+// Len implements cache.Cache.
+func (c *Cache) Len(_ context.Context) int64 {
+	return int64(c.cache.Len())
+}
+
+// Contains implements cache.Cache.
+func (c *Cache) Contains(_ context.Context, key string) bool {
+	return c.cache.Contains(key)
+}

server/cache/noop/noop.go 🔗

@@ -0,0 +1,42 @@
+package noop
+
+import (
+	"context"
+
+	"github.com/charmbracelet/soft-serve/server/cache"
+)
+
+type noopCache struct{}
+
+// NewCache returns a new Cache.
+func NewCache(_ context.Context, _ ...cache.Option) (cache.Cache, error) {
+	return &noopCache{}, nil
+}
+
+// Contains implements Cache.
+func (*noopCache) Contains(_ context.Context, _ string) bool {
+	return false
+}
+
+// Delete implements Cache.
+func (*noopCache) Delete(_ context.Context, _ string) {}
+
+// Get implements Cache.
+func (*noopCache) Get(_ context.Context, _ string) (any, bool) {
+	return nil, false
+}
+
+// Keys implements Cache.
+func (*noopCache) Keys(_ context.Context) []string {
+	return []string{}
+}
+
+// Len implements Cache.
+func (*noopCache) Len(_ context.Context) int64 {
+	return -1
+}
+
+// Set implements Cache.
+func (*noopCache) Set(_ context.Context, _ string, _ any, _ ...cache.ItemOption) {}
+
+var _ cache.Cache = &noopCache{}

server/cache/registry.go 🔗

@@ -0,0 +1,39 @@
+package cache
+
+import (
+	"context"
+	"fmt"
+	"sync"
+)
+
+// Constructor is a function that returns a new cache.
+type Constructor func(context.Context, ...Option) (Cache, error)
+
+var (
+	registry = map[string]Constructor{}
+	mtx      sync.RWMutex
+
+	// ErrCacheNotFound is returned when a cache is not found.
+	ErrCacheNotFound = fmt.Errorf("cache not found")
+)
+
+// Register registers a cache.
+func Register(name string, fn Constructor) {
+	mtx.Lock()
+	defer mtx.Unlock()
+
+	registry[name] = fn
+}
+
+// New returns a new cache.
+func New(name string, ctx context.Context, opts ...Option) (Cache, error) {
+	mtx.RLock()
+	fn, ok := registry[name]
+	mtx.RUnlock()
+
+	if !ok {
+		return nil, ErrCacheNotFound
+	}
+
+	return fn(ctx, opts...)
+}

server/daemon/daemon_test.go 🔗

@@ -14,6 +14,8 @@ import (
 
 	"github.com/charmbracelet/soft-serve/server/backend"
 	"github.com/charmbracelet/soft-serve/server/backend/sqlite"
+	"github.com/charmbracelet/soft-serve/server/cache"
+	"github.com/charmbracelet/soft-serve/server/cache/noop"
 	"github.com/charmbracelet/soft-serve/server/config"
 	"github.com/charmbracelet/soft-serve/server/git"
 	"github.com/charmbracelet/soft-serve/server/test"
@@ -34,6 +36,8 @@ func TestMain(m *testing.M) {
 	os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
 	os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", test.RandomPort()))
 	ctx := context.TODO()
+	ca, _ := noop.NewCache()
+	ctx = cache.WithContext(ctx, ca)
 	cfg := config.DefaultConfig()
 	if err := cfg.WriteConfig(); err != nil {
 		log.Fatal("failed to write default config: %w", err)

server/server.go 🔗

@@ -10,6 +10,8 @@ import (
 
 	"github.com/charmbracelet/soft-serve/server/backend"
 	"github.com/charmbracelet/soft-serve/server/backend/sqlite"
+	"github.com/charmbracelet/soft-serve/server/cache"
+	"github.com/charmbracelet/soft-serve/server/cache/lru"
 	"github.com/charmbracelet/soft-serve/server/config"
 	"github.com/charmbracelet/soft-serve/server/cron"
 	"github.com/charmbracelet/soft-serve/server/daemon"
@@ -43,6 +45,16 @@ func NewServer(ctx context.Context) (*Server, error) {
 	cfg := config.FromContext(ctx)
 
 	var err error
+
+	if c := cache.FromContext(ctx); c == nil {
+		lruCache, err := lru.NewCache(lru.WithSize(1000))
+		if err != nil {
+			return nil, fmt.Errorf("create default cache: %w", err)
+		}
+
+		ctx = cache.WithContext(ctx, lruCache)
+	}
+
 	if cfg.Backend == nil {
 		sb, err := sqlite.NewSqliteBackend(ctx)
 		if err != nil {

server/ssh/session_test.go 🔗

@@ -10,6 +10,8 @@ import (
 	"time"
 
 	"github.com/charmbracelet/soft-serve/server/backend/sqlite"
+	"github.com/charmbracelet/soft-serve/server/cache"
+	"github.com/charmbracelet/soft-serve/server/cache/noop"
 	"github.com/charmbracelet/soft-serve/server/config"
 	"github.com/charmbracelet/soft-serve/server/test"
 	"github.com/charmbracelet/ssh"
@@ -58,6 +60,8 @@ func setup(tb testing.TB) (*gossh.Session, func() error) {
 		is.NoErr(os.RemoveAll(dp))
 	})
 	ctx := context.TODO()
+	ca, _ := noop.NewCache()
+	ctx = cache.WithContext(ctx, ca)
 	cfg := config.DefaultConfig()
 	ctx = config.WithContext(ctx, cfg)
 	fb, err := sqlite.NewSqliteBackend(ctx)