1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package offers two main types:
4// - Shell: A general-purpose shell executor for one-off or managed commands
5// - PersistentShell: A singleton shell that maintains state across the application
6//
7// WINDOWS COMPATIBILITY:
8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3),
9// even on Windows. Some caution has to be taken: commands should have forward
10// slashes (/) as path separators to work, even on Windows.
11package shell
12
13import (
14 "bytes"
15 "context"
16 "errors"
17 "fmt"
18 "os"
19 "strings"
20 "sync"
21
22 "mvdan.cc/sh/v3/expand"
23 "mvdan.cc/sh/v3/interp"
24 "mvdan.cc/sh/v3/syntax"
25)
26
27// ShellType represents the type of shell to use
28type ShellType int
29
30const (
31 ShellTypePOSIX ShellType = iota
32 ShellTypeCmd
33 ShellTypePowerShell
34)
35
36// Logger interface for optional logging
37type Logger interface {
38 InfoPersist(msg string, keysAndValues ...interface{})
39}
40
41// noopLogger is a logger that does nothing
42type noopLogger struct{}
43
44func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
45
46// BlockFunc is a function that determines if a command should be blocked
47type BlockFunc func(args []string) bool
48
49// Shell provides cross-platform shell execution with optional state persistence
50type Shell struct {
51 env []string
52 cwd string
53 mu sync.Mutex
54 logger Logger
55 blockFuncs []BlockFunc
56}
57
58// Options for creating a new shell
59type Options struct {
60 WorkingDir string
61 Env []string
62 Logger Logger
63 BlockFuncs []BlockFunc
64}
65
66// NewShell creates a new shell instance with the given options
67func NewShell(opts *Options) *Shell {
68 if opts == nil {
69 opts = &Options{}
70 }
71
72 cwd := opts.WorkingDir
73 if cwd == "" {
74 cwd, _ = os.Getwd()
75 }
76
77 env := opts.Env
78 if env == nil {
79 env = os.Environ()
80 }
81
82 logger := opts.Logger
83 if logger == nil {
84 logger = noopLogger{}
85 }
86
87 return &Shell{
88 cwd: cwd,
89 env: env,
90 logger: logger,
91 blockFuncs: opts.BlockFuncs,
92 }
93}
94
95// Exec executes a command in the shell
96func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
97 s.mu.Lock()
98 defer s.mu.Unlock()
99
100 return s.execPOSIX(ctx, command)
101}
102
103// GetWorkingDir returns the current working directory
104func (s *Shell) GetWorkingDir() string {
105 s.mu.Lock()
106 defer s.mu.Unlock()
107 return s.cwd
108}
109
110// SetWorkingDir sets the working directory
111func (s *Shell) SetWorkingDir(dir string) error {
112 s.mu.Lock()
113 defer s.mu.Unlock()
114
115 // Verify the directory exists
116 if _, err := os.Stat(dir); err != nil {
117 return fmt.Errorf("directory does not exist: %w", err)
118 }
119
120 s.cwd = dir
121 return nil
122}
123
124// GetEnv returns a copy of the environment variables
125func (s *Shell) GetEnv() []string {
126 s.mu.Lock()
127 defer s.mu.Unlock()
128
129 env := make([]string, len(s.env))
130 copy(env, s.env)
131 return env
132}
133
134// SetEnv sets an environment variable
135func (s *Shell) SetEnv(key, value string) {
136 s.mu.Lock()
137 defer s.mu.Unlock()
138
139 // Update or add the environment variable
140 keyPrefix := key + "="
141 for i, env := range s.env {
142 if strings.HasPrefix(env, keyPrefix) {
143 s.env[i] = keyPrefix + value
144 return
145 }
146 }
147 s.env = append(s.env, keyPrefix+value)
148}
149
150// SetBlockFuncs sets the command block functions for the shell
151func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
152 s.mu.Lock()
153 defer s.mu.Unlock()
154 s.blockFuncs = blockFuncs
155}
156
157// CommandsBlocker creates a BlockFunc that blocks exact command matches
158func CommandsBlocker(bannedCommands []string) BlockFunc {
159 bannedSet := make(map[string]bool)
160 for _, cmd := range bannedCommands {
161 bannedSet[cmd] = true
162 }
163
164 return func(args []string) bool {
165 if len(args) == 0 {
166 return false
167 }
168 return bannedSet[args[0]]
169 }
170}
171
172// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands
173func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc {
174 return func(args []string) bool {
175 for _, blocked := range blockedSubCommands {
176 if len(args) >= len(blocked) {
177 match := true
178 for i, part := range blocked {
179 if args[i] != part {
180 match = false
181 break
182 }
183 }
184 if match {
185 return true
186 }
187 }
188 }
189 return false
190 }
191}
192
193func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
194 return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
195 return func(ctx context.Context, args []string) error {
196 if len(args) == 0 {
197 return next(ctx, args)
198 }
199
200 for _, blockFunc := range s.blockFuncs {
201 if blockFunc(args) {
202 return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
203 }
204 }
205
206 return next(ctx, args)
207 }
208 }
209}
210
211// execPOSIX executes commands using POSIX shell emulation (cross-platform)
212func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
213 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
214 if err != nil {
215 return "", "", fmt.Errorf("could not parse command: %w", err)
216 }
217
218 var stdout, stderr bytes.Buffer
219 runner, err := interp.New(
220 interp.StdIO(nil, &stdout, &stderr),
221 interp.Interactive(false),
222 interp.Env(expand.ListEnviron(s.env...)),
223 interp.Dir(s.cwd),
224 interp.ExecHandlers(s.blockHandler()),
225 )
226 if err != nil {
227 return "", "", fmt.Errorf("could not run command: %w", err)
228 }
229
230 err = runner.Run(ctx, line)
231 s.cwd = runner.Dir
232 s.env = []string{}
233 for name, vr := range runner.Vars {
234 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
235 }
236 s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
237 return stdout.String(), stderr.String(), err
238}
239
240// IsInterrupt checks if an error is due to interruption
241func IsInterrupt(err error) bool {
242 return errors.Is(err, context.Canceled) ||
243 errors.Is(err, context.DeadlineExceeded)
244}
245
246// ExitCode extracts the exit code from an error
247func ExitCode(err error) int {
248 if err == nil {
249 return 0
250 }
251 status, ok := interp.IsExitStatus(err)
252 if ok {
253 return int(status)
254 }
255 return 1
256}