feat,fix: add task manager

Ayman Bagabas created

Implement a task manager that can run different tasks given a unique ID.

This is needed to accommodate expensive tasks like importing a large
repository. The current behavior uses the connection's context (the SSH
connection) to import the repository. However, if the server has defined
an SSH `idle_timeout`, `max_timeout`, and/or the connection drops,
Soft Serve cancels the git clone process and aborts importing the
repository.

Instead, we add the import task to the "task manager" and wait on the
connection context. If a task already exists for the same repository,
return `Error: import already in progress`.

Fixes: https://github.com/charmbracelet/soft-serve/issues/348

Change summary

server/backend/backend.go |  25 +++--
server/backend/repo.go    | 151 +++++++++++++++++++++++-----------------
server/ssh/cmd/import.go  |   8 ++
server/task/manager.go    | 116 +++++++++++++++++++++++++++++++
4 files changed, 224 insertions(+), 76 deletions(-)

Detailed changes

server/backend/backend.go 🔗

@@ -7,17 +7,19 @@ import (
 	"github.com/charmbracelet/soft-serve/server/config"
 	"github.com/charmbracelet/soft-serve/server/db"
 	"github.com/charmbracelet/soft-serve/server/store"
+	"github.com/charmbracelet/soft-serve/server/task"
 )
 
 // Backend is the Soft Serve backend that handles users, repositories, and
 // server settings management and operations.
 type Backend struct {
-	ctx    context.Context
-	cfg    *config.Config
-	db     *db.DB
-	store  store.Store
-	logger *log.Logger
-	cache  *cache
+	ctx     context.Context
+	cfg     *config.Config
+	db      *db.DB
+	store   store.Store
+	logger  *log.Logger
+	cache   *cache
+	manager *task.Manager
 }
 
 // New returns a new Soft Serve backend.
@@ -25,11 +27,12 @@ func New(ctx context.Context, cfg *config.Config, db *db.DB) *Backend {
 	dbstore := store.FromContext(ctx)
 	logger := log.FromContext(ctx).WithPrefix("backend")
 	b := &Backend{
-		ctx:    ctx,
-		cfg:    cfg,
-		db:     db,
-		store:  dbstore,
-		logger: logger,
+		ctx:     ctx,
+		cfg:     cfg,
+		db:      db,
+		store:   dbstore,
+		logger:  logger,
+		manager: task.NewManager(ctx),
 	}
 
 	// TODO: implement a proper caching interface

server/backend/repo.go 🔗

@@ -19,6 +19,7 @@ import (
 	"github.com/charmbracelet/soft-serve/server/lfs"
 	"github.com/charmbracelet/soft-serve/server/proto"
 	"github.com/charmbracelet/soft-serve/server/storage"
+	"github.com/charmbracelet/soft-serve/server/task"
 	"github.com/charmbracelet/soft-serve/server/utils"
 )
 
@@ -91,7 +92,8 @@ func (d *Backend) CreateRepository(ctx context.Context, name string, user proto.
 }
 
 // ImportRepository imports a repository from remote.
-func (d *Backend) ImportRepository(ctx context.Context, name string, user proto.User, remote string, opts proto.RepositoryOptions) (proto.Repository, error) {
+// XXX: This a expensive operation and should be run in a goroutine.
+func (d *Backend) ImportRepository(_ context.Context, name string, user proto.User, remote string, opts proto.RepositoryOptions) (proto.Repository, error) {
 	name = utils.SanitizeRepo(name)
 	if err := utils.ValidateRepo(name); err != nil {
 		return nil, err
@@ -100,91 +102,110 @@ func (d *Backend) ImportRepository(ctx context.Context, name string, user proto.
 	repo := name + ".git"
 	rp := filepath.Join(d.reposPath(), repo)
 
+	tid := "import:" + name
+	if d.manager.Exists(tid) {
+		return nil, task.ErrAlreadyStarted
+	}
+
 	if _, err := os.Stat(rp); err == nil || os.IsExist(err) {
 		return nil, proto.ErrRepoExist
 	}
 
-	copts := git.CloneOptions{
-		Bare:   true,
-		Mirror: opts.Mirror,
-		Quiet:  true,
-		CommandOptions: git.CommandOptions{
-			Timeout: -1,
-			Context: ctx,
-			Envs: []string{
-				fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`,
-					filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"),
-					d.cfg.SSH.ClientKeyPath,
-				),
+	done := make(chan error, 1)
+	repoc := make(chan proto.Repository, 1)
+	d.logger.Info("importing repository", "name", name, "remote", remote, "path", rp)
+	d.manager.Add(tid, func(ctx context.Context) (err error) {
+		copts := git.CloneOptions{
+			Bare:   true,
+			Mirror: opts.Mirror,
+			Quiet:  true,
+			CommandOptions: git.CommandOptions{
+				Timeout: -1,
+				Context: ctx,
+				Envs: []string{
+					fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`,
+						filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"),
+						d.cfg.SSH.ClientKeyPath,
+					),
+				},
 			},
-		},
-	}
-
-	if err := git.Clone(remote, rp, copts); err != nil {
-		d.logger.Error("failed to clone repository", "err", err, "mirror", opts.Mirror, "remote", remote, "path", rp)
-		// Cleanup the mess!
-		if rerr := os.RemoveAll(rp); rerr != nil {
-			err = errors.Join(err, rerr)
 		}
 
-		return nil, err
-	}
+		if err := git.Clone(remote, rp, copts); err != nil {
+			d.logger.Error("failed to clone repository", "err", err, "mirror", opts.Mirror, "remote", remote, "path", rp)
+			// Cleanup the mess!
+			if rerr := os.RemoveAll(rp); rerr != nil {
+				err = errors.Join(err, rerr)
+			}
 
-	r, err := d.CreateRepository(ctx, name, user, opts)
-	if err != nil {
-		d.logger.Error("failed to create repository", "err", err, "name", name)
-		return nil, err
-	}
+			return err
+		}
 
-	defer func() {
+		r, err := d.CreateRepository(ctx, name, user, opts)
 		if err != nil {
-			if rerr := d.DeleteRepository(ctx, name); rerr != nil {
-				d.logger.Error("failed to delete repository", "err", rerr, "name", name)
+			d.logger.Error("failed to create repository", "err", err, "name", name)
+			return err
+		}
+
+		defer func() {
+			if err != nil {
+				if rerr := d.DeleteRepository(ctx, name); rerr != nil {
+					d.logger.Error("failed to delete repository", "err", rerr, "name", name)
+				}
 			}
+		}()
+
+		rr, err := r.Open()
+		if err != nil {
+			d.logger.Error("failed to open repository", "err", err, "path", rp)
+			return err
 		}
-	}()
 
-	rr, err := r.Open()
-	if err != nil {
-		d.logger.Error("failed to open repository", "err", err, "path", rp)
-		return nil, err
-	}
+		repoc <- r
 
-	rcfg, err := rr.Config()
-	if err != nil {
-		d.logger.Error("failed to get repository config", "err", err, "path", rp)
-		return nil, err
-	}
+		rcfg, err := rr.Config()
+		if err != nil {
+			d.logger.Error("failed to get repository config", "err", err, "path", rp)
+			return err
+		}
 
-	endpoint := remote
-	if opts.LFSEndpoint != "" {
-		endpoint = opts.LFSEndpoint
-	}
+		endpoint := remote
+		if opts.LFSEndpoint != "" {
+			endpoint = opts.LFSEndpoint
+		}
 
-	rcfg.Section("lfs").SetOption("url", endpoint)
+		rcfg.Section("lfs").SetOption("url", endpoint)
 
-	if err := rr.SetConfig(rcfg); err != nil {
-		d.logger.Error("failed to set repository config", "err", err, "path", rp)
-		return nil, err
-	}
+		if err := rr.SetConfig(rcfg); err != nil {
+			d.logger.Error("failed to set repository config", "err", err, "path", rp)
+			return err
+		}
 
-	ep, err := lfs.NewEndpoint(endpoint)
-	if err != nil {
-		d.logger.Error("failed to create lfs endpoint", "err", err, "path", rp)
-		return nil, err
-	}
+		ep, err := lfs.NewEndpoint(endpoint)
+		if err != nil {
+			d.logger.Error("failed to create lfs endpoint", "err", err, "path", rp)
+			return err
+		}
 
-	client := lfs.NewClient(ep)
-	if client == nil {
-		return nil, fmt.Errorf("failed to create lfs client: unsupported endpoint %s", endpoint)
-	}
+		client := lfs.NewClient(ep)
+		if client == nil {
+			return fmt.Errorf("failed to create lfs client: unsupported endpoint %s", endpoint)
+		}
 
-	if err := StoreRepoMissingLFSObjects(ctx, r, d.db, d.store, client); err != nil {
-		d.logger.Error("failed to store missing lfs objects", "err", err, "path", rp)
-		return nil, err
-	}
+		if err := StoreRepoMissingLFSObjects(ctx, r, d.db, d.store, client); err != nil {
+			d.logger.Error("failed to store missing lfs objects", "err", err, "path", rp)
+			return err
+		}
 
-	return r, nil
+		return nil
+	})
+
+	go func() {
+		d.logger.Info("running import", "name", name)
+		d.manager.Run(tid, done)
+	}()
+
+	return <-repoc, <-done
 }
 
 // DeleteRepository deletes a repository.

server/ssh/cmd/import.go 🔗

@@ -1,8 +1,11 @@
 package cmd
 
 import (
+	"errors"
+
 	"github.com/charmbracelet/soft-serve/server/backend"
 	"github.com/charmbracelet/soft-serve/server/proto"
+	"github.com/charmbracelet/soft-serve/server/task"
 	"github.com/spf13/cobra"
 )
 
@@ -36,8 +39,13 @@ func importCommand() *cobra.Command {
 				LFS:         lfs,
 				LFSEndpoint: lfsEndpoint,
 			}); err != nil {
+				if errors.Is(err, task.ErrAlreadyStarted) {
+					return errors.New("import already in progress")
+				}
+
 				return err
 			}
+
 			return nil
 		},
 	}

server/task/manager.go 🔗

@@ -0,0 +1,116 @@
+package task
+
+import (
+	"context"
+	"errors"
+	"sync"
+	"sync/atomic"
+)
+
+var (
+	// ErrNotFound is returned when a process is not found.
+	ErrNotFound = errors.New("task not found")
+
+	// ErrAlreadyStarted is returned when a process is already started.
+	ErrAlreadyStarted = errors.New("task already started")
+)
+
+// Task is a task that can be started and stopped.
+type Task struct {
+	id      string
+	fn      func(context.Context) error
+	started atomic.Bool
+	ctx     context.Context
+	cancel  context.CancelFunc
+	err     error
+}
+
+// Manager manages tasks.
+type Manager struct {
+	m   sync.Map
+	ctx context.Context
+}
+
+// NewManager returns a new task manager.
+func NewManager(ctx context.Context) *Manager {
+	return &Manager{
+		m:   sync.Map{},
+		ctx: ctx,
+	}
+}
+
+// Add adds a task to the manager.
+// If the process already exists, it is a no-op.
+func (m *Manager) Add(id string, fn func(context.Context) error) {
+	if m.Exists(id) {
+		return
+	}
+
+	ctx, cancel := context.WithCancel(m.ctx)
+	m.m.Store(id, &Task{
+		id:     id,
+		fn:     fn,
+		ctx:    ctx,
+		cancel: cancel,
+	})
+}
+
+// Stop stops the task and removes it from the manager.
+func (m *Manager) Stop(id string) error {
+	v, ok := m.m.Load(id)
+	if !ok {
+		return ErrNotFound
+	}
+
+	p := v.(*Task)
+	p.cancel()
+
+	m.m.Delete(id)
+	return nil
+}
+
+// Exists checks if a task exists.
+func (m *Manager) Exists(id string) bool {
+	_, ok := m.m.Load(id)
+	return ok
+}
+
+// Run starts the task if it exists.
+// Otherwise, it waits for the process to finish.
+func (m *Manager) Run(id string, done chan<- error) {
+	v, ok := m.m.Load(id)
+	if !ok {
+		done <- ErrNotFound
+		return
+	}
+
+	p := v.(*Task)
+	if p.started.Load() {
+		<-p.ctx.Done()
+		if p.err != nil {
+			done <- p.err
+			return
+		}
+
+		done <- p.ctx.Err()
+	}
+
+	p.started.Store(true)
+	m.m.Store(id, p)
+	defer p.cancel()
+	defer m.m.Delete(id)
+
+	errc := make(chan error, 1)
+	go func(ctx context.Context) {
+		errc <- p.fn(ctx)
+	}(p.ctx)
+
+	select {
+	case <-m.ctx.Done():
+		done <- m.ctx.Err()
+	case err := <-errc:
+		p.err = err
+		m.m.Store(id, p)
+		done <- err
+	}
+}