1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package offers two main types:
4// - Shell: A general-purpose shell executor for one-off or managed commands
5// - PersistentShell: A singleton shell that maintains state across the application
6//
7// WINDOWS COMPATIBILITY:
8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3),
9// even on Windows. Some caution has to be taken: commands should have forward
10// slashes (/) as path separators to work, even on Windows.
11package shell
12
13import (
14 "bytes"
15 "context"
16 "errors"
17 "fmt"
18 "io"
19 "os"
20 "slices"
21 "strings"
22 "sync"
23
24 "github.com/charmbracelet/x/exp/slice"
25 "mvdan.cc/sh/moreinterp/coreutils"
26 "mvdan.cc/sh/v3/expand"
27 "mvdan.cc/sh/v3/interp"
28 "mvdan.cc/sh/v3/syntax"
29)
30
31// ShellType represents the type of shell to use
32type ShellType int
33
34const (
35 ShellTypePOSIX ShellType = iota
36 ShellTypeCmd
37 ShellTypePowerShell
38)
39
40// Logger interface for optional logging
41type Logger interface {
42 InfoPersist(msg string, keysAndValues ...any)
43}
44
45// noopLogger is a logger that does nothing
46type noopLogger struct{}
47
48func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
49
50// BlockFunc is a function that determines if a command should be blocked
51type BlockFunc func(args []string) bool
52
53// Shell provides cross-platform shell execution with optional state persistence
54type Shell struct {
55 env []string
56 cwd string
57 mu sync.Mutex
58 logger Logger
59 blockFuncs []BlockFunc
60}
61
62// Options for creating a new shell
63type Options struct {
64 WorkingDir string
65 Env []string
66 Logger Logger
67 BlockFuncs []BlockFunc
68}
69
70// NewShell creates a new shell instance with the given options
71func NewShell(opts *Options) *Shell {
72 if opts == nil {
73 opts = &Options{}
74 }
75
76 cwd := opts.WorkingDir
77 if cwd == "" {
78 cwd, _ = os.Getwd()
79 }
80
81 env := opts.Env
82 if env == nil {
83 env = os.Environ()
84 }
85
86 logger := opts.Logger
87 if logger == nil {
88 logger = noopLogger{}
89 }
90
91 return &Shell{
92 cwd: cwd,
93 env: env,
94 logger: logger,
95 blockFuncs: opts.BlockFuncs,
96 }
97}
98
99// Exec executes a command in the shell
100func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
101 s.mu.Lock()
102 defer s.mu.Unlock()
103
104 return s.execPOSIX(ctx, command, nil)
105}
106
107// ExecWithStdin executes a command in the shell with the given stdin
108func (s *Shell) ExecWithStdin(ctx context.Context, command string, stdin string) (string, string, error) {
109 s.mu.Lock()
110 defer s.mu.Unlock()
111
112 return s.execPOSIX(ctx, command, strings.NewReader(stdin))
113}
114
115// GetWorkingDir returns the current working directory
116func (s *Shell) GetWorkingDir() string {
117 s.mu.Lock()
118 defer s.mu.Unlock()
119 return s.cwd
120}
121
122// SetWorkingDir sets the working directory
123func (s *Shell) SetWorkingDir(dir string) error {
124 s.mu.Lock()
125 defer s.mu.Unlock()
126
127 // Verify the directory exists
128 if _, err := os.Stat(dir); err != nil {
129 return fmt.Errorf("directory does not exist: %w", err)
130 }
131
132 s.cwd = dir
133 return nil
134}
135
136// GetEnv returns a copy of the environment variables
137func (s *Shell) GetEnv() []string {
138 s.mu.Lock()
139 defer s.mu.Unlock()
140
141 env := make([]string, len(s.env))
142 copy(env, s.env)
143 return env
144}
145
146// SetEnv sets an environment variable
147func (s *Shell) SetEnv(key, value string) {
148 s.mu.Lock()
149 defer s.mu.Unlock()
150
151 // Update or add the environment variable
152 keyPrefix := key + "="
153 for i, env := range s.env {
154 if strings.HasPrefix(env, keyPrefix) {
155 s.env[i] = keyPrefix + value
156 return
157 }
158 }
159 s.env = append(s.env, keyPrefix+value)
160}
161
162// SetBlockFuncs sets the command block functions for the shell
163func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
164 s.mu.Lock()
165 defer s.mu.Unlock()
166 s.blockFuncs = blockFuncs
167}
168
169// CommandsBlocker creates a BlockFunc that blocks exact command matches
170func CommandsBlocker(cmds []string) BlockFunc {
171 bannedSet := make(map[string]struct{})
172 for _, cmd := range cmds {
173 bannedSet[cmd] = struct{}{}
174 }
175
176 return func(args []string) bool {
177 if len(args) == 0 {
178 return false
179 }
180 _, ok := bannedSet[args[0]]
181 return ok
182 }
183}
184
185// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
186func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
187 return func(parts []string) bool {
188 if len(parts) == 0 || parts[0] != cmd {
189 return false
190 }
191
192 argParts, flagParts := splitArgsFlags(parts[1:])
193 if len(argParts) < len(args) || len(flagParts) < len(flags) {
194 return false
195 }
196
197 argsMatch := slices.Equal(argParts[:len(args)], args)
198 flagsMatch := slice.IsSubset(flags, flagParts)
199
200 return argsMatch && flagsMatch
201 }
202}
203
204func splitArgsFlags(parts []string) (args []string, flags []string) {
205 args = make([]string, 0, len(parts))
206 flags = make([]string, 0, len(parts))
207 for _, part := range parts {
208 if strings.HasPrefix(part, "-") {
209 // Extract flag name before '=' if present
210 flag := part
211 if idx := strings.IndexByte(part, '='); idx != -1 {
212 flag = part[:idx]
213 }
214 flags = append(flags, flag)
215 } else {
216 args = append(args, part)
217 }
218 }
219 return args, flags
220}
221
222func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
223 return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
224 return func(ctx context.Context, args []string) error {
225 if len(args) == 0 {
226 return next(ctx, args)
227 }
228
229 for _, blockFunc := range s.blockFuncs {
230 if blockFunc(args) {
231 return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
232 }
233 }
234
235 return next(ctx, args)
236 }
237 }
238}
239
240// execPOSIX executes commands using POSIX shell emulation (cross-platform)
241func (s *Shell) execPOSIX(ctx context.Context, command string, stdin *strings.Reader) (string, string, error) {
242 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
243 if err != nil {
244 return "", "", fmt.Errorf("could not parse command: %w", err)
245 }
246
247 var stdout, stderr bytes.Buffer
248 var stdinReader io.Reader
249 if stdin != nil {
250 stdinReader = stdin
251 }
252 runner, err := interp.New(
253 interp.StdIO(stdinReader, &stdout, &stderr),
254 interp.Interactive(false),
255 interp.Env(expand.ListEnviron(s.env...)),
256 interp.Dir(s.cwd),
257 interp.ExecHandlers(s.blockHandler(), coreutils.ExecHandler),
258 )
259 if err != nil {
260 return "", "", fmt.Errorf("could not run command: %w", err)
261 }
262
263 err = runner.Run(ctx, line)
264 s.cwd = runner.Dir
265 s.env = []string{}
266 for name, vr := range runner.Vars {
267 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
268 }
269 s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
270 return stdout.String(), stderr.String(), err
271}
272
273// IsInterrupt checks if an error is due to interruption
274func IsInterrupt(err error) bool {
275 return errors.Is(err, context.Canceled) ||
276 errors.Is(err, context.DeadlineExceeded)
277}
278
279// ExitCode extracts the exit code from an error
280func ExitCode(err error) int {
281 if err == nil {
282 return 0
283 }
284 var exitErr interp.ExitStatus
285 if errors.As(err, &exitErr) {
286 return int(exitErr)
287 }
288 return 1
289}