1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package offers two main types:
4// - Shell: A general-purpose shell executor for one-off or managed commands
5// - PersistentShell: A singleton shell that maintains state across the application
6//
7// WINDOWS COMPATIBILITY:
8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3),
9// even on Windows. Some caution has to be taken: commands should have forward
10// slashes (/) as path separators to work, even on Windows.
11package shell
12
13import (
14 "bytes"
15 "context"
16 "errors"
17 "fmt"
18 "os"
19 "slices"
20 "strings"
21 "sync"
22
23 "github.com/charmbracelet/crush/internal/slicesext"
24 "mvdan.cc/sh/moreinterp/coreutils"
25 "mvdan.cc/sh/v3/expand"
26 "mvdan.cc/sh/v3/interp"
27 "mvdan.cc/sh/v3/syntax"
28)
29
30// ShellType represents the type of shell to use
31type ShellType int
32
33const (
34 ShellTypePOSIX ShellType = iota
35 ShellTypeCmd
36 ShellTypePowerShell
37)
38
39// Logger interface for optional logging
40type Logger interface {
41 InfoPersist(msg string, keysAndValues ...any)
42}
43
44// noopLogger is a logger that does nothing
45type noopLogger struct{}
46
47func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
48
49// BlockFunc is a function that determines if a command should be blocked
50type BlockFunc func(args []string) bool
51
52// Shell provides cross-platform shell execution with optional state persistence
53type Shell struct {
54 env []string
55 cwd string
56 mu sync.Mutex
57 logger Logger
58 blockFuncs []BlockFunc
59}
60
61// Options for creating a new shell
62type Options struct {
63 WorkingDir string
64 Env []string
65 Logger Logger
66 BlockFuncs []BlockFunc
67}
68
69// NewShell creates a new shell instance with the given options
70func NewShell(opts *Options) *Shell {
71 if opts == nil {
72 opts = &Options{}
73 }
74
75 cwd := opts.WorkingDir
76 if cwd == "" {
77 cwd, _ = os.Getwd()
78 }
79
80 env := opts.Env
81 if env == nil {
82 env = os.Environ()
83 }
84
85 logger := opts.Logger
86 if logger == nil {
87 logger = noopLogger{}
88 }
89
90 return &Shell{
91 cwd: cwd,
92 env: env,
93 logger: logger,
94 blockFuncs: opts.BlockFuncs,
95 }
96}
97
98// Exec executes a command in the shell
99func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
100 s.mu.Lock()
101 defer s.mu.Unlock()
102
103 return s.execPOSIX(ctx, command)
104}
105
106// GetWorkingDir returns the current working directory
107func (s *Shell) GetWorkingDir() string {
108 s.mu.Lock()
109 defer s.mu.Unlock()
110 return s.cwd
111}
112
113// SetWorkingDir sets the working directory
114func (s *Shell) SetWorkingDir(dir string) error {
115 s.mu.Lock()
116 defer s.mu.Unlock()
117
118 // Verify the directory exists
119 if _, err := os.Stat(dir); err != nil {
120 return fmt.Errorf("directory does not exist: %w", err)
121 }
122
123 s.cwd = dir
124 return nil
125}
126
127// GetEnv returns a copy of the environment variables
128func (s *Shell) GetEnv() []string {
129 s.mu.Lock()
130 defer s.mu.Unlock()
131
132 env := make([]string, len(s.env))
133 copy(env, s.env)
134 return env
135}
136
137// SetEnv sets an environment variable
138func (s *Shell) SetEnv(key, value string) {
139 s.mu.Lock()
140 defer s.mu.Unlock()
141
142 // Update or add the environment variable
143 keyPrefix := key + "="
144 for i, env := range s.env {
145 if strings.HasPrefix(env, keyPrefix) {
146 s.env[i] = keyPrefix + value
147 return
148 }
149 }
150 s.env = append(s.env, keyPrefix+value)
151}
152
153// SetBlockFuncs sets the command block functions for the shell
154func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
155 s.mu.Lock()
156 defer s.mu.Unlock()
157 s.blockFuncs = blockFuncs
158}
159
160// CommandsBlocker creates a BlockFunc that blocks exact command matches
161func CommandsBlocker(cmds []string) BlockFunc {
162 bannedSet := make(map[string]struct{})
163 for _, cmd := range cmds {
164 bannedSet[cmd] = struct{}{}
165 }
166
167 return func(args []string) bool {
168 if len(args) == 0 {
169 return false
170 }
171 _, ok := bannedSet[args[0]]
172 return ok
173 }
174}
175
176// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
177func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
178 return func(parts []string) bool {
179 if len(parts) == 0 || parts[0] != cmd {
180 return false
181 }
182
183 argParts, flagParts := splitArgsFlags(parts[1:])
184 if len(argParts) < len(args) || len(flagParts) < len(flags) {
185 return false
186 }
187
188 argsMatch := slices.Equal(argParts[:len(args)], args)
189 flagsMatch := slicesext.IsSubset(flags, flagParts)
190
191 return argsMatch && flagsMatch
192 }
193}
194
195func splitArgsFlags(parts []string) (args []string, flags []string) {
196 args = make([]string, 0, len(parts))
197 flags = make([]string, 0, len(parts))
198 for _, part := range parts {
199 if strings.HasPrefix(part, "-") {
200 flags = append(flags, part)
201 } else {
202 args = append(args, part)
203 }
204 }
205 return
206}
207
208func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
209 return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
210 return func(ctx context.Context, args []string) error {
211 if len(args) == 0 {
212 return next(ctx, args)
213 }
214
215 for _, blockFunc := range s.blockFuncs {
216 if blockFunc(args) {
217 return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
218 }
219 }
220
221 return next(ctx, args)
222 }
223 }
224}
225
226// execPOSIX executes commands using POSIX shell emulation (cross-platform)
227func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
228 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
229 if err != nil {
230 return "", "", fmt.Errorf("could not parse command: %w", err)
231 }
232
233 var stdout, stderr bytes.Buffer
234 runner, err := interp.New(
235 interp.StdIO(nil, &stdout, &stderr),
236 interp.Interactive(false),
237 interp.Env(expand.ListEnviron(s.env...)),
238 interp.Dir(s.cwd),
239 interp.ExecHandlers(s.blockHandler(), coreutils.ExecHandler),
240 )
241 if err != nil {
242 return "", "", fmt.Errorf("could not run command: %w", err)
243 }
244
245 err = runner.Run(ctx, line)
246 s.cwd = runner.Dir
247 s.env = []string{}
248 for name, vr := range runner.Vars {
249 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
250 }
251 s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
252 return stdout.String(), stderr.String(), err
253}
254
255// IsInterrupt checks if an error is due to interruption
256func IsInterrupt(err error) bool {
257 return errors.Is(err, context.Canceled) ||
258 errors.Is(err, context.DeadlineExceeded)
259}
260
261// ExitCode extracts the exit code from an error
262func ExitCode(err error) int {
263 if err == nil {
264 return 0
265 }
266 var exitErr interp.ExitStatus
267 if errors.As(err, &exitErr) {
268 return int(exitErr)
269 }
270 return 1
271}