1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package provides Shell instances for executing commands with their own
4// working directory and environment. Each shell execution is independent.
5//
6// WINDOWS COMPATIBILITY:
7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
8// Windows. Commands should use forward slashes (/) as path separators to work
9// correctly on all platforms.
10package shell
11
12import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "io"
18 "os"
19 "slices"
20 "strings"
21 "sync"
22
23 "github.com/charmbracelet/x/exp/slice"
24 "mvdan.cc/sh/moreinterp/coreutils"
25 "mvdan.cc/sh/v3/expand"
26 "mvdan.cc/sh/v3/interp"
27 "mvdan.cc/sh/v3/syntax"
28)
29
30// ShellType represents the type of shell to use
31type ShellType int
32
33const (
34 ShellTypePOSIX ShellType = iota
35 ShellTypeCmd
36 ShellTypePowerShell
37)
38
39// Logger interface for optional logging
40type Logger interface {
41 InfoPersist(msg string, keysAndValues ...any)
42}
43
44// noopLogger is a logger that does nothing
45type noopLogger struct{}
46
47func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
48
49// BlockFunc is a function that determines if a command should be blocked
50type BlockFunc func(args []string) bool
51
52// Shell provides cross-platform shell execution with optional state persistence
53type Shell struct {
54 env []string
55 cwd string
56 mu sync.Mutex
57 logger Logger
58 blockFuncs []BlockFunc
59}
60
61// Options for creating a new shell
62type Options struct {
63 WorkingDir string
64 Env []string
65 Logger Logger
66 BlockFuncs []BlockFunc
67}
68
69// NewShell creates a new shell instance with the given options
70func NewShell(opts *Options) *Shell {
71 if opts == nil {
72 opts = &Options{}
73 }
74
75 cwd := opts.WorkingDir
76 if cwd == "" {
77 cwd, _ = os.Getwd()
78 }
79
80 env := opts.Env
81 if env == nil {
82 env = os.Environ()
83 }
84
85 logger := opts.Logger
86 if logger == nil {
87 logger = noopLogger{}
88 }
89
90 return &Shell{
91 cwd: cwd,
92 env: env,
93 logger: logger,
94 blockFuncs: opts.BlockFuncs,
95 }
96}
97
98// Exec executes a command in the shell
99func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
100 s.mu.Lock()
101 defer s.mu.Unlock()
102
103 return s.exec(ctx, command)
104}
105
106// ExecStream executes a command in the shell with streaming output to provided writers
107func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
108 s.mu.Lock()
109 defer s.mu.Unlock()
110
111 return s.execStream(ctx, command, stdout, stderr)
112}
113
114func (s *Shell) ExecWithStdin(ctx context.Context, command string, stdin string) (string, string, error) {
115 s.mu.Lock()
116 defer s.mu.Unlock()
117
118 var stdout, stderr bytes.Buffer
119 err := s.execCommonWithStdin(ctx, command, strings.NewReader(stdin), &stdout, &stderr)
120 return stdout.String(), stderr.String(), err
121}
122
123// GetWorkingDir returns the current working directory
124func (s *Shell) GetWorkingDir() string {
125 s.mu.Lock()
126 defer s.mu.Unlock()
127 return s.cwd
128}
129
130// SetWorkingDir sets the working directory
131func (s *Shell) SetWorkingDir(dir string) error {
132 s.mu.Lock()
133 defer s.mu.Unlock()
134
135 // Verify the directory exists
136 if _, err := os.Stat(dir); err != nil {
137 return fmt.Errorf("directory does not exist: %w", err)
138 }
139
140 s.cwd = dir
141 return nil
142}
143
144// GetEnv returns a copy of the environment variables
145func (s *Shell) GetEnv() []string {
146 s.mu.Lock()
147 defer s.mu.Unlock()
148
149 env := make([]string, len(s.env))
150 copy(env, s.env)
151 return env
152}
153
154// SetEnv sets an environment variable
155func (s *Shell) SetEnv(key, value string) {
156 s.mu.Lock()
157 defer s.mu.Unlock()
158
159 // Update or add the environment variable
160 keyPrefix := key + "="
161 for i, env := range s.env {
162 if strings.HasPrefix(env, keyPrefix) {
163 s.env[i] = keyPrefix + value
164 return
165 }
166 }
167 s.env = append(s.env, keyPrefix+value)
168}
169
170// SetBlockFuncs sets the command block functions for the shell
171func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
172 s.mu.Lock()
173 defer s.mu.Unlock()
174 s.blockFuncs = blockFuncs
175}
176
177// CommandsBlocker creates a BlockFunc that blocks exact command matches
178func CommandsBlocker(cmds []string) BlockFunc {
179 bannedSet := make(map[string]struct{})
180 for _, cmd := range cmds {
181 bannedSet[cmd] = struct{}{}
182 }
183
184 return func(args []string) bool {
185 if len(args) == 0 {
186 return false
187 }
188 _, ok := bannedSet[args[0]]
189 return ok
190 }
191}
192
193// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
194func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
195 return func(parts []string) bool {
196 if len(parts) == 0 || parts[0] != cmd {
197 return false
198 }
199
200 argParts, flagParts := splitArgsFlags(parts[1:])
201 if len(argParts) < len(args) || len(flagParts) < len(flags) {
202 return false
203 }
204
205 argsMatch := slices.Equal(argParts[:len(args)], args)
206 flagsMatch := slice.IsSubset(flags, flagParts)
207
208 return argsMatch && flagsMatch
209 }
210}
211
212func splitArgsFlags(parts []string) (args []string, flags []string) {
213 args = make([]string, 0, len(parts))
214 flags = make([]string, 0, len(parts))
215 for _, part := range parts {
216 if strings.HasPrefix(part, "-") {
217 // Extract flag name before '=' if present
218 flag := part
219 if idx := strings.IndexByte(part, '='); idx != -1 {
220 flag = part[:idx]
221 }
222 flags = append(flags, flag)
223 } else {
224 args = append(args, part)
225 }
226 }
227 return args, flags
228}
229
230func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
231 return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
232 return func(ctx context.Context, args []string) error {
233 if len(args) == 0 {
234 return next(ctx, args)
235 }
236
237 for _, blockFunc := range s.blockFuncs {
238 if blockFunc(args) {
239 return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
240 }
241 }
242
243 return next(ctx, args)
244 }
245 }
246}
247
248// newInterp creates a new interpreter with the current shell state
249func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
250 return interp.New(
251 interp.StdIO(stdin, stdout, stderr),
252 interp.Interactive(false),
253 interp.Env(expand.ListEnviron(s.env...)),
254 interp.Dir(s.cwd),
255 interp.ExecHandlers(s.execHandlers()...),
256 )
257}
258
259// updateShellFromRunner updates the shell from the interpreter after execution
260func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
261 s.cwd = runner.Dir
262 s.env = nil
263 for name, vr := range runner.Vars {
264 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
265 }
266}
267
268// execCommon is the shared implementation for executing commands
269func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr io.Writer) error {
270 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
271 if err != nil {
272 return fmt.Errorf("could not parse command: %w", err)
273 }
274
275 runner, err := s.newInterp(nil, stdout, stderr)
276 if err != nil {
277 return fmt.Errorf("could not run command: %w", err)
278 }
279
280 err = runner.Run(ctx, line)
281 s.updateShellFromRunner(runner)
282 s.logger.InfoPersist("command finished", "command", command, "err", err)
283 return err
284}
285
286// exec executes commands using a cross-platform shell interpreter.
287func (s *Shell) exec(ctx context.Context, command string) (string, string, error) {
288 var stdout, stderr bytes.Buffer
289 err := s.execCommon(ctx, command, &stdout, &stderr)
290 return stdout.String(), stderr.String(), err
291}
292
293// execStream executes commands using POSIX shell emulation with streaming output
294func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
295 return s.execCommon(ctx, command, stdout, stderr)
296}
297
298// execCommonWithStdin is like execCommon but with stdin support
299func (s *Shell) execCommonWithStdin(ctx context.Context, command string, stdin io.Reader, stdout, stderr io.Writer) error {
300 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
301 if err != nil {
302 return fmt.Errorf("could not parse command: %w", err)
303 }
304
305 runner, err := s.newInterp(stdin, stdout, stderr)
306 if err != nil {
307 return fmt.Errorf("could not run command: %w", err)
308 }
309
310 err = runner.Run(ctx, line)
311 s.updateShellFromRunner(runner)
312 s.logger.InfoPersist("command finished", "command", command, "err", err)
313 return err
314}
315
316func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
317 handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
318 s.blockHandler(),
319 }
320 if useGoCoreUtils {
321 handlers = append(handlers, coreutils.ExecHandler)
322 }
323 return handlers
324}
325
326// IsInterrupt checks if an error is due to interruption
327func IsInterrupt(err error) bool {
328 return errors.Is(err, context.Canceled) ||
329 errors.Is(err, context.DeadlineExceeded)
330}
331
332// ExitCode extracts the exit code from an error
333func ExitCode(err error) int {
334 if err == nil {
335 return 0
336 }
337 var exitErr interp.ExitStatus
338 if errors.As(err, &exitErr) {
339 return int(exitErr)
340 }
341 return 1
342}