From c4dde1c3f98a5f6d8a182e1ad2b460b24aeb71e7 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Wed, 2 Aug 2023 17:44:49 -0400 Subject: [PATCH] feat,fix: add task manager Implement a task manager that can run different tasks given a unique ID. This is needed to accommodate expensive tasks like importing a large repository. The current behavior uses the connection's context (the SSH connection) to import the repository. However, if the server has defined an SSH `idle_timeout`, `max_timeout`, and/or the connection drops, Soft Serve cancels the git clone process and aborts importing the repository. Instead, we add the import task to the "task manager" and wait on the connection context. If a task already exists for the same repository, return `Error: import already in progress`. Fixes: https://github.com/charmbracelet/soft-serve/issues/348 --- server/backend/backend.go | 25 ++++--- server/backend/repo.go | 151 ++++++++++++++++++++++---------------- server/ssh/cmd/import.go | 8 ++ server/task/manager.go | 116 +++++++++++++++++++++++++++++ 4 files changed, 224 insertions(+), 76 deletions(-) create mode 100644 server/task/manager.go diff --git a/server/backend/backend.go b/server/backend/backend.go index 586d95132504a852dc8fa339de464d3360aca59d..ba9ad61d3713dd18f72a8a9e91d989ddcc370fe1 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -7,17 +7,19 @@ import ( "github.com/charmbracelet/soft-serve/server/config" "github.com/charmbracelet/soft-serve/server/db" "github.com/charmbracelet/soft-serve/server/store" + "github.com/charmbracelet/soft-serve/server/task" ) // Backend is the Soft Serve backend that handles users, repositories, and // server settings management and operations. type Backend struct { - ctx context.Context - cfg *config.Config - db *db.DB - store store.Store - logger *log.Logger - cache *cache + ctx context.Context + cfg *config.Config + db *db.DB + store store.Store + logger *log.Logger + cache *cache + manager *task.Manager } // New returns a new Soft Serve backend. @@ -25,11 +27,12 @@ func New(ctx context.Context, cfg *config.Config, db *db.DB) *Backend { dbstore := store.FromContext(ctx) logger := log.FromContext(ctx).WithPrefix("backend") b := &Backend{ - ctx: ctx, - cfg: cfg, - db: db, - store: dbstore, - logger: logger, + ctx: ctx, + cfg: cfg, + db: db, + store: dbstore, + logger: logger, + manager: task.NewManager(ctx), } // TODO: implement a proper caching interface diff --git a/server/backend/repo.go b/server/backend/repo.go index c9c6a7257e2674909cf7d8e44811a956d3562129..2d8633cc3961c6883cea0942eef2aea67cadaa78 100644 --- a/server/backend/repo.go +++ b/server/backend/repo.go @@ -19,6 +19,7 @@ import ( "github.com/charmbracelet/soft-serve/server/lfs" "github.com/charmbracelet/soft-serve/server/proto" "github.com/charmbracelet/soft-serve/server/storage" + "github.com/charmbracelet/soft-serve/server/task" "github.com/charmbracelet/soft-serve/server/utils" ) @@ -91,7 +92,8 @@ func (d *Backend) CreateRepository(ctx context.Context, name string, user proto. } // ImportRepository imports a repository from remote. -func (d *Backend) ImportRepository(ctx context.Context, name string, user proto.User, remote string, opts proto.RepositoryOptions) (proto.Repository, error) { +// XXX: This a expensive operation and should be run in a goroutine. +func (d *Backend) ImportRepository(_ context.Context, name string, user proto.User, remote string, opts proto.RepositoryOptions) (proto.Repository, error) { name = utils.SanitizeRepo(name) if err := utils.ValidateRepo(name); err != nil { return nil, err @@ -100,91 +102,110 @@ func (d *Backend) ImportRepository(ctx context.Context, name string, user proto. repo := name + ".git" rp := filepath.Join(d.reposPath(), repo) + tid := "import:" + name + if d.manager.Exists(tid) { + return nil, task.ErrAlreadyStarted + } + if _, err := os.Stat(rp); err == nil || os.IsExist(err) { return nil, proto.ErrRepoExist } - copts := git.CloneOptions{ - Bare: true, - Mirror: opts.Mirror, - Quiet: true, - CommandOptions: git.CommandOptions{ - Timeout: -1, - Context: ctx, - Envs: []string{ - fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`, - filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"), - d.cfg.SSH.ClientKeyPath, - ), + done := make(chan error, 1) + repoc := make(chan proto.Repository, 1) + d.logger.Info("importing repository", "name", name, "remote", remote, "path", rp) + d.manager.Add(tid, func(ctx context.Context) (err error) { + copts := git.CloneOptions{ + Bare: true, + Mirror: opts.Mirror, + Quiet: true, + CommandOptions: git.CommandOptions{ + Timeout: -1, + Context: ctx, + Envs: []string{ + fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`, + filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"), + d.cfg.SSH.ClientKeyPath, + ), + }, }, - }, - } - - if err := git.Clone(remote, rp, copts); err != nil { - d.logger.Error("failed to clone repository", "err", err, "mirror", opts.Mirror, "remote", remote, "path", rp) - // Cleanup the mess! - if rerr := os.RemoveAll(rp); rerr != nil { - err = errors.Join(err, rerr) } - return nil, err - } + if err := git.Clone(remote, rp, copts); err != nil { + d.logger.Error("failed to clone repository", "err", err, "mirror", opts.Mirror, "remote", remote, "path", rp) + // Cleanup the mess! + if rerr := os.RemoveAll(rp); rerr != nil { + err = errors.Join(err, rerr) + } - r, err := d.CreateRepository(ctx, name, user, opts) - if err != nil { - d.logger.Error("failed to create repository", "err", err, "name", name) - return nil, err - } + return err + } - defer func() { + r, err := d.CreateRepository(ctx, name, user, opts) if err != nil { - if rerr := d.DeleteRepository(ctx, name); rerr != nil { - d.logger.Error("failed to delete repository", "err", rerr, "name", name) + d.logger.Error("failed to create repository", "err", err, "name", name) + return err + } + + defer func() { + if err != nil { + if rerr := d.DeleteRepository(ctx, name); rerr != nil { + d.logger.Error("failed to delete repository", "err", rerr, "name", name) + } } + }() + + rr, err := r.Open() + if err != nil { + d.logger.Error("failed to open repository", "err", err, "path", rp) + return err } - }() - rr, err := r.Open() - if err != nil { - d.logger.Error("failed to open repository", "err", err, "path", rp) - return nil, err - } + repoc <- r - rcfg, err := rr.Config() - if err != nil { - d.logger.Error("failed to get repository config", "err", err, "path", rp) - return nil, err - } + rcfg, err := rr.Config() + if err != nil { + d.logger.Error("failed to get repository config", "err", err, "path", rp) + return err + } - endpoint := remote - if opts.LFSEndpoint != "" { - endpoint = opts.LFSEndpoint - } + endpoint := remote + if opts.LFSEndpoint != "" { + endpoint = opts.LFSEndpoint + } - rcfg.Section("lfs").SetOption("url", endpoint) + rcfg.Section("lfs").SetOption("url", endpoint) - if err := rr.SetConfig(rcfg); err != nil { - d.logger.Error("failed to set repository config", "err", err, "path", rp) - return nil, err - } + if err := rr.SetConfig(rcfg); err != nil { + d.logger.Error("failed to set repository config", "err", err, "path", rp) + return err + } - ep, err := lfs.NewEndpoint(endpoint) - if err != nil { - d.logger.Error("failed to create lfs endpoint", "err", err, "path", rp) - return nil, err - } + ep, err := lfs.NewEndpoint(endpoint) + if err != nil { + d.logger.Error("failed to create lfs endpoint", "err", err, "path", rp) + return err + } - client := lfs.NewClient(ep) - if client == nil { - return nil, fmt.Errorf("failed to create lfs client: unsupported endpoint %s", endpoint) - } + client := lfs.NewClient(ep) + if client == nil { + return fmt.Errorf("failed to create lfs client: unsupported endpoint %s", endpoint) + } - if err := StoreRepoMissingLFSObjects(ctx, r, d.db, d.store, client); err != nil { - d.logger.Error("failed to store missing lfs objects", "err", err, "path", rp) - return nil, err - } + if err := StoreRepoMissingLFSObjects(ctx, r, d.db, d.store, client); err != nil { + d.logger.Error("failed to store missing lfs objects", "err", err, "path", rp) + return err + } - return r, nil + return nil + }) + + go func() { + d.logger.Info("running import", "name", name) + d.manager.Run(tid, done) + }() + + return <-repoc, <-done } // DeleteRepository deletes a repository. diff --git a/server/ssh/cmd/import.go b/server/ssh/cmd/import.go index b34b46f613389ed41a779060549e2ac760d8c4b9..85cb2fb8f3e09a464a8dc9efccaa5d3defe4ebfb 100644 --- a/server/ssh/cmd/import.go +++ b/server/ssh/cmd/import.go @@ -1,8 +1,11 @@ package cmd import ( + "errors" + "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/proto" + "github.com/charmbracelet/soft-serve/server/task" "github.com/spf13/cobra" ) @@ -36,8 +39,13 @@ func importCommand() *cobra.Command { LFS: lfs, LFSEndpoint: lfsEndpoint, }); err != nil { + if errors.Is(err, task.ErrAlreadyStarted) { + return errors.New("import already in progress") + } + return err } + return nil }, } diff --git a/server/task/manager.go b/server/task/manager.go new file mode 100644 index 0000000000000000000000000000000000000000..4f8763711d9478820774c04a04a44c30e3c5d7b2 --- /dev/null +++ b/server/task/manager.go @@ -0,0 +1,116 @@ +package task + +import ( + "context" + "errors" + "sync" + "sync/atomic" +) + +var ( + // ErrNotFound is returned when a process is not found. + ErrNotFound = errors.New("task not found") + + // ErrAlreadyStarted is returned when a process is already started. + ErrAlreadyStarted = errors.New("task already started") +) + +// Task is a task that can be started and stopped. +type Task struct { + id string + fn func(context.Context) error + started atomic.Bool + ctx context.Context + cancel context.CancelFunc + err error +} + +// Manager manages tasks. +type Manager struct { + m sync.Map + ctx context.Context +} + +// NewManager returns a new task manager. +func NewManager(ctx context.Context) *Manager { + return &Manager{ + m: sync.Map{}, + ctx: ctx, + } +} + +// Add adds a task to the manager. +// If the process already exists, it is a no-op. +func (m *Manager) Add(id string, fn func(context.Context) error) { + if m.Exists(id) { + return + } + + ctx, cancel := context.WithCancel(m.ctx) + m.m.Store(id, &Task{ + id: id, + fn: fn, + ctx: ctx, + cancel: cancel, + }) +} + +// Stop stops the task and removes it from the manager. +func (m *Manager) Stop(id string) error { + v, ok := m.m.Load(id) + if !ok { + return ErrNotFound + } + + p := v.(*Task) + p.cancel() + + m.m.Delete(id) + return nil +} + +// Exists checks if a task exists. +func (m *Manager) Exists(id string) bool { + _, ok := m.m.Load(id) + return ok +} + +// Run starts the task if it exists. +// Otherwise, it waits for the process to finish. +func (m *Manager) Run(id string, done chan<- error) { + v, ok := m.m.Load(id) + if !ok { + done <- ErrNotFound + return + } + + p := v.(*Task) + if p.started.Load() { + <-p.ctx.Done() + if p.err != nil { + done <- p.err + return + } + + done <- p.ctx.Err() + } + + p.started.Store(true) + m.m.Store(id, p) + defer p.cancel() + defer m.m.Delete(id) + + errc := make(chan error, 1) + go func(ctx context.Context) { + errc <- p.fn(ctx) + }(p.ctx) + + select { + case <-m.ctx.Done(): + done <- m.ctx.Err() + case err := <-errc: + p.err = err + m.m.Store(id, p) + done <- err + } +}