1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package offers two main types:
4// - Shell: A general-purpose shell executor for one-off or managed commands
5// - PersistentShell: A singleton shell that maintains state across the application
6//
7// WINDOWS COMPATIBILITY:
8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3),
9// even on Windows. Some caution has to be taken: commands should have forward
10// slashes (/) as path separators to work, even on Windows.
11package shell
12
13import (
14 "bytes"
15 "context"
16 "errors"
17 "fmt"
18 "os"
19 "strings"
20 "sync"
21
22 "mvdan.cc/sh/moreinterp/coreutils"
23 "mvdan.cc/sh/v3/expand"
24 "mvdan.cc/sh/v3/interp"
25 "mvdan.cc/sh/v3/syntax"
26)
27
28// ShellType represents the type of shell to use
29type ShellType int
30
31const (
32 ShellTypePOSIX ShellType = iota
33 ShellTypeCmd
34 ShellTypePowerShell
35)
36
37// Logger interface for optional logging
38type Logger interface {
39 InfoPersist(msg string, keysAndValues ...any)
40}
41
42// noopLogger is a logger that does nothing
43type noopLogger struct{}
44
45func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
46
47// BlockFunc is a function that determines if a command should be blocked
48type BlockFunc func(args []string) bool
49
50// Shell provides cross-platform shell execution with optional state persistence
51type Shell struct {
52 env []string
53 cwd string
54 mu sync.Mutex
55 logger Logger
56 blockFuncs []BlockFunc
57}
58
59// Options for creating a new shell
60type Options struct {
61 WorkingDir string
62 Env []string
63 Logger Logger
64 BlockFuncs []BlockFunc
65}
66
67// NewShell creates a new shell instance with the given options
68func NewShell(opts *Options) *Shell {
69 if opts == nil {
70 opts = &Options{}
71 }
72
73 cwd := opts.WorkingDir
74 if cwd == "" {
75 cwd, _ = os.Getwd()
76 }
77
78 env := opts.Env
79 if env == nil {
80 env = os.Environ()
81 }
82
83 logger := opts.Logger
84 if logger == nil {
85 logger = noopLogger{}
86 }
87
88 return &Shell{
89 cwd: cwd,
90 env: env,
91 logger: logger,
92 blockFuncs: opts.BlockFuncs,
93 }
94}
95
96// Exec executes a command in the shell
97func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
98 s.mu.Lock()
99 defer s.mu.Unlock()
100
101 return s.execPOSIX(ctx, command)
102}
103
104// GetWorkingDir returns the current working directory
105func (s *Shell) GetWorkingDir() string {
106 s.mu.Lock()
107 defer s.mu.Unlock()
108 return s.cwd
109}
110
111// SetWorkingDir sets the working directory
112func (s *Shell) SetWorkingDir(dir string) error {
113 s.mu.Lock()
114 defer s.mu.Unlock()
115
116 // Verify the directory exists
117 if _, err := os.Stat(dir); err != nil {
118 return fmt.Errorf("directory does not exist: %w", err)
119 }
120
121 s.cwd = dir
122 return nil
123}
124
125// GetEnv returns a copy of the environment variables
126func (s *Shell) GetEnv() []string {
127 s.mu.Lock()
128 defer s.mu.Unlock()
129
130 env := make([]string, len(s.env))
131 copy(env, s.env)
132 return env
133}
134
135// SetEnv sets an environment variable
136func (s *Shell) SetEnv(key, value string) {
137 s.mu.Lock()
138 defer s.mu.Unlock()
139
140 // Update or add the environment variable
141 keyPrefix := key + "="
142 for i, env := range s.env {
143 if strings.HasPrefix(env, keyPrefix) {
144 s.env[i] = keyPrefix + value
145 return
146 }
147 }
148 s.env = append(s.env, keyPrefix+value)
149}
150
151// SetBlockFuncs sets the command block functions for the shell
152func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
153 s.mu.Lock()
154 defer s.mu.Unlock()
155 s.blockFuncs = blockFuncs
156}
157
158// CommandsBlocker creates a BlockFunc that blocks exact command matches
159func CommandsBlocker(bannedCommands []string) BlockFunc {
160 bannedSet := make(map[string]struct{})
161 for _, cmd := range bannedCommands {
162 bannedSet[cmd] = struct{}{}
163 }
164
165 return func(args []string) bool {
166 if len(args) == 0 {
167 return false
168 }
169 _, ok := bannedSet[args[0]]
170 return ok
171 }
172}
173
174// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands
175func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc {
176 return func(args []string) bool {
177 for _, blocked := range blockedSubCommands {
178 if len(args) >= len(blocked) {
179 match := true
180 for i, part := range blocked {
181 if args[i] != part {
182 match = false
183 break
184 }
185 }
186 if match {
187 return true
188 }
189 }
190 }
191 return false
192 }
193}
194
195func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
196 return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
197 return func(ctx context.Context, args []string) error {
198 if len(args) == 0 {
199 return next(ctx, args)
200 }
201
202 for _, blockFunc := range s.blockFuncs {
203 if blockFunc(args) {
204 return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
205 }
206 }
207
208 return next(ctx, args)
209 }
210 }
211}
212
213// execPOSIX executes commands using POSIX shell emulation (cross-platform)
214func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
215 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
216 if err != nil {
217 return "", "", fmt.Errorf("could not parse command: %w", err)
218 }
219
220 var stdout, stderr bytes.Buffer
221 runner, err := interp.New(
222 interp.StdIO(nil, &stdout, &stderr),
223 interp.Interactive(false),
224 interp.Env(expand.ListEnviron(s.env...)),
225 interp.Dir(s.cwd),
226 interp.ExecHandlers(s.blockHandler(), coreutils.ExecHandler),
227 )
228 if err != nil {
229 return "", "", fmt.Errorf("could not run command: %w", err)
230 }
231
232 err = runner.Run(ctx, line)
233 s.cwd = runner.Dir
234 s.env = []string{}
235 for name, vr := range runner.Vars {
236 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
237 }
238 s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
239 return stdout.String(), stderr.String(), err
240}
241
242// IsInterrupt checks if an error is due to interruption
243func IsInterrupt(err error) bool {
244 return errors.Is(err, context.Canceled) ||
245 errors.Is(err, context.DeadlineExceeded)
246}
247
248// ExitCode extracts the exit code from an error
249func ExitCode(err error) int {
250 if err == nil {
251 return 0
252 }
253 var exitErr interp.ExitStatus
254 if errors.As(err, &exitErr) {
255 return int(exitErr)
256 }
257 return 1
258}