diff --git a/server/backend/backend.go b/server/backend/backend.go index 586d95132504a852dc8fa339de464d3360aca59d..ba9ad61d3713dd18f72a8a9e91d989ddcc370fe1 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -7,17 +7,19 @@ import ( "github.com/charmbracelet/soft-serve/server/config" "github.com/charmbracelet/soft-serve/server/db" "github.com/charmbracelet/soft-serve/server/store" + "github.com/charmbracelet/soft-serve/server/task" ) // Backend is the Soft Serve backend that handles users, repositories, and // server settings management and operations. type Backend struct { - ctx context.Context - cfg *config.Config - db *db.DB - store store.Store - logger *log.Logger - cache *cache + ctx context.Context + cfg *config.Config + db *db.DB + store store.Store + logger *log.Logger + cache *cache + manager *task.Manager } // New returns a new Soft Serve backend. @@ -25,11 +27,12 @@ func New(ctx context.Context, cfg *config.Config, db *db.DB) *Backend { dbstore := store.FromContext(ctx) logger := log.FromContext(ctx).WithPrefix("backend") b := &Backend{ - ctx: ctx, - cfg: cfg, - db: db, - store: dbstore, - logger: logger, + ctx: ctx, + cfg: cfg, + db: db, + store: dbstore, + logger: logger, + manager: task.NewManager(ctx), } // TODO: implement a proper caching interface diff --git a/server/backend/repo.go b/server/backend/repo.go index c9c6a7257e2674909cf7d8e44811a956d3562129..2d8633cc3961c6883cea0942eef2aea67cadaa78 100644 --- a/server/backend/repo.go +++ b/server/backend/repo.go @@ -19,6 +19,7 @@ import ( "github.com/charmbracelet/soft-serve/server/lfs" "github.com/charmbracelet/soft-serve/server/proto" "github.com/charmbracelet/soft-serve/server/storage" + "github.com/charmbracelet/soft-serve/server/task" "github.com/charmbracelet/soft-serve/server/utils" ) @@ -91,7 +92,8 @@ func (d *Backend) CreateRepository(ctx context.Context, name string, user proto. } // ImportRepository imports a repository from remote. -func (d *Backend) ImportRepository(ctx context.Context, name string, user proto.User, remote string, opts proto.RepositoryOptions) (proto.Repository, error) { +// XXX: This a expensive operation and should be run in a goroutine. +func (d *Backend) ImportRepository(_ context.Context, name string, user proto.User, remote string, opts proto.RepositoryOptions) (proto.Repository, error) { name = utils.SanitizeRepo(name) if err := utils.ValidateRepo(name); err != nil { return nil, err @@ -100,91 +102,110 @@ func (d *Backend) ImportRepository(ctx context.Context, name string, user proto. repo := name + ".git" rp := filepath.Join(d.reposPath(), repo) + tid := "import:" + name + if d.manager.Exists(tid) { + return nil, task.ErrAlreadyStarted + } + if _, err := os.Stat(rp); err == nil || os.IsExist(err) { return nil, proto.ErrRepoExist } - copts := git.CloneOptions{ - Bare: true, - Mirror: opts.Mirror, - Quiet: true, - CommandOptions: git.CommandOptions{ - Timeout: -1, - Context: ctx, - Envs: []string{ - fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`, - filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"), - d.cfg.SSH.ClientKeyPath, - ), + done := make(chan error, 1) + repoc := make(chan proto.Repository, 1) + d.logger.Info("importing repository", "name", name, "remote", remote, "path", rp) + d.manager.Add(tid, func(ctx context.Context) (err error) { + copts := git.CloneOptions{ + Bare: true, + Mirror: opts.Mirror, + Quiet: true, + CommandOptions: git.CommandOptions{ + Timeout: -1, + Context: ctx, + Envs: []string{ + fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`, + filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"), + d.cfg.SSH.ClientKeyPath, + ), + }, }, - }, - } - - if err := git.Clone(remote, rp, copts); err != nil { - d.logger.Error("failed to clone repository", "err", err, "mirror", opts.Mirror, "remote", remote, "path", rp) - // Cleanup the mess! - if rerr := os.RemoveAll(rp); rerr != nil { - err = errors.Join(err, rerr) } - return nil, err - } + if err := git.Clone(remote, rp, copts); err != nil { + d.logger.Error("failed to clone repository", "err", err, "mirror", opts.Mirror, "remote", remote, "path", rp) + // Cleanup the mess! + if rerr := os.RemoveAll(rp); rerr != nil { + err = errors.Join(err, rerr) + } - r, err := d.CreateRepository(ctx, name, user, opts) - if err != nil { - d.logger.Error("failed to create repository", "err", err, "name", name) - return nil, err - } + return err + } - defer func() { + r, err := d.CreateRepository(ctx, name, user, opts) if err != nil { - if rerr := d.DeleteRepository(ctx, name); rerr != nil { - d.logger.Error("failed to delete repository", "err", rerr, "name", name) + d.logger.Error("failed to create repository", "err", err, "name", name) + return err + } + + defer func() { + if err != nil { + if rerr := d.DeleteRepository(ctx, name); rerr != nil { + d.logger.Error("failed to delete repository", "err", rerr, "name", name) + } } + }() + + rr, err := r.Open() + if err != nil { + d.logger.Error("failed to open repository", "err", err, "path", rp) + return err } - }() - rr, err := r.Open() - if err != nil { - d.logger.Error("failed to open repository", "err", err, "path", rp) - return nil, err - } + repoc <- r - rcfg, err := rr.Config() - if err != nil { - d.logger.Error("failed to get repository config", "err", err, "path", rp) - return nil, err - } + rcfg, err := rr.Config() + if err != nil { + d.logger.Error("failed to get repository config", "err", err, "path", rp) + return err + } - endpoint := remote - if opts.LFSEndpoint != "" { - endpoint = opts.LFSEndpoint - } + endpoint := remote + if opts.LFSEndpoint != "" { + endpoint = opts.LFSEndpoint + } - rcfg.Section("lfs").SetOption("url", endpoint) + rcfg.Section("lfs").SetOption("url", endpoint) - if err := rr.SetConfig(rcfg); err != nil { - d.logger.Error("failed to set repository config", "err", err, "path", rp) - return nil, err - } + if err := rr.SetConfig(rcfg); err != nil { + d.logger.Error("failed to set repository config", "err", err, "path", rp) + return err + } - ep, err := lfs.NewEndpoint(endpoint) - if err != nil { - d.logger.Error("failed to create lfs endpoint", "err", err, "path", rp) - return nil, err - } + ep, err := lfs.NewEndpoint(endpoint) + if err != nil { + d.logger.Error("failed to create lfs endpoint", "err", err, "path", rp) + return err + } - client := lfs.NewClient(ep) - if client == nil { - return nil, fmt.Errorf("failed to create lfs client: unsupported endpoint %s", endpoint) - } + client := lfs.NewClient(ep) + if client == nil { + return fmt.Errorf("failed to create lfs client: unsupported endpoint %s", endpoint) + } - if err := StoreRepoMissingLFSObjects(ctx, r, d.db, d.store, client); err != nil { - d.logger.Error("failed to store missing lfs objects", "err", err, "path", rp) - return nil, err - } + if err := StoreRepoMissingLFSObjects(ctx, r, d.db, d.store, client); err != nil { + d.logger.Error("failed to store missing lfs objects", "err", err, "path", rp) + return err + } - return r, nil + return nil + }) + + go func() { + d.logger.Info("running import", "name", name) + d.manager.Run(tid, done) + }() + + return <-repoc, <-done } // DeleteRepository deletes a repository. diff --git a/server/ssh/cmd/import.go b/server/ssh/cmd/import.go index b34b46f613389ed41a779060549e2ac760d8c4b9..85cb2fb8f3e09a464a8dc9efccaa5d3defe4ebfb 100644 --- a/server/ssh/cmd/import.go +++ b/server/ssh/cmd/import.go @@ -1,8 +1,11 @@ package cmd import ( + "errors" + "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/proto" + "github.com/charmbracelet/soft-serve/server/task" "github.com/spf13/cobra" ) @@ -36,8 +39,13 @@ func importCommand() *cobra.Command { LFS: lfs, LFSEndpoint: lfsEndpoint, }); err != nil { + if errors.Is(err, task.ErrAlreadyStarted) { + return errors.New("import already in progress") + } + return err } + return nil }, } diff --git a/server/task/manager.go b/server/task/manager.go new file mode 100644 index 0000000000000000000000000000000000000000..4f8763711d9478820774c04a04a44c30e3c5d7b2 --- /dev/null +++ b/server/task/manager.go @@ -0,0 +1,116 @@ +package task + +import ( + "context" + "errors" + "sync" + "sync/atomic" +) + +var ( + // ErrNotFound is returned when a process is not found. + ErrNotFound = errors.New("task not found") + + // ErrAlreadyStarted is returned when a process is already started. + ErrAlreadyStarted = errors.New("task already started") +) + +// Task is a task that can be started and stopped. +type Task struct { + id string + fn func(context.Context) error + started atomic.Bool + ctx context.Context + cancel context.CancelFunc + err error +} + +// Manager manages tasks. +type Manager struct { + m sync.Map + ctx context.Context +} + +// NewManager returns a new task manager. +func NewManager(ctx context.Context) *Manager { + return &Manager{ + m: sync.Map{}, + ctx: ctx, + } +} + +// Add adds a task to the manager. +// If the process already exists, it is a no-op. +func (m *Manager) Add(id string, fn func(context.Context) error) { + if m.Exists(id) { + return + } + + ctx, cancel := context.WithCancel(m.ctx) + m.m.Store(id, &Task{ + id: id, + fn: fn, + ctx: ctx, + cancel: cancel, + }) +} + +// Stop stops the task and removes it from the manager. +func (m *Manager) Stop(id string) error { + v, ok := m.m.Load(id) + if !ok { + return ErrNotFound + } + + p := v.(*Task) + p.cancel() + + m.m.Delete(id) + return nil +} + +// Exists checks if a task exists. +func (m *Manager) Exists(id string) bool { + _, ok := m.m.Load(id) + return ok +} + +// Run starts the task if it exists. +// Otherwise, it waits for the process to finish. +func (m *Manager) Run(id string, done chan<- error) { + v, ok := m.m.Load(id) + if !ok { + done <- ErrNotFound + return + } + + p := v.(*Task) + if p.started.Load() { + <-p.ctx.Done() + if p.err != nil { + done <- p.err + return + } + + done <- p.ctx.Err() + } + + p.started.Store(true) + m.m.Store(id, p) + defer p.cancel() + defer m.m.Delete(id) + + errc := make(chan error, 1) + go func(ctx context.Context) { + errc <- p.fn(ctx) + }(p.ctx) + + select { + case <-m.ctx.Done(): + done <- m.ctx.Err() + case err := <-errc: + p.err = err + m.m.Store(id, p) + done <- err + } +}