1package shell
2
3import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "os"
9 "strings"
10 "sync"
11
12 "github.com/charmbracelet/crush/internal/logging"
13 "mvdan.cc/sh/v3/expand"
14 "mvdan.cc/sh/v3/interp"
15 "mvdan.cc/sh/v3/syntax"
16)
17
18type PersistentShell struct {
19 env []string
20 cwd string
21 mu sync.Mutex
22}
23
24var (
25 once sync.Once
26 shellInstance *PersistentShell
27)
28
29func GetPersistentShell(cwd string) *PersistentShell {
30 once.Do(func() {
31 shellInstance = newPersistentShell(cwd)
32 })
33 return shellInstance
34}
35
36func newPersistentShell(cwd string) *PersistentShell {
37 return &PersistentShell{
38 cwd: cwd,
39 env: os.Environ(),
40 }
41}
42
43func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) {
44 s.mu.Lock()
45 defer s.mu.Unlock()
46
47 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
48 if err != nil {
49 return "", "", fmt.Errorf("could not parse command: %w", err)
50 }
51
52 var stdout, stderr bytes.Buffer
53 runner, err := interp.New(
54 interp.StdIO(nil, &stdout, &stderr),
55 interp.Interactive(false),
56 interp.Env(expand.ListEnviron(s.env...)),
57 interp.Dir(s.cwd),
58 )
59 if err != nil {
60 return "", "", fmt.Errorf("could not run command: %w", err)
61 }
62
63 err = runner.Run(ctx, line)
64 s.cwd = runner.Dir
65 s.env = []string{}
66 for name, vr := range runner.Vars {
67 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
68 }
69 logging.InfoPersist("Command finished", "command", command, "err", err)
70 return stdout.String(), stderr.String(), err
71}
72
73func IsInterrupt(err error) bool {
74 return errors.Is(err, context.Canceled) ||
75 errors.Is(err, context.DeadlineExceeded)
76}
77
78func ExitCode(err error) int {
79 if err == nil {
80 return 0
81 }
82 status, ok := interp.IsExitStatus(err)
83 if ok {
84 return int(status)
85 }
86 return 1
87}