1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package provides Shell instances for executing commands with their own
4// working directory and environment. Each shell execution is independent.
5//
6// WINDOWS COMPATIBILITY:
7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
8// Windows. Commands should use forward slashes (/) as path separators to work
9// correctly on all platforms.
10package shell
11
12import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "os"
18 "slices"
19 "strings"
20 "sync"
21
22 "github.com/charmbracelet/x/exp/slice"
23 "mvdan.cc/sh/moreinterp/coreutils"
24 "mvdan.cc/sh/v3/expand"
25 "mvdan.cc/sh/v3/interp"
26 "mvdan.cc/sh/v3/syntax"
27)
28
29// ShellType represents the type of shell to use
30type ShellType int
31
32const (
33 ShellTypePOSIX ShellType = iota
34 ShellTypeCmd
35 ShellTypePowerShell
36)
37
38// Logger interface for optional logging
39type Logger interface {
40 InfoPersist(msg string, keysAndValues ...any)
41}
42
43// noopLogger is a logger that does nothing
44type noopLogger struct{}
45
46func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
47
48// BlockFunc is a function that determines if a command should be blocked
49type BlockFunc func(args []string) bool
50
51// Shell provides cross-platform shell execution with optional state persistence
52type Shell struct {
53 env []string
54 cwd string
55 mu sync.Mutex
56 logger Logger
57 blockFuncs []BlockFunc
58}
59
60// Options for creating a new shell
61type Options struct {
62 WorkingDir string
63 Env []string
64 Logger Logger
65 BlockFuncs []BlockFunc
66}
67
68// NewShell creates a new shell instance with the given options
69func NewShell(opts *Options) *Shell {
70 if opts == nil {
71 opts = &Options{}
72 }
73
74 cwd := opts.WorkingDir
75 if cwd == "" {
76 cwd, _ = os.Getwd()
77 }
78
79 env := opts.Env
80 if env == nil {
81 env = os.Environ()
82 }
83
84 logger := opts.Logger
85 if logger == nil {
86 logger = noopLogger{}
87 }
88
89 return &Shell{
90 cwd: cwd,
91 env: env,
92 logger: logger,
93 blockFuncs: opts.BlockFuncs,
94 }
95}
96
97// Exec executes a command in the shell
98func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
99 s.mu.Lock()
100 defer s.mu.Unlock()
101
102 return s.execPOSIX(ctx, command)
103}
104
105// GetWorkingDir returns the current working directory
106func (s *Shell) GetWorkingDir() string {
107 s.mu.Lock()
108 defer s.mu.Unlock()
109 return s.cwd
110}
111
112// SetWorkingDir sets the working directory
113func (s *Shell) SetWorkingDir(dir string) error {
114 s.mu.Lock()
115 defer s.mu.Unlock()
116
117 // Verify the directory exists
118 if _, err := os.Stat(dir); err != nil {
119 return fmt.Errorf("directory does not exist: %w", err)
120 }
121
122 s.cwd = dir
123 return nil
124}
125
126// GetEnv returns a copy of the environment variables
127func (s *Shell) GetEnv() []string {
128 s.mu.Lock()
129 defer s.mu.Unlock()
130
131 env := make([]string, len(s.env))
132 copy(env, s.env)
133 return env
134}
135
136// SetEnv sets an environment variable
137func (s *Shell) SetEnv(key, value string) {
138 s.mu.Lock()
139 defer s.mu.Unlock()
140
141 // Update or add the environment variable
142 keyPrefix := key + "="
143 for i, env := range s.env {
144 if strings.HasPrefix(env, keyPrefix) {
145 s.env[i] = keyPrefix + value
146 return
147 }
148 }
149 s.env = append(s.env, keyPrefix+value)
150}
151
152// SetBlockFuncs sets the command block functions for the shell
153func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
154 s.mu.Lock()
155 defer s.mu.Unlock()
156 s.blockFuncs = blockFuncs
157}
158
159// CommandsBlocker creates a BlockFunc that blocks exact command matches
160func CommandsBlocker(cmds []string) BlockFunc {
161 bannedSet := make(map[string]struct{})
162 for _, cmd := range cmds {
163 bannedSet[cmd] = struct{}{}
164 }
165
166 return func(args []string) bool {
167 if len(args) == 0 {
168 return false
169 }
170 _, ok := bannedSet[args[0]]
171 return ok
172 }
173}
174
175// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
176func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
177 return func(parts []string) bool {
178 if len(parts) == 0 || parts[0] != cmd {
179 return false
180 }
181
182 argParts, flagParts := splitArgsFlags(parts[1:])
183 if len(argParts) < len(args) || len(flagParts) < len(flags) {
184 return false
185 }
186
187 argsMatch := slices.Equal(argParts[:len(args)], args)
188 flagsMatch := slice.IsSubset(flags, flagParts)
189
190 return argsMatch && flagsMatch
191 }
192}
193
194func splitArgsFlags(parts []string) (args []string, flags []string) {
195 args = make([]string, 0, len(parts))
196 flags = make([]string, 0, len(parts))
197 for _, part := range parts {
198 if strings.HasPrefix(part, "-") {
199 // Extract flag name before '=' if present
200 flag := part
201 if idx := strings.IndexByte(part, '='); idx != -1 {
202 flag = part[:idx]
203 }
204 flags = append(flags, flag)
205 } else {
206 args = append(args, part)
207 }
208 }
209 return args, flags
210}
211
212func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
213 return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
214 return func(ctx context.Context, args []string) error {
215 if len(args) == 0 {
216 return next(ctx, args)
217 }
218
219 for _, blockFunc := range s.blockFuncs {
220 if blockFunc(args) {
221 return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
222 }
223 }
224
225 return next(ctx, args)
226 }
227 }
228}
229
230// execPOSIX executes commands using POSIX shell emulation (cross-platform)
231func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
232 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
233 if err != nil {
234 return "", "", fmt.Errorf("could not parse command: %w", err)
235 }
236
237 var stdout, stderr bytes.Buffer
238 runner, err := interp.New(
239 interp.StdIO(nil, &stdout, &stderr),
240 interp.Interactive(false),
241 interp.Env(expand.ListEnviron(s.env...)),
242 interp.Dir(s.cwd),
243 interp.ExecHandlers(s.blockHandler(), coreutils.ExecHandler),
244 )
245 if err != nil {
246 return "", "", fmt.Errorf("could not run command: %w", err)
247 }
248
249 err = runner.Run(ctx, line)
250 s.cwd = runner.Dir
251 s.env = []string{}
252 for name, vr := range runner.Vars {
253 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
254 }
255 s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
256 return stdout.String(), stderr.String(), err
257}
258
259// IsInterrupt checks if an error is due to interruption
260func IsInterrupt(err error) bool {
261 return errors.Is(err, context.Canceled) ||
262 errors.Is(err, context.DeadlineExceeded)
263}
264
265// ExitCode extracts the exit code from an error
266func ExitCode(err error) int {
267 if err == nil {
268 return 0
269 }
270 var exitErr interp.ExitStatus
271 if errors.As(err, &exitErr) {
272 return int(exitErr)
273 }
274 return 1
275}