1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package offers two main types:
4// - Shell: A general-purpose shell executor for one-off or managed commands
5// - PersistentShell: A singleton shell that maintains state across the application
6//
7// WINDOWS COMPATIBILITY:
8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3),
9// even on Windows. Some caution has to be taken: commands should have forward
10// slashes (/) as path separators to work, even on Windows.
11package shell
12
13import (
14 "bytes"
15 "context"
16 "errors"
17 "fmt"
18 "os"
19 "strings"
20 "sync"
21
22 "mvdan.cc/sh/moreinterp/coreutils"
23 "mvdan.cc/sh/v3/expand"
24 "mvdan.cc/sh/v3/interp"
25 "mvdan.cc/sh/v3/syntax"
26)
27
28// ShellType represents the type of shell to use
29type ShellType int
30
31const (
32 ShellTypePOSIX ShellType = iota
33 ShellTypeCmd
34 ShellTypePowerShell
35)
36
37// Logger interface for optional logging
38type Logger interface {
39 InfoPersist(msg string, keysAndValues ...any)
40}
41
42// noopLogger is a logger that does nothing
43type noopLogger struct{}
44
45func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
46
47// BlockFunc is a function that determines if a command should be blocked
48type BlockFunc func(args []string) bool
49
50// Shell provides cross-platform shell execution with optional state persistence
51type Shell struct {
52 env []string
53 cwd string
54 mu sync.Mutex
55 logger Logger
56 blockFuncs []BlockFunc
57}
58
59// Options for creating a new shell
60type Options struct {
61 WorkingDir string
62 Env []string
63 Logger Logger
64 BlockFuncs []BlockFunc
65}
66
67// NewShell creates a new shell instance with the given options
68func NewShell(opts *Options) *Shell {
69 if opts == nil {
70 opts = &Options{}
71 }
72
73 cwd := opts.WorkingDir
74 if cwd == "" {
75 cwd, _ = os.Getwd()
76 }
77
78 env := opts.Env
79 if env == nil {
80 env = os.Environ()
81 }
82
83 logger := opts.Logger
84 if logger == nil {
85 logger = noopLogger{}
86 }
87
88 return &Shell{
89 cwd: cwd,
90 env: env,
91 logger: logger,
92 blockFuncs: opts.BlockFuncs,
93 }
94}
95
96// Exec executes a command in the shell
97func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
98 s.mu.Lock()
99 defer s.mu.Unlock()
100
101 return s.execPOSIX(ctx, command)
102}
103
104// GetWorkingDir returns the current working directory
105func (s *Shell) GetWorkingDir() string {
106 s.mu.Lock()
107 defer s.mu.Unlock()
108 return s.cwd
109}
110
111// SetWorkingDir sets the working directory
112func (s *Shell) SetWorkingDir(dir string) error {
113 s.mu.Lock()
114 defer s.mu.Unlock()
115
116 // Verify the directory exists
117 if _, err := os.Stat(dir); err != nil {
118 return fmt.Errorf("directory does not exist: %w", err)
119 }
120
121 s.cwd = dir
122 return nil
123}
124
125// GetEnv returns a copy of the environment variables
126func (s *Shell) GetEnv() []string {
127 s.mu.Lock()
128 defer s.mu.Unlock()
129
130 env := make([]string, len(s.env))
131 copy(env, s.env)
132 return env
133}
134
135// SetEnv sets an environment variable
136func (s *Shell) SetEnv(key, value string) {
137 s.mu.Lock()
138 defer s.mu.Unlock()
139
140 // Update or add the environment variable
141 keyPrefix := key + "="
142 for i, env := range s.env {
143 if strings.HasPrefix(env, keyPrefix) {
144 s.env[i] = keyPrefix + value
145 return
146 }
147 }
148 s.env = append(s.env, keyPrefix+value)
149}
150
151// SetBlockFuncs sets the command block functions for the shell
152func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
153 s.mu.Lock()
154 defer s.mu.Unlock()
155 s.blockFuncs = blockFuncs
156}
157
158// CommandsBlocker creates a BlockFunc that blocks exact command matches
159func CommandsBlocker(bannedCommands []string) BlockFunc {
160 bannedSet := make(map[string]bool)
161 for _, cmd := range bannedCommands {
162 bannedSet[cmd] = true
163 }
164
165 return func(args []string) bool {
166 if len(args) == 0 {
167 return false
168 }
169 return bannedSet[args[0]]
170 }
171}
172
173// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands
174func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc {
175 return func(args []string) bool {
176 for _, blocked := range blockedSubCommands {
177 if len(args) >= len(blocked) {
178 match := true
179 for i, part := range blocked {
180 if args[i] != part {
181 match = false
182 break
183 }
184 }
185 if match {
186 return true
187 }
188 }
189 }
190 return false
191 }
192}
193
194func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
195 return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
196 return func(ctx context.Context, args []string) error {
197 if len(args) == 0 {
198 return next(ctx, args)
199 }
200
201 for _, blockFunc := range s.blockFuncs {
202 if blockFunc(args) {
203 return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
204 }
205 }
206
207 return next(ctx, args)
208 }
209 }
210}
211
212// execPOSIX executes commands using POSIX shell emulation (cross-platform)
213func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
214 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
215 if err != nil {
216 return "", "", fmt.Errorf("could not parse command: %w", err)
217 }
218
219 var stdout, stderr bytes.Buffer
220 runner, err := interp.New(
221 interp.StdIO(nil, &stdout, &stderr),
222 interp.Interactive(false),
223 interp.Env(expand.ListEnviron(s.env...)),
224 interp.Dir(s.cwd),
225 interp.ExecHandlers(s.blockHandler(), coreutils.ExecHandler),
226 )
227 if err != nil {
228 return "", "", fmt.Errorf("could not run command: %w", err)
229 }
230
231 err = runner.Run(ctx, line)
232 s.cwd = runner.Dir
233 s.env = []string{}
234 for name, vr := range runner.Vars {
235 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
236 }
237 s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
238 return stdout.String(), stderr.String(), err
239}
240
241// IsInterrupt checks if an error is due to interruption
242func IsInterrupt(err error) bool {
243 return errors.Is(err, context.Canceled) ||
244 errors.Is(err, context.DeadlineExceeded)
245}
246
247// ExitCode extracts the exit code from an error
248func ExitCode(err error) int {
249 if err == nil {
250 return 0
251 }
252 var exitErr interp.ExitStatus
253 if errors.As(err, &exitErr) {
254 return int(exitErr)
255 }
256 return 1
257}