shell.go

 1package shell
 2
 3import (
 4	"bytes"
 5	"context"
 6	"errors"
 7	"fmt"
 8	"os"
 9	"strings"
10	"sync"
11
12	"github.com/charmbracelet/crush/internal/logging"
13	"mvdan.cc/sh/v3/expand"
14	"mvdan.cc/sh/v3/interp"
15	"mvdan.cc/sh/v3/syntax"
16)
17
18type PersistentShell struct {
19	env []string
20	cwd string
21	mu  sync.Mutex
22}
23
24var (
25	once          sync.Once
26	shellInstance *PersistentShell
27)
28
29func GetPersistentShell(cwd string) *PersistentShell {
30	once.Do(func() {
31		shellInstance = newPersistentShell(cwd)
32	})
33	return shellInstance
34}
35
36func newPersistentShell(cwd string) *PersistentShell {
37	return &PersistentShell{
38		cwd: cwd,
39		env: os.Environ(),
40	}
41}
42
43func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) {
44	s.mu.Lock()
45	defer s.mu.Unlock()
46
47	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
48	if err != nil {
49		return "", "", fmt.Errorf("could not parse command: %w", err)
50	}
51
52	var stdout, stderr bytes.Buffer
53	runner, err := interp.New(
54		interp.StdIO(nil, &stdout, &stderr),
55		interp.Interactive(false),
56		interp.Env(expand.ListEnviron(s.env...)),
57		interp.Dir(s.cwd),
58	)
59	if err != nil {
60		return "", "", fmt.Errorf("could not run command: %w", err)
61	}
62
63	err = runner.Run(ctx, line)
64	s.cwd = runner.Dir
65	s.env = []string{}
66	for name, vr := range runner.Vars {
67		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
68	}
69	logging.InfoPersist("Command finished", "command", command, "err", err)
70	return stdout.String(), stderr.String(), err
71}
72
73func IsInterrupt(err error) bool {
74	return errors.Is(err, context.Canceled) ||
75		errors.Is(err, context.DeadlineExceeded)
76}
77
78func ExitCode(err error) int {
79	if err == nil {
80		return 0
81	}
82	status, ok := interp.IsExitStatus(err)
83	if ok {
84		return int(status)
85	}
86	return 1
87}