1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package provides Shell instances for executing commands with their own
4// working directory and environment. Each shell execution is independent.
5//
6// WINDOWS COMPATIBILITY:
7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
8// Windows. Commands should use forward slashes (/) as path separators to work
9// correctly on all platforms.
10package shell
11
12import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "io"
18 "os"
19 "slices"
20 "strings"
21 "sync"
22
23 "github.com/charmbracelet/x/exp/slice"
24 "mvdan.cc/sh/moreinterp/coreutils"
25 "mvdan.cc/sh/v3/expand"
26 "mvdan.cc/sh/v3/interp"
27 "mvdan.cc/sh/v3/syntax"
28)
29
30// ShellType represents the type of shell to use
31type ShellType int
32
33const (
34 ShellTypePOSIX ShellType = iota
35 ShellTypeCmd
36 ShellTypePowerShell
37)
38
39// Logger interface for optional logging
40type Logger interface {
41 InfoPersist(msg string, keysAndValues ...any)
42}
43
44// noopLogger is a logger that does nothing
45type noopLogger struct{}
46
47func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
48
49// BlockFunc is a function that determines if a command should be blocked
50type BlockFunc func(args []string) bool
51
52// Shell provides cross-platform shell execution with optional state persistence
53type Shell struct {
54 env []string
55 cwd string
56 mu sync.Mutex
57 logger Logger
58 blockFuncs []BlockFunc
59 customExecHandlers []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc
60}
61
62// Options for creating a new shell
63type Options struct {
64 WorkingDir string
65 Env []string
66 Logger Logger
67 BlockFuncs []BlockFunc
68 ExecHandlers []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc
69}
70
71// NewShell creates a new shell instance with the given options
72func NewShell(opts *Options) *Shell {
73 if opts == nil {
74 opts = &Options{}
75 }
76
77 cwd := opts.WorkingDir
78 if cwd == "" {
79 cwd, _ = os.Getwd()
80 }
81
82 env := opts.Env
83 if env == nil {
84 env = os.Environ()
85 }
86
87 logger := opts.Logger
88 if logger == nil {
89 logger = noopLogger{}
90 }
91
92 return &Shell{
93 cwd: cwd,
94 env: env,
95 logger: logger,
96 blockFuncs: opts.BlockFuncs,
97 customExecHandlers: opts.ExecHandlers,
98 }
99}
100
101// Exec executes a command in the shell
102func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
103 s.mu.Lock()
104 defer s.mu.Unlock()
105
106 return s.exec(ctx, command, nil)
107}
108
109// ExecWithStdin executes a command in the shell with provided stdin
110func (s *Shell) ExecWithStdin(ctx context.Context, command string, stdin io.Reader) (string, string, error) {
111 s.mu.Lock()
112 defer s.mu.Unlock()
113
114 return s.exec(ctx, command, stdin)
115}
116
117// ExecStream executes a command in the shell with streaming output to provided writers
118func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
119 s.mu.Lock()
120 defer s.mu.Unlock()
121
122 return s.execStream(ctx, command, stdout, stderr)
123}
124
125// GetWorkingDir returns the current working directory
126func (s *Shell) GetWorkingDir() string {
127 s.mu.Lock()
128 defer s.mu.Unlock()
129 return s.cwd
130}
131
132// SetWorkingDir sets the working directory
133func (s *Shell) SetWorkingDir(dir string) error {
134 s.mu.Lock()
135 defer s.mu.Unlock()
136
137 // Verify the directory exists
138 if _, err := os.Stat(dir); err != nil {
139 return fmt.Errorf("directory does not exist: %w", err)
140 }
141
142 s.cwd = dir
143 return nil
144}
145
146// GetEnv returns a copy of the environment variables
147func (s *Shell) GetEnv() []string {
148 s.mu.Lock()
149 defer s.mu.Unlock()
150
151 env := make([]string, len(s.env))
152 copy(env, s.env)
153 return env
154}
155
156// SetEnv sets an environment variable
157func (s *Shell) SetEnv(key, value string) {
158 s.mu.Lock()
159 defer s.mu.Unlock()
160
161 // Update or add the environment variable
162 keyPrefix := key + "="
163 for i, env := range s.env {
164 if strings.HasPrefix(env, keyPrefix) {
165 s.env[i] = keyPrefix + value
166 return
167 }
168 }
169 s.env = append(s.env, keyPrefix+value)
170}
171
172// SetBlockFuncs sets the command block functions for the shell
173func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
174 s.mu.Lock()
175 defer s.mu.Unlock()
176 s.blockFuncs = blockFuncs
177}
178
179// CommandsBlocker creates a BlockFunc that blocks exact command matches
180func CommandsBlocker(cmds []string) BlockFunc {
181 bannedSet := make(map[string]struct{})
182 for _, cmd := range cmds {
183 bannedSet[cmd] = struct{}{}
184 }
185
186 return func(args []string) bool {
187 if len(args) == 0 {
188 return false
189 }
190 _, ok := bannedSet[args[0]]
191 return ok
192 }
193}
194
195// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
196func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
197 return func(parts []string) bool {
198 if len(parts) == 0 || parts[0] != cmd {
199 return false
200 }
201
202 argParts, flagParts := splitArgsFlags(parts[1:])
203 if len(argParts) < len(args) || len(flagParts) < len(flags) {
204 return false
205 }
206
207 argsMatch := slices.Equal(argParts[:len(args)], args)
208 flagsMatch := slice.IsSubset(flags, flagParts)
209
210 return argsMatch && flagsMatch
211 }
212}
213
214func splitArgsFlags(parts []string) (args []string, flags []string) {
215 args = make([]string, 0, len(parts))
216 flags = make([]string, 0, len(parts))
217 for _, part := range parts {
218 if strings.HasPrefix(part, "-") {
219 // Extract flag name before '=' if present
220 flag := part
221 if idx := strings.IndexByte(part, '='); idx != -1 {
222 flag = part[:idx]
223 }
224 flags = append(flags, flag)
225 } else {
226 args = append(args, part)
227 }
228 }
229 return args, flags
230}
231
232func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
233 return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
234 return func(ctx context.Context, args []string) error {
235 if len(args) == 0 {
236 return next(ctx, args)
237 }
238
239 for _, blockFunc := range s.blockFuncs {
240 if blockFunc(args) {
241 return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
242 }
243 }
244
245 return next(ctx, args)
246 }
247 }
248}
249
250// newInterp creates a new interpreter with the current shell state
251func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
252 opts := []interp.RunnerOption{
253 interp.StdIO(stdin, stdout, stderr),
254 interp.Interactive(false),
255 interp.Env(expand.ListEnviron(s.env...)),
256 interp.Dir(s.cwd),
257 interp.ExecHandlers(s.execHandlers()...),
258 }
259
260 return interp.New(opts...)
261}
262
263// updateShellFromRunner updates the shell from the interpreter after execution
264func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
265 s.cwd = runner.Dir
266 s.env = nil
267 for name, vr := range runner.Vars {
268 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
269 }
270}
271
272// execCommon is the shared implementation for executing commands
273func (s *Shell) execCommon(ctx context.Context, command string, stdin io.Reader, stdout, stderr io.Writer) error {
274 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
275 if err != nil {
276 return fmt.Errorf("could not parse command: %w", err)
277 }
278
279 runner, err := s.newInterp(stdin, stdout, stderr)
280 if err != nil {
281 return fmt.Errorf("could not run command: %w", err)
282 }
283
284 err = runner.Run(ctx, line)
285 s.updateShellFromRunner(runner)
286 s.logger.InfoPersist("command finished", "command", command, "err", err)
287 return err
288}
289
290// exec executes commands using a cross-platform shell interpreter.
291func (s *Shell) exec(ctx context.Context, command string, stdin io.Reader) (string, string, error) {
292 var stdout, stderr bytes.Buffer
293 err := s.execCommon(ctx, command, stdin, &stdout, &stderr)
294 return stdout.String(), stderr.String(), err
295}
296
297// execStream executes commands using POSIX shell emulation with streaming output
298func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
299 return s.execCommon(ctx, command, nil, stdout, stderr)
300}
301
302func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
303 handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
304 s.blockHandler(),
305 }
306 // Add custom exec handlers first (they get priority)
307 handlers = append(handlers, s.customExecHandlers...)
308 if useGoCoreUtils {
309 handlers = append(handlers, coreutils.ExecHandler)
310 }
311 return handlers
312}
313
314// IsInterrupt checks if an error is due to interruption
315func IsInterrupt(err error) bool {
316 return errors.Is(err, context.Canceled) ||
317 errors.Is(err, context.DeadlineExceeded)
318}
319
320// ExitCode extracts the exit code from an error
321func ExitCode(err error) int {
322 if err == nil {
323 return 0
324 }
325 var exitErr interp.ExitStatus
326 if errors.As(err, &exitErr) {
327 return int(exitErr)
328 }
329 return 1
330}