1package shell
2
3import (
4 "context"
5 "strings"
6 "testing"
7
8 "github.com/stretchr/testify/require"
9)
10
11func TestCommandBlocking(t *testing.T) {
12 tests := []struct {
13 name string
14 blockFuncs []BlockFunc
15 command string
16 shouldBlock bool
17 }{
18 {
19 name: "block simple command",
20 blockFuncs: []BlockFunc{
21 func(args []string) bool {
22 return len(args) > 0 && args[0] == "curl"
23 },
24 },
25 command: "curl https://example.com",
26 shouldBlock: true,
27 },
28 {
29 name: "allow non-blocked command",
30 blockFuncs: []BlockFunc{
31 func(args []string) bool {
32 return len(args) > 0 && args[0] == "curl"
33 },
34 },
35 command: "echo hello",
36 shouldBlock: false,
37 },
38 {
39 name: "block subcommand",
40 blockFuncs: []BlockFunc{
41 func(args []string) bool {
42 return len(args) >= 2 && args[0] == "brew" && args[1] == "install"
43 },
44 },
45 command: "brew install wget",
46 shouldBlock: true,
47 },
48 {
49 name: "allow different subcommand",
50 blockFuncs: []BlockFunc{
51 func(args []string) bool {
52 return len(args) >= 2 && args[0] == "brew" && args[1] == "install"
53 },
54 },
55 command: "brew list",
56 shouldBlock: false,
57 },
58 {
59 name: "block npm global install with -g",
60 blockFuncs: []BlockFunc{
61 ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
62 },
63 command: "npm install -g typescript",
64 shouldBlock: true,
65 },
66 {
67 name: "block npm global install with --global",
68 blockFuncs: []BlockFunc{
69 ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
70 },
71 command: "npm install --global typescript",
72 shouldBlock: true,
73 },
74 {
75 name: "allow npm local install",
76 blockFuncs: []BlockFunc{
77 ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
78 ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
79 },
80 command: "npm install typescript",
81 shouldBlock: false,
82 },
83 }
84
85 for _, tt := range tests {
86 t.Run(tt.name, func(t *testing.T) {
87 // Create a temporary directory for each test
88 tmpDir := t.TempDir()
89
90 shell := NewShell(&Options{
91 WorkingDir: tmpDir,
92 BlockFuncs: tt.blockFuncs,
93 })
94
95 _, _, err := shell.Exec(context.Background(), tt.command)
96
97 if tt.shouldBlock {
98 if err == nil {
99 t.Errorf("Expected command to be blocked, but it was allowed")
100 } else if !strings.Contains(err.Error(), "not allowed for security reasons") {
101 t.Errorf("Expected security error, got: %v", err)
102 }
103 } else {
104 // For non-blocked commands, we might get other errors (like command not found)
105 // but we shouldn't get the security error
106 if err != nil && strings.Contains(err.Error(), "not allowed for security reasons") {
107 t.Errorf("Command was unexpectedly blocked: %v", err)
108 }
109 }
110 })
111 }
112}
113
114func TestArgumentsBlocker(t *testing.T) {
115 tests := []struct {
116 name string
117 cmd string
118 args []string
119 flags []string
120 input []string
121 shouldBlock bool
122 }{
123 // Basic command blocking
124 {
125 name: "block exact command match",
126 cmd: "npm",
127 args: []string{"install"},
128 flags: nil,
129 input: []string{"npm", "install", "package"},
130 shouldBlock: true,
131 },
132 {
133 name: "allow different command",
134 cmd: "npm",
135 args: []string{"install"},
136 flags: nil,
137 input: []string{"yarn", "install", "package"},
138 shouldBlock: false,
139 },
140 {
141 name: "allow different subcommand",
142 cmd: "npm",
143 args: []string{"install"},
144 flags: nil,
145 input: []string{"npm", "list"},
146 shouldBlock: false,
147 },
148
149 // Flag-based blocking
150 {
151 name: "block with single flag",
152 cmd: "npm",
153 args: []string{"install"},
154 flags: []string{"-g"},
155 input: []string{"npm", "install", "-g", "typescript"},
156 shouldBlock: true,
157 },
158 {
159 name: "block with flag in different position",
160 cmd: "npm",
161 args: []string{"install"},
162 flags: []string{"-g"},
163 input: []string{"npm", "install", "typescript", "-g"},
164 shouldBlock: true,
165 },
166 {
167 name: "allow without required flag",
168 cmd: "npm",
169 args: []string{"install"},
170 flags: []string{"-g"},
171 input: []string{"npm", "install", "typescript"},
172 shouldBlock: false,
173 },
174 {
175 name: "block with multiple flags",
176 cmd: "pip",
177 args: []string{"install"},
178 flags: []string{"--user"},
179 input: []string{"pip", "install", "--user", "--upgrade", "package"},
180 shouldBlock: true,
181 },
182
183 // Complex argument patterns
184 {
185 name: "block multi-arg subcommand",
186 cmd: "yarn",
187 args: []string{"global", "add"},
188 flags: nil,
189 input: []string{"yarn", "global", "add", "typescript"},
190 shouldBlock: true,
191 },
192 {
193 name: "allow partial multi-arg match",
194 cmd: "yarn",
195 args: []string{"global", "add"},
196 flags: nil,
197 input: []string{"yarn", "global", "list"},
198 shouldBlock: false,
199 },
200
201 // Edge cases
202 {
203 name: "handle empty input",
204 cmd: "npm",
205 args: []string{"install"},
206 flags: nil,
207 input: []string{},
208 shouldBlock: false,
209 },
210 {
211 name: "handle command only",
212 cmd: "npm",
213 args: []string{"install"},
214 flags: nil,
215 input: []string{"npm"},
216 shouldBlock: false,
217 },
218 {
219 name: "block pacman with -S flag",
220 cmd: "pacman",
221 args: nil,
222 flags: []string{"-S"},
223 input: []string{"pacman", "-S", "package"},
224 shouldBlock: true,
225 },
226 {
227 name: "allow pacman without -S flag",
228 cmd: "pacman",
229 args: nil,
230 flags: []string{"-S"},
231 input: []string{"pacman", "-Q", "package"},
232 shouldBlock: false,
233 },
234 }
235
236 for _, tt := range tests {
237 t.Run(tt.name, func(t *testing.T) {
238 blocker := ArgumentsBlocker(tt.cmd, tt.args, tt.flags)
239 result := blocker(tt.input)
240 require.Equal(t, tt.shouldBlock, result,
241 "Expected block=%v for input %v", tt.shouldBlock, tt.input)
242 })
243 }
244}
245
246func TestCommandsBlocker(t *testing.T) {
247 tests := []struct {
248 name string
249 banned []string
250 input []string
251 shouldBlock bool
252 }{
253 {
254 name: "block single banned command",
255 banned: []string{"curl"},
256 input: []string{"curl", "https://example.com"},
257 shouldBlock: true,
258 },
259 {
260 name: "allow non-banned command",
261 banned: []string{"curl", "wget"},
262 input: []string{"echo", "hello"},
263 shouldBlock: false,
264 },
265 {
266 name: "block from multiple banned",
267 banned: []string{"curl", "wget", "nc"},
268 input: []string{"wget", "https://example.com"},
269 shouldBlock: true,
270 },
271 {
272 name: "handle empty input",
273 banned: []string{"curl"},
274 input: []string{},
275 shouldBlock: false,
276 },
277 {
278 name: "case sensitive matching",
279 banned: []string{"curl"},
280 input: []string{"CURL", "https://example.com"},
281 shouldBlock: false,
282 },
283 }
284
285 for _, tt := range tests {
286 t.Run(tt.name, func(t *testing.T) {
287 blocker := CommandsBlocker(tt.banned)
288 result := blocker(tt.input)
289 require.Equal(t, tt.shouldBlock, result,
290 "Expected block=%v for input %v", tt.shouldBlock, tt.input)
291 })
292 }
293}
294
295func TestSplitArgsFlags(t *testing.T) {
296 tests := []struct {
297 name string
298 input []string
299 wantArgs []string
300 wantFlags []string
301 }{
302 {
303 name: "only args",
304 input: []string{"install", "package", "another"},
305 wantArgs: []string{"install", "package", "another"},
306 wantFlags: []string{},
307 },
308 {
309 name: "only flags",
310 input: []string{"-g", "--verbose", "-f"},
311 wantArgs: []string{},
312 wantFlags: []string{"-g", "--verbose", "-f"},
313 },
314 {
315 name: "mixed args and flags",
316 input: []string{"install", "-g", "package", "--verbose"},
317 wantArgs: []string{"install", "package"},
318 wantFlags: []string{"-g", "--verbose"},
319 },
320 {
321 name: "empty input",
322 input: []string{},
323 wantArgs: []string{},
324 wantFlags: []string{},
325 },
326 {
327 name: "single dash flag",
328 input: []string{"-S", "package"},
329 wantArgs: []string{"package"},
330 wantFlags: []string{"-S"},
331 },
332 }
333
334 for _, tt := range tests {
335 t.Run(tt.name, func(t *testing.T) {
336 args, flags := splitArgsFlags(tt.input)
337 require.Equal(t, tt.wantArgs, args, "args mismatch")
338 require.Equal(t, tt.wantFlags, flags, "flags mismatch")
339 })
340 }
341}