1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package provides Shell instances for executing commands with their own
4// working directory and environment. Each shell execution is independent.
5//
6// WINDOWS COMPATIBILITY:
7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
8// Windows. Commands should use forward slashes (/) as path separators to work
9// correctly on all platforms.
10package shell
11
12import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "io"
18 "os"
19 "slices"
20 "strings"
21 "sync"
22
23 "github.com/charmbracelet/x/exp/slice"
24 "mvdan.cc/sh/moreinterp/coreutils"
25 "mvdan.cc/sh/v3/expand"
26 "mvdan.cc/sh/v3/interp"
27 "mvdan.cc/sh/v3/syntax"
28)
29
30// ShellType represents the type of shell to use
31type ShellType int
32
33const (
34 ShellTypePOSIX ShellType = iota
35 ShellTypeCmd
36 ShellTypePowerShell
37)
38
39// Logger interface for optional logging
40type Logger interface {
41 InfoPersist(msg string, keysAndValues ...any)
42}
43
44// noopLogger is a logger that does nothing
45type noopLogger struct{}
46
47func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
48
49// BlockFunc is a function that determines if a command should be blocked
50type BlockFunc func(args []string) bool
51
52// Shell provides cross-platform shell execution with optional state persistence
53type Shell struct {
54 env []string
55 cwd string
56 mu sync.Mutex
57 logger Logger
58 blockFuncs []BlockFunc
59}
60
61// Options for creating a new shell
62type Options struct {
63 WorkingDir string
64 Env []string
65 Logger Logger
66 BlockFuncs []BlockFunc
67}
68
69// NewShell creates a new shell instance with the given options
70func NewShell(opts *Options) *Shell {
71 if opts == nil {
72 opts = &Options{}
73 }
74
75 cwd := opts.WorkingDir
76 if cwd == "" {
77 cwd, _ = os.Getwd()
78 }
79
80 env := opts.Env
81 if env == nil {
82 env = os.Environ()
83 }
84
85 logger := opts.Logger
86 if logger == nil {
87 logger = noopLogger{}
88 }
89
90 return &Shell{
91 cwd: cwd,
92 env: env,
93 logger: logger,
94 blockFuncs: opts.BlockFuncs,
95 }
96}
97
98// Exec executes a command in the shell
99func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
100 s.mu.Lock()
101 defer s.mu.Unlock()
102
103 return s.exec(ctx, command, nil)
104}
105
106// ExecWithStdin executes a command in the shell with provided stdin
107func (s *Shell) ExecWithStdin(ctx context.Context, command string, stdin io.Reader) (string, string, error) {
108 s.mu.Lock()
109 defer s.mu.Unlock()
110
111 return s.exec(ctx, command, stdin)
112}
113
114// ExecStream executes a command in the shell with streaming output to provided writers
115func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
116 s.mu.Lock()
117 defer s.mu.Unlock()
118
119 return s.execStream(ctx, command, stdout, stderr)
120}
121
122// GetWorkingDir returns the current working directory
123func (s *Shell) GetWorkingDir() string {
124 s.mu.Lock()
125 defer s.mu.Unlock()
126 return s.cwd
127}
128
129// SetWorkingDir sets the working directory
130func (s *Shell) SetWorkingDir(dir string) error {
131 s.mu.Lock()
132 defer s.mu.Unlock()
133
134 // Verify the directory exists
135 if _, err := os.Stat(dir); err != nil {
136 return fmt.Errorf("directory does not exist: %w", err)
137 }
138
139 s.cwd = dir
140 return nil
141}
142
143// GetEnv returns a copy of the environment variables
144func (s *Shell) GetEnv() []string {
145 s.mu.Lock()
146 defer s.mu.Unlock()
147
148 env := make([]string, len(s.env))
149 copy(env, s.env)
150 return env
151}
152
153// SetEnv sets an environment variable
154func (s *Shell) SetEnv(key, value string) {
155 s.mu.Lock()
156 defer s.mu.Unlock()
157
158 // Update or add the environment variable
159 keyPrefix := key + "="
160 for i, env := range s.env {
161 if strings.HasPrefix(env, keyPrefix) {
162 s.env[i] = keyPrefix + value
163 return
164 }
165 }
166 s.env = append(s.env, keyPrefix+value)
167}
168
169// SetBlockFuncs sets the command block functions for the shell
170func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
171 s.mu.Lock()
172 defer s.mu.Unlock()
173 s.blockFuncs = blockFuncs
174}
175
176// CommandsBlocker creates a BlockFunc that blocks exact command matches
177func CommandsBlocker(cmds []string) BlockFunc {
178 bannedSet := make(map[string]struct{})
179 for _, cmd := range cmds {
180 bannedSet[cmd] = struct{}{}
181 }
182
183 return func(args []string) bool {
184 if len(args) == 0 {
185 return false
186 }
187 _, ok := bannedSet[args[0]]
188 return ok
189 }
190}
191
192// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
193func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
194 return func(parts []string) bool {
195 if len(parts) == 0 || parts[0] != cmd {
196 return false
197 }
198
199 argParts, flagParts := splitArgsFlags(parts[1:])
200 if len(argParts) < len(args) || len(flagParts) < len(flags) {
201 return false
202 }
203
204 argsMatch := slices.Equal(argParts[:len(args)], args)
205 flagsMatch := slice.IsSubset(flags, flagParts)
206
207 return argsMatch && flagsMatch
208 }
209}
210
211func splitArgsFlags(parts []string) (args []string, flags []string) {
212 args = make([]string, 0, len(parts))
213 flags = make([]string, 0, len(parts))
214 for _, part := range parts {
215 if strings.HasPrefix(part, "-") {
216 // Extract flag name before '=' if present
217 flag := part
218 if idx := strings.IndexByte(part, '='); idx != -1 {
219 flag = part[:idx]
220 }
221 flags = append(flags, flag)
222 } else {
223 args = append(args, part)
224 }
225 }
226 return args, flags
227}
228
229func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
230 return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
231 return func(ctx context.Context, args []string) error {
232 if len(args) == 0 {
233 return next(ctx, args)
234 }
235
236 for _, blockFunc := range s.blockFuncs {
237 if blockFunc(args) {
238 return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
239 }
240 }
241
242 return next(ctx, args)
243 }
244 }
245}
246
247// newInterp creates a new interpreter with the current shell state
248func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
249 return interp.New(
250 interp.StdIO(stdin, stdout, stderr),
251 interp.Interactive(false),
252 interp.Env(expand.ListEnviron(s.env...)),
253 interp.Dir(s.cwd),
254 interp.ExecHandlers(s.execHandlers()...),
255 )
256}
257
258// updateShellFromRunner updates the shell from the interpreter after execution
259func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
260 s.cwd = runner.Dir
261 s.env = nil
262 for name, vr := range runner.Vars {
263 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
264 }
265}
266
267// execCommon is the shared implementation for executing commands
268func (s *Shell) execCommon(ctx context.Context, command string, stdin io.Reader, stdout, stderr io.Writer) error {
269 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
270 if err != nil {
271 return fmt.Errorf("could not parse command: %w", err)
272 }
273
274 runner, err := s.newInterp(stdin, stdout, stderr)
275 if err != nil {
276 return fmt.Errorf("could not run command: %w", err)
277 }
278
279 err = runner.Run(ctx, line)
280 s.updateShellFromRunner(runner)
281 s.logger.InfoPersist("command finished", "command", command, "err", err)
282 return err
283}
284
285// exec executes commands using a cross-platform shell interpreter.
286func (s *Shell) exec(ctx context.Context, command string, stdin io.Reader) (string, string, error) {
287 var stdout, stderr bytes.Buffer
288 err := s.execCommon(ctx, command, stdin, &stdout, &stderr)
289 return stdout.String(), stderr.String(), err
290}
291
292// execStream executes commands using POSIX shell emulation with streaming output
293func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
294 return s.execCommon(ctx, command, nil, stdout, stderr)
295}
296
297func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
298 handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
299 s.blockHandler(),
300 }
301 if useGoCoreUtils {
302 handlers = append(handlers, coreutils.ExecHandler)
303 }
304 return handlers
305}
306
307// IsInterrupt checks if an error is due to interruption
308func IsInterrupt(err error) bool {
309 return errors.Is(err, context.Canceled) ||
310 errors.Is(err, context.DeadlineExceeded)
311}
312
313// ExitCode extracts the exit code from an error
314func ExitCode(err error) int {
315 if err == nil {
316 return 0
317 }
318 var exitErr interp.ExitStatus
319 if errors.As(err, &exitErr) {
320 return int(exitErr)
321 }
322 return 1
323}