1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package provides Shell instances for executing commands with their own
4// working directory and environment. Each shell execution is independent.
5//
6// WINDOWS COMPATIBILITY:
7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
8// Windows. Commands should use forward slashes (/) as path separators to work
9// correctly on all platforms.
10package shell
11
12import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "io"
18 "os"
19 "slices"
20 "strings"
21 "sync"
22
23 "github.com/charmbracelet/x/exp/slice"
24 "mvdan.cc/sh/moreinterp/coreutils"
25 "mvdan.cc/sh/v3/expand"
26 "mvdan.cc/sh/v3/interp"
27 "mvdan.cc/sh/v3/syntax"
28)
29
30// ShellType represents the type of shell to use
31type ShellType int
32
33const (
34 ShellTypePOSIX ShellType = iota
35 ShellTypeCmd
36 ShellTypePowerShell
37)
38
39// Logger interface for optional logging
40type Logger interface {
41 InfoPersist(msg string, keysAndValues ...any)
42}
43
44// noopLogger is a logger that does nothing
45type noopLogger struct{}
46
47func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
48
49// BlockFunc is a function that determines if a command should be blocked
50type BlockFunc func(args []string) bool
51
52// Shell provides cross-platform shell execution with optional state persistence
53type Shell struct {
54 env []string
55 cwd string
56 mu sync.Mutex
57 logger Logger
58 blockFuncs []BlockFunc
59}
60
61// Options for creating a new shell
62type Options struct {
63 WorkingDir string
64 Env []string
65 Logger Logger
66 BlockFuncs []BlockFunc
67}
68
69// NewShell creates a new shell instance with the given options
70func NewShell(opts *Options) *Shell {
71 if opts == nil {
72 opts = &Options{}
73 }
74
75 cwd := opts.WorkingDir
76 if cwd == "" {
77 cwd, _ = os.Getwd()
78 }
79
80 env := opts.Env
81 if env == nil {
82 env = os.Environ()
83 }
84
85 logger := opts.Logger
86 if logger == nil {
87 logger = noopLogger{}
88 }
89
90 return &Shell{
91 cwd: cwd,
92 env: env,
93 logger: logger,
94 blockFuncs: opts.BlockFuncs,
95 }
96}
97
98// Exec executes a command in the shell
99func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
100 s.mu.Lock()
101 defer s.mu.Unlock()
102
103 return s.exec(ctx, command)
104}
105
106// ExecStream executes a command in the shell with streaming output to provided writers
107func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
108 s.mu.Lock()
109 defer s.mu.Unlock()
110
111 return s.execStream(ctx, command, stdout, stderr)
112}
113
114// GetWorkingDir returns the current working directory
115func (s *Shell) GetWorkingDir() string {
116 s.mu.Lock()
117 defer s.mu.Unlock()
118 return s.cwd
119}
120
121// SetWorkingDir sets the working directory
122func (s *Shell) SetWorkingDir(dir string) error {
123 s.mu.Lock()
124 defer s.mu.Unlock()
125
126 // Verify the directory exists
127 if _, err := os.Stat(dir); err != nil {
128 return fmt.Errorf("directory does not exist: %w", err)
129 }
130
131 s.cwd = dir
132 return nil
133}
134
135// GetEnv returns a copy of the environment variables
136func (s *Shell) GetEnv() []string {
137 s.mu.Lock()
138 defer s.mu.Unlock()
139
140 env := make([]string, len(s.env))
141 copy(env, s.env)
142 return env
143}
144
145// SetEnv sets an environment variable
146func (s *Shell) SetEnv(key, value string) {
147 s.mu.Lock()
148 defer s.mu.Unlock()
149
150 // Update or add the environment variable
151 keyPrefix := key + "="
152 for i, env := range s.env {
153 if strings.HasPrefix(env, keyPrefix) {
154 s.env[i] = keyPrefix + value
155 return
156 }
157 }
158 s.env = append(s.env, keyPrefix+value)
159}
160
161// SetBlockFuncs sets the command block functions for the shell
162func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
163 s.mu.Lock()
164 defer s.mu.Unlock()
165 s.blockFuncs = blockFuncs
166}
167
168// CommandsBlocker creates a BlockFunc that blocks exact command matches
169func CommandsBlocker(cmds []string) BlockFunc {
170 bannedSet := make(map[string]struct{})
171 for _, cmd := range cmds {
172 bannedSet[cmd] = struct{}{}
173 }
174
175 return func(args []string) bool {
176 if len(args) == 0 {
177 return false
178 }
179 _, ok := bannedSet[args[0]]
180 return ok
181 }
182}
183
184// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
185func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
186 return func(parts []string) bool {
187 if len(parts) == 0 || parts[0] != cmd {
188 return false
189 }
190
191 argParts, flagParts := splitArgsFlags(parts[1:])
192 if len(argParts) < len(args) || len(flagParts) < len(flags) {
193 return false
194 }
195
196 argsMatch := slices.Equal(argParts[:len(args)], args)
197 flagsMatch := slice.IsSubset(flags, flagParts)
198
199 return argsMatch && flagsMatch
200 }
201}
202
203func splitArgsFlags(parts []string) (args []string, flags []string) {
204 args = make([]string, 0, len(parts))
205 flags = make([]string, 0, len(parts))
206 for _, part := range parts {
207 if strings.HasPrefix(part, "-") {
208 // Extract flag name before '=' if present
209 flag := part
210 if idx := strings.IndexByte(part, '='); idx != -1 {
211 flag = part[:idx]
212 }
213 flags = append(flags, flag)
214 } else {
215 args = append(args, part)
216 }
217 }
218 return args, flags
219}
220
221func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
222 return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
223 return func(ctx context.Context, args []string) error {
224 if len(args) == 0 {
225 return next(ctx, args)
226 }
227
228 for _, blockFunc := range s.blockFuncs {
229 if blockFunc(args) {
230 return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
231 }
232 }
233
234 return next(ctx, args)
235 }
236 }
237}
238
239// newInterp creates a new interpreter with the current shell state
240func (s *Shell) newInterp(stdout, stderr io.Writer) (*interp.Runner, error) {
241 return interp.New(
242 interp.StdIO(nil, stdout, stderr),
243 interp.Interactive(false),
244 interp.Env(expand.ListEnviron(s.env...)),
245 interp.Dir(s.cwd),
246 interp.ExecHandlers(s.execHandlers()...),
247 )
248}
249
250// updateShellFromRunner updates the shell from the interpreter after execution
251func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
252 s.cwd = runner.Dir
253 s.env = nil
254 for name, vr := range runner.Vars {
255 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
256 }
257}
258
259// execCommon is the shared implementation for executing commands
260func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr io.Writer) error {
261 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
262 if err != nil {
263 return fmt.Errorf("could not parse command: %w", err)
264 }
265
266 runner, err := s.newInterp(stdout, stderr)
267 if err != nil {
268 return fmt.Errorf("could not run command: %w", err)
269 }
270
271 err = runner.Run(ctx, line)
272 s.updateShellFromRunner(runner)
273 s.logger.InfoPersist("command finished", "command", command, "err", err)
274 return err
275}
276
277// exec executes commands using a cross-platform shell interpreter.
278func (s *Shell) exec(ctx context.Context, command string) (string, string, error) {
279 var stdout, stderr bytes.Buffer
280 err := s.execCommon(ctx, command, &stdout, &stderr)
281 return stdout.String(), stderr.String(), err
282}
283
284// execStream executes commands using POSIX shell emulation with streaming output
285func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
286 return s.execCommon(ctx, command, stdout, stderr)
287}
288
289func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
290 handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
291 s.blockHandler(),
292 }
293 if useGoCoreUtils {
294 handlers = append(handlers, coreutils.ExecHandler)
295 }
296 return handlers
297}
298
299// IsInterrupt checks if an error is due to interruption
300func IsInterrupt(err error) bool {
301 return errors.Is(err, context.Canceled) ||
302 errors.Is(err, context.DeadlineExceeded)
303}
304
305// ExitCode extracts the exit code from an error
306func ExitCode(err error) int {
307 if err == nil {
308 return 0
309 }
310 var exitErr interp.ExitStatus
311 if errors.As(err, &exitErr) {
312 return int(exitErr)
313 }
314 return 1
315}