1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package provides Shell instances for executing commands with their own
4// working directory and environment. Each shell execution is independent.
5//
6// WINDOWS COMPATIBILITY:
7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
8// Windows. Commands should use forward slashes (/) as path separators to work
9// correctly on all platforms.
10package shell
11
12import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "io"
18 "os"
19 "slices"
20 "strings"
21 "sync"
22
23 "github.com/charmbracelet/x/exp/slice"
24 "mvdan.cc/sh/moreinterp/coreutils"
25 "mvdan.cc/sh/v3/expand"
26 "mvdan.cc/sh/v3/interp"
27 "mvdan.cc/sh/v3/syntax"
28)
29
30// ShellType represents the type of shell to use
31type ShellType int
32
33const (
34 ShellTypePOSIX ShellType = iota
35 ShellTypeCmd
36 ShellTypePowerShell
37)
38
39// Logger interface for optional logging
40type Logger interface {
41 InfoPersist(msg string, keysAndValues ...any)
42}
43
44// noopLogger is a logger that does nothing
45type noopLogger struct{}
46
47func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
48
49// BlockFunc is a function that determines if a command should be blocked
50type BlockFunc func(args []string) bool
51
52// Shell provides cross-platform shell execution with optional state persistence
53type Shell struct {
54 env []string
55 cwd string
56 mu sync.Mutex
57 logger Logger
58 blockFuncs []BlockFunc
59}
60
61// Options for creating a new shell
62type Options struct {
63 WorkingDir string
64 Env []string
65 Logger Logger
66 BlockFuncs []BlockFunc
67}
68
69// NewShell creates a new shell instance with the given options
70func NewShell(opts *Options) *Shell {
71 if opts == nil {
72 opts = &Options{}
73 }
74
75 cwd := opts.WorkingDir
76 if cwd == "" {
77 cwd, _ = os.Getwd()
78 }
79
80 env := opts.Env
81 if env == nil {
82 env = os.Environ()
83 }
84
85 logger := opts.Logger
86 if logger == nil {
87 logger = noopLogger{}
88 }
89
90 return &Shell{
91 cwd: cwd,
92 env: env,
93 logger: logger,
94 blockFuncs: opts.BlockFuncs,
95 }
96}
97
98// Exec executes a command in the shell
99func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
100 s.mu.Lock()
101 defer s.mu.Unlock()
102
103 return s.exec(ctx, command)
104}
105
106// ExecStream executes a command in the shell with streaming output to provided writers
107func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
108 s.mu.Lock()
109 defer s.mu.Unlock()
110
111 return s.execPOSIXStream(ctx, command, stdout, stderr)
112}
113
114// GetWorkingDir returns the current working directory
115func (s *Shell) GetWorkingDir() string {
116 s.mu.Lock()
117 defer s.mu.Unlock()
118 return s.cwd
119}
120
121// SetWorkingDir sets the working directory
122func (s *Shell) SetWorkingDir(dir string) error {
123 s.mu.Lock()
124 defer s.mu.Unlock()
125
126 // Verify the directory exists
127 if _, err := os.Stat(dir); err != nil {
128 return fmt.Errorf("directory does not exist: %w", err)
129 }
130
131 s.cwd = dir
132 return nil
133}
134
135// GetEnv returns a copy of the environment variables
136func (s *Shell) GetEnv() []string {
137 s.mu.Lock()
138 defer s.mu.Unlock()
139
140 env := make([]string, len(s.env))
141 copy(env, s.env)
142 return env
143}
144
145// SetEnv sets an environment variable
146func (s *Shell) SetEnv(key, value string) {
147 s.mu.Lock()
148 defer s.mu.Unlock()
149
150 // Update or add the environment variable
151 keyPrefix := key + "="
152 for i, env := range s.env {
153 if strings.HasPrefix(env, keyPrefix) {
154 s.env[i] = keyPrefix + value
155 return
156 }
157 }
158 s.env = append(s.env, keyPrefix+value)
159}
160
161// SetBlockFuncs sets the command block functions for the shell
162func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
163 s.mu.Lock()
164 defer s.mu.Unlock()
165 s.blockFuncs = blockFuncs
166}
167
168// CommandsBlocker creates a BlockFunc that blocks exact command matches
169func CommandsBlocker(cmds []string) BlockFunc {
170 bannedSet := make(map[string]struct{})
171 for _, cmd := range cmds {
172 bannedSet[cmd] = struct{}{}
173 }
174
175 return func(args []string) bool {
176 if len(args) == 0 {
177 return false
178 }
179 _, ok := bannedSet[args[0]]
180 return ok
181 }
182}
183
184// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
185func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
186 return func(parts []string) bool {
187 if len(parts) == 0 || parts[0] != cmd {
188 return false
189 }
190
191 argParts, flagParts := splitArgsFlags(parts[1:])
192 if len(argParts) < len(args) || len(flagParts) < len(flags) {
193 return false
194 }
195
196 argsMatch := slices.Equal(argParts[:len(args)], args)
197 flagsMatch := slice.IsSubset(flags, flagParts)
198
199 return argsMatch && flagsMatch
200 }
201}
202
203func splitArgsFlags(parts []string) (args []string, flags []string) {
204 args = make([]string, 0, len(parts))
205 flags = make([]string, 0, len(parts))
206 for _, part := range parts {
207 if strings.HasPrefix(part, "-") {
208 // Extract flag name before '=' if present
209 flag := part
210 if idx := strings.IndexByte(part, '='); idx != -1 {
211 flag = part[:idx]
212 }
213 flags = append(flags, flag)
214 } else {
215 args = append(args, part)
216 }
217 }
218 return args, flags
219}
220
221func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
222 return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
223 return func(ctx context.Context, args []string) error {
224 if len(args) == 0 {
225 return next(ctx, args)
226 }
227
228 for _, blockFunc := range s.blockFuncs {
229 if blockFunc(args) {
230 return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
231 }
232 }
233
234 return next(ctx, args)
235 }
236 }
237}
238
239// exec executes commands using a cross-platform shell interpreter.
240func (s *Shell) exec(ctx context.Context, command string) (string, string, error) {
241 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
242 if err != nil {
243 return "", "", fmt.Errorf("could not parse command: %w", err)
244 }
245
246 var stdout, stderr bytes.Buffer
247 runner, err := interp.New(
248 interp.StdIO(nil, &stdout, &stderr),
249 interp.Interactive(false),
250 interp.Env(expand.ListEnviron(s.env...)),
251 interp.Dir(s.cwd),
252 interp.ExecHandlers(s.execHandlers()...),
253 )
254 if err != nil {
255 return "", "", fmt.Errorf("could not run command: %w", err)
256 }
257
258 err = runner.Run(ctx, line)
259 s.cwd = runner.Dir
260 for name, vr := range runner.Vars {
261 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
262 }
263 s.logger.InfoPersist("command finished", "command", command, "err", err)
264 return stdout.String(), stderr.String(), err
265}
266
267// execPOSIXStream executes commands using POSIX shell emulation with streaming output
268func (s *Shell) execPOSIXStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
269 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
270 if err != nil {
271 return fmt.Errorf("could not parse command: %w", err)
272 }
273
274 runner, err := interp.New(
275 interp.StdIO(nil, stdout, stderr),
276 interp.Interactive(false),
277 interp.Env(expand.ListEnviron(s.env...)),
278 interp.Dir(s.cwd),
279 interp.ExecHandlers(s.execHandlers()...),
280 )
281 if err != nil {
282 return fmt.Errorf("could not run command: %w", err)
283 }
284
285 err = runner.Run(ctx, line)
286 s.cwd = runner.Dir
287 s.env = []string{}
288 for name, vr := range runner.Vars {
289 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
290 }
291 s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
292 return err
293}
294
295func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
296 handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
297 s.blockHandler(),
298 }
299 if useGoCoreUtils {
300 handlers = append(handlers, coreutils.ExecHandler)
301 }
302 return handlers
303}
304
305// IsInterrupt checks if an error is due to interruption
306func IsInterrupt(err error) bool {
307 return errors.Is(err, context.Canceled) ||
308 errors.Is(err, context.DeadlineExceeded)
309}
310
311// ExitCode extracts the exit code from an error
312func ExitCode(err error) int {
313 if err == nil {
314 return 0
315 }
316 var exitErr interp.ExitStatus
317 if errors.As(err, &exitErr) {
318 return int(exitErr)
319 }
320 return 1
321}