1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package offers two main types:
4// - Shell: A general-purpose shell executor for one-off or managed commands
5// - PersistentShell: A singleton shell that maintains state across the application
6//
7// WINDOWS COMPATIBILITY:
8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and
9// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility.
10package shell
11
12import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "os"
18 "os/exec"
19 "runtime"
20 "strings"
21 "sync"
22
23 "mvdan.cc/sh/v3/expand"
24 "mvdan.cc/sh/v3/interp"
25 "mvdan.cc/sh/v3/syntax"
26)
27
28// ShellType represents the type of shell to use
29type ShellType int
30
31const (
32 ShellTypePOSIX ShellType = iota
33 ShellTypeCmd
34 ShellTypePowerShell
35)
36
37// Logger interface for optional logging
38type Logger interface {
39 InfoPersist(msg string, keysAndValues ...interface{})
40}
41
42// noopLogger is a logger that does nothing
43type noopLogger struct{}
44
45func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
46
47// CommandBlockFunc is a function that determines if a command should be blocked
48type CommandBlockFunc func(args []string) bool
49
50// Shell provides cross-platform shell execution with optional state persistence
51type Shell struct {
52 env []string
53 cwd string
54 mu sync.Mutex
55 logger Logger
56 blockFuncs []CommandBlockFunc
57}
58
59// Options for creating a new shell
60type Options struct {
61 WorkingDir string
62 Env []string
63 Logger Logger
64 BlockFuncs []CommandBlockFunc
65}
66
67// NewShell creates a new shell instance with the given options
68func NewShell(opts *Options) *Shell {
69 if opts == nil {
70 opts = &Options{}
71 }
72
73 cwd := opts.WorkingDir
74 if cwd == "" {
75 cwd, _ = os.Getwd()
76 }
77
78 env := opts.Env
79 if env == nil {
80 env = os.Environ()
81 }
82
83 logger := opts.Logger
84 if logger == nil {
85 logger = noopLogger{}
86 }
87
88 return &Shell{
89 cwd: cwd,
90 env: env,
91 logger: logger,
92 blockFuncs: opts.BlockFuncs,
93 }
94}
95
96// Exec executes a command in the shell
97func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
98 s.mu.Lock()
99 defer s.mu.Unlock()
100
101 // Determine which shell to use based on platform and command
102 shellType := s.determineShellType(command)
103
104 switch shellType {
105 case ShellTypeCmd:
106 return s.execWindows(ctx, command, "cmd")
107 case ShellTypePowerShell:
108 return s.execWindows(ctx, command, "powershell")
109 default:
110 return s.execPOSIX(ctx, command)
111 }
112}
113
114// GetWorkingDir returns the current working directory
115func (s *Shell) GetWorkingDir() string {
116 s.mu.Lock()
117 defer s.mu.Unlock()
118 return s.cwd
119}
120
121// SetWorkingDir sets the working directory
122func (s *Shell) SetWorkingDir(dir string) error {
123 s.mu.Lock()
124 defer s.mu.Unlock()
125
126 // Verify the directory exists
127 if _, err := os.Stat(dir); err != nil {
128 return fmt.Errorf("directory does not exist: %w", err)
129 }
130
131 s.cwd = dir
132 return nil
133}
134
135// GetEnv returns a copy of the environment variables
136func (s *Shell) GetEnv() []string {
137 s.mu.Lock()
138 defer s.mu.Unlock()
139
140 env := make([]string, len(s.env))
141 copy(env, s.env)
142 return env
143}
144
145// SetEnv sets an environment variable
146func (s *Shell) SetEnv(key, value string) {
147 s.mu.Lock()
148 defer s.mu.Unlock()
149
150 // Update or add the environment variable
151 keyPrefix := key + "="
152 for i, env := range s.env {
153 if strings.HasPrefix(env, keyPrefix) {
154 s.env[i] = keyPrefix + value
155 return
156 }
157 }
158 s.env = append(s.env, keyPrefix+value)
159}
160
161// SetBlockFuncs sets the command block functions for the shell
162func (s *Shell) SetBlockFuncs(blockFuncs []CommandBlockFunc) {
163 s.mu.Lock()
164 defer s.mu.Unlock()
165 s.blockFuncs = blockFuncs
166}
167
168// Windows-specific commands that should use native shell
169var windowsNativeCommands = map[string]bool{
170 "dir": true,
171 "type": true,
172 "copy": true,
173 "move": true,
174 "del": true,
175 "md": true,
176 "mkdir": true,
177 "rd": true,
178 "rmdir": true,
179 "cls": true,
180 "where": true,
181 "tasklist": true,
182 "taskkill": true,
183 "net": true,
184 "sc": true,
185 "reg": true,
186 "wmic": true,
187}
188
189// determineShellType decides which shell to use based on platform and command
190func (s *Shell) determineShellType(command string) ShellType {
191 if runtime.GOOS != "windows" {
192 return ShellTypePOSIX
193 }
194
195 // Extract the first command from the command line
196 parts := strings.Fields(command)
197 if len(parts) == 0 {
198 return ShellTypePOSIX
199 }
200
201 firstCmd := strings.ToLower(parts[0])
202
203 // Check if it's a Windows-specific command
204 if windowsNativeCommands[firstCmd] {
205 return ShellTypeCmd
206 }
207
208 // Check for PowerShell-specific syntax
209 if strings.Contains(command, "Get-") || strings.Contains(command, "Set-") ||
210 strings.Contains(command, "New-") || strings.Contains(command, "$_") ||
211 strings.Contains(command, "| Where-Object") || strings.Contains(command, "| ForEach-Object") {
212 return ShellTypePowerShell
213 }
214
215 // Default to POSIX emulation for cross-platform compatibility
216 return ShellTypePOSIX
217}
218
219// CreateSimpleCommandBlocker creates a CommandBlockFunc that blocks exact command matches
220func CreateSimpleCommandBlocker(bannedCommands []string) CommandBlockFunc {
221 bannedSet := make(map[string]bool)
222 for _, cmd := range bannedCommands {
223 bannedSet[cmd] = true
224 }
225
226 return func(args []string) bool {
227 if len(args) == 0 {
228 return false
229 }
230 return bannedSet[args[0]]
231 }
232}
233
234// CreateSubCommandBlocker creates a CommandBlockFunc that blocks specific subcommands
235func CreateSubCommandBlocker(blockedSubCommands [][]string) CommandBlockFunc {
236 return func(args []string) bool {
237 for _, blocked := range blockedSubCommands {
238 if len(args) >= len(blocked) {
239 match := true
240 for i, part := range blocked {
241 if args[i] != part {
242 match = false
243 break
244 }
245 }
246 if match {
247 return true
248 }
249 }
250 }
251 return false
252 }
253}
254func (s *Shell) createCommandBlockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
255 return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
256 return func(ctx context.Context, args []string) error {
257 if len(args) == 0 {
258 return next(ctx, args)
259 }
260
261 for _, blockFunc := range s.blockFuncs {
262 if blockFunc(args) {
263 return fmt.Errorf("command '%s' is not allowed for security reasons", strings.Join(args, " "))
264 }
265 }
266
267 return next(ctx, args)
268 }
269 }
270}
271
272// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
273func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
274 var cmd *exec.Cmd
275
276 // Handle directory changes specially to maintain persistent shell behavior
277 if strings.HasPrefix(strings.TrimSpace(command), "cd ") {
278 return s.handleWindowsCD(command)
279 }
280
281 switch shell {
282 case "cmd":
283 // Use cmd.exe for Windows commands
284 // Add current directory context to maintain state
285 fullCommand := fmt.Sprintf("cd /d \"%s\" && %s", s.cwd, command)
286 cmd = exec.CommandContext(ctx, "cmd", "/C", fullCommand)
287 case "powershell":
288 // Use PowerShell for PowerShell commands
289 // Add current directory context to maintain state
290 fullCommand := fmt.Sprintf("Set-Location '%s'; %s", s.cwd, command)
291 cmd = exec.CommandContext(ctx, "powershell", "-Command", fullCommand)
292 default:
293 return "", "", fmt.Errorf("unsupported Windows shell: %s", shell)
294 }
295
296 // Set environment variables
297 cmd.Env = s.env
298
299 var stdout, stderr bytes.Buffer
300 cmd.Stdout = &stdout
301 cmd.Stderr = &stderr
302
303 err := cmd.Run()
304
305 s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
306 return stdout.String(), stderr.String(), err
307}
308
309// handleWindowsCD handles directory changes for Windows shells
310func (s *Shell) handleWindowsCD(command string) (string, string, error) {
311 // Extract the target directory from the cd command
312 parts := strings.Fields(command)
313 if len(parts) < 2 {
314 return "", "cd: missing directory argument", fmt.Errorf("missing directory argument")
315 }
316
317 targetDir := parts[1]
318
319 // Handle relative paths
320 if !strings.Contains(targetDir, ":") && !strings.HasPrefix(targetDir, "\\") {
321 // Relative path - resolve against current directory
322 if targetDir == ".." {
323 // Go up one directory
324 if len(s.cwd) > 3 { // Don't go above drive root (C:\)
325 lastSlash := strings.LastIndex(s.cwd, "\\")
326 if lastSlash > 2 { // Keep drive letter
327 s.cwd = s.cwd[:lastSlash]
328 }
329 }
330 } else if targetDir != "." {
331 // Go to subdirectory
332 s.cwd = s.cwd + "\\" + targetDir
333 }
334 } else {
335 // Absolute path
336 s.cwd = targetDir
337 }
338
339 // Verify the directory exists
340 if _, err := os.Stat(s.cwd); err != nil {
341 return "", fmt.Sprintf("cd: %s: No such file or directory", targetDir), err
342 }
343
344 return "", "", nil
345}
346
347// execPOSIX executes commands using POSIX shell emulation (cross-platform)
348func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
349 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
350 if err != nil {
351 return "", "", fmt.Errorf("could not parse command: %w", err)
352 }
353
354 var stdout, stderr bytes.Buffer
355 runner, err := interp.New(
356 interp.StdIO(nil, &stdout, &stderr),
357 interp.Interactive(false),
358 interp.Env(expand.ListEnviron(s.env...)),
359 interp.Dir(s.cwd),
360 interp.ExecHandlers(s.createCommandBlockHandler()),
361 )
362 if err != nil {
363 return "", "", fmt.Errorf("could not run command: %w", err)
364 }
365
366 err = runner.Run(ctx, line)
367 s.cwd = runner.Dir
368 s.env = []string{}
369 for name, vr := range runner.Vars {
370 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
371 }
372 s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
373 return stdout.String(), stderr.String(), err
374}
375
376// IsInterrupt checks if an error is due to interruption
377func IsInterrupt(err error) bool {
378 return errors.Is(err, context.Canceled) ||
379 errors.Is(err, context.DeadlineExceeded)
380}
381
382// ExitCode extracts the exit code from an error
383func ExitCode(err error) int {
384 if err == nil {
385 return 0
386 }
387 status, ok := interp.IsExitStatus(err)
388 if ok {
389 return int(status)
390 }
391 return 1
392}