1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package offers two main types:
4// - Shell: A general-purpose shell executor for one-off or managed commands
5// - PersistentShell: A singleton shell that maintains state across the application
6//
7// WINDOWS COMPATIBILITY:
8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and
9// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility.
10package shell
11
12import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "os"
18 "os/exec"
19 "runtime"
20 "strings"
21 "sync"
22
23 "mvdan.cc/sh/v3/expand"
24 "mvdan.cc/sh/v3/interp"
25 "mvdan.cc/sh/v3/syntax"
26)
27
28// ShellType represents the type of shell to use
29type ShellType int
30
31const (
32 ShellTypePOSIX ShellType = iota
33 ShellTypeCmd
34 ShellTypePowerShell
35)
36
37// Logger interface for optional logging
38type Logger interface {
39 InfoPersist(msg string, keysAndValues ...interface{})
40}
41
42// noopLogger is a logger that does nothing
43type noopLogger struct{}
44
45func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
46
47// BlockFunc is a function that determines if a command should be blocked
48type BlockFunc func(args []string) bool
49
50// Shell provides cross-platform shell execution with optional state persistence
51type Shell struct {
52 env []string
53 cwd string
54 mu sync.Mutex
55 logger Logger
56 blockFuncs []BlockFunc
57}
58
59// Options for creating a new shell
60type Options struct {
61 WorkingDir string
62 Env []string
63 Logger Logger
64 BlockFuncs []BlockFunc
65}
66
67// NewShell creates a new shell instance with the given options
68func NewShell(opts *Options) *Shell {
69 if opts == nil {
70 opts = &Options{}
71 }
72
73 cwd := opts.WorkingDir
74 if cwd == "" {
75 cwd, _ = os.Getwd()
76 }
77
78 env := opts.Env
79 if env == nil {
80 env = os.Environ()
81 }
82
83 logger := opts.Logger
84 if logger == nil {
85 logger = noopLogger{}
86 }
87
88 return &Shell{
89 cwd: cwd,
90 env: env,
91 logger: logger,
92 blockFuncs: opts.BlockFuncs,
93 }
94}
95
96// Exec executes a command in the shell
97func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
98 s.mu.Lock()
99 defer s.mu.Unlock()
100
101 // Determine which shell to use based on platform and command
102 shellType := s.determineShellType(command)
103
104 switch shellType {
105 case ShellTypeCmd:
106 return s.execWindows(ctx, command, "cmd")
107 case ShellTypePowerShell:
108 return s.execWindows(ctx, command, "powershell")
109 default:
110 return s.execPOSIX(ctx, command)
111 }
112}
113
114// GetWorkingDir returns the current working directory
115func (s *Shell) GetWorkingDir() string {
116 s.mu.Lock()
117 defer s.mu.Unlock()
118 return s.cwd
119}
120
121// SetWorkingDir sets the working directory
122func (s *Shell) SetWorkingDir(dir string) error {
123 s.mu.Lock()
124 defer s.mu.Unlock()
125
126 // Verify the directory exists
127 if _, err := os.Stat(dir); err != nil {
128 return fmt.Errorf("directory does not exist: %w", err)
129 }
130
131 s.cwd = dir
132 return nil
133}
134
135// GetEnv returns a copy of the environment variables
136func (s *Shell) GetEnv() []string {
137 s.mu.Lock()
138 defer s.mu.Unlock()
139
140 env := make([]string, len(s.env))
141 copy(env, s.env)
142 return env
143}
144
145// SetEnv sets an environment variable
146func (s *Shell) SetEnv(key, value string) {
147 s.mu.Lock()
148 defer s.mu.Unlock()
149
150 // Update or add the environment variable
151 keyPrefix := key + "="
152 for i, env := range s.env {
153 if strings.HasPrefix(env, keyPrefix) {
154 s.env[i] = keyPrefix + value
155 return
156 }
157 }
158 s.env = append(s.env, keyPrefix+value)
159}
160
161// SetBlockFuncs sets the command block functions for the shell
162func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
163 s.mu.Lock()
164 defer s.mu.Unlock()
165 s.blockFuncs = blockFuncs
166}
167
168// Windows-specific commands that should use native shell
169var windowsNativeCommands = map[string]bool{
170 "dir": true,
171 "type": true,
172 "copy": true,
173 "move": true,
174 "del": true,
175 "md": true,
176 "mkdir": true,
177 "rd": true,
178 "rmdir": true,
179 "cls": true,
180 "where": true,
181 "tasklist": true,
182 "taskkill": true,
183 "net": true,
184 "sc": true,
185 "reg": true,
186 "wmic": true,
187}
188
189// determineShellType decides which shell to use based on platform and command
190func (s *Shell) determineShellType(command string) ShellType {
191 if runtime.GOOS != "windows" {
192 return ShellTypePOSIX
193 }
194
195 // Extract the first command from the command line
196 parts := strings.Fields(command)
197 if len(parts) == 0 {
198 return ShellTypePOSIX
199 }
200
201 firstCmd := strings.ToLower(parts[0])
202
203 // Check if it's a Windows-specific command
204 if windowsNativeCommands[firstCmd] {
205 return ShellTypeCmd
206 }
207
208 // Check for PowerShell-specific syntax
209 if strings.Contains(command, "Get-") || strings.Contains(command, "Set-") ||
210 strings.Contains(command, "New-") || strings.Contains(command, "$_") ||
211 strings.Contains(command, "| Where-Object") || strings.Contains(command, "| ForEach-Object") {
212 return ShellTypePowerShell
213 }
214
215 // Default to POSIX emulation for cross-platform compatibility
216 return ShellTypePOSIX
217}
218
219// CommandsBlocker creates a BlockFunc that blocks exact command matches
220func CommandsBlocker(bannedCommands []string) BlockFunc {
221 bannedSet := make(map[string]bool)
222 for _, cmd := range bannedCommands {
223 bannedSet[cmd] = true
224 }
225
226 return func(args []string) bool {
227 if len(args) == 0 {
228 return false
229 }
230 return bannedSet[args[0]]
231 }
232}
233
234// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands
235func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc {
236 return func(args []string) bool {
237 for _, blocked := range blockedSubCommands {
238 if len(args) >= len(blocked) {
239 match := true
240 for i, part := range blocked {
241 if args[i] != part {
242 match = false
243 break
244 }
245 }
246 if match {
247 return true
248 }
249 }
250 }
251 return false
252 }
253}
254
255func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
256 return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
257 return func(ctx context.Context, args []string) error {
258 if len(args) == 0 {
259 return next(ctx, args)
260 }
261
262 for _, blockFunc := range s.blockFuncs {
263 if blockFunc(args) {
264 return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
265 }
266 }
267
268 return next(ctx, args)
269 }
270 }
271}
272
273// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
274func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
275 var cmd *exec.Cmd
276
277 // Handle directory changes specially to maintain persistent shell behavior
278 if strings.HasPrefix(strings.TrimSpace(command), "cd ") {
279 return s.handleWindowsCD(command)
280 }
281
282 switch shell {
283 case "cmd":
284 // Use cmd.exe for Windows commands
285 // Add current directory context to maintain state
286 fullCommand := fmt.Sprintf("cd /d \"%s\" && %s", s.cwd, command)
287 cmd = exec.CommandContext(ctx, "cmd", "/C", fullCommand)
288 case "powershell":
289 // Use PowerShell for PowerShell commands
290 // Add current directory context to maintain state
291 fullCommand := fmt.Sprintf("Set-Location '%s'; %s", s.cwd, command)
292 cmd = exec.CommandContext(ctx, "powershell", "-Command", fullCommand)
293 default:
294 return "", "", fmt.Errorf("unsupported Windows shell: %s", shell)
295 }
296
297 // Set environment variables
298 cmd.Env = s.env
299
300 var stdout, stderr bytes.Buffer
301 cmd.Stdout = &stdout
302 cmd.Stderr = &stderr
303
304 err := cmd.Run()
305
306 s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
307 return stdout.String(), stderr.String(), err
308}
309
310// handleWindowsCD handles directory changes for Windows shells
311func (s *Shell) handleWindowsCD(command string) (string, string, error) {
312 // Extract the target directory from the cd command
313 parts := strings.Fields(command)
314 if len(parts) < 2 {
315 return "", "cd: missing directory argument", fmt.Errorf("missing directory argument")
316 }
317
318 targetDir := parts[1]
319
320 // Handle relative paths
321 if !strings.Contains(targetDir, ":") && !strings.HasPrefix(targetDir, "\\") {
322 // Relative path - resolve against current directory
323 if targetDir == ".." {
324 // Go up one directory
325 if len(s.cwd) > 3 { // Don't go above drive root (C:\)
326 lastSlash := strings.LastIndex(s.cwd, "\\")
327 if lastSlash > 2 { // Keep drive letter
328 s.cwd = s.cwd[:lastSlash]
329 }
330 }
331 } else if targetDir != "." {
332 // Go to subdirectory
333 s.cwd = s.cwd + "\\" + targetDir
334 }
335 } else {
336 // Absolute path
337 s.cwd = targetDir
338 }
339
340 // Verify the directory exists
341 if _, err := os.Stat(s.cwd); err != nil {
342 return "", fmt.Sprintf("cd: %s: No such file or directory", targetDir), err
343 }
344
345 return "", "", nil
346}
347
348// execPOSIX executes commands using POSIX shell emulation (cross-platform)
349func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
350 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
351 if err != nil {
352 return "", "", fmt.Errorf("could not parse command: %w", err)
353 }
354
355 var stdout, stderr bytes.Buffer
356 runner, err := interp.New(
357 interp.StdIO(nil, &stdout, &stderr),
358 interp.Interactive(false),
359 interp.Env(expand.ListEnviron(s.env...)),
360 interp.Dir(s.cwd),
361 interp.ExecHandlers(s.blockHandler()),
362 )
363 if err != nil {
364 return "", "", fmt.Errorf("could not run command: %w", err)
365 }
366
367 err = runner.Run(ctx, line)
368 s.cwd = runner.Dir
369 s.env = []string{}
370 for name, vr := range runner.Vars {
371 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
372 }
373 s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
374 return stdout.String(), stderr.String(), err
375}
376
377// IsInterrupt checks if an error is due to interruption
378func IsInterrupt(err error) bool {
379 return errors.Is(err, context.Canceled) ||
380 errors.Is(err, context.DeadlineExceeded)
381}
382
383// ExitCode extracts the exit code from an error
384func ExitCode(err error) int {
385 if err == nil {
386 return 0
387 }
388 status, ok := interp.IsExitStatus(err)
389 if ok {
390 return int(status)
391 }
392 return 1
393}