1package bashkit
2
3import (
4 "strings"
5 "testing"
6
7 "mvdan.cc/sh/v3/syntax"
8)
9
10func TestCheck(t *testing.T) {
11 tests := []struct {
12 name string
13 script string
14 wantErr bool
15 errMatch string // string to match in error message, if wantErr is true
16 }{
17 {
18 name: "valid script",
19 script: "echo hello world",
20 wantErr: false,
21 errMatch: "",
22 },
23 {
24 name: "invalid syntax",
25 script: "echo 'unterminated string",
26 wantErr: false, // As per implementation, syntax errors are not flagged
27 errMatch: "",
28 },
29 // Git add validation tests
30 {
31 name: "git add with -A flag",
32 script: "git add -A",
33 wantErr: true,
34 errMatch: "blind git add commands",
35 },
36 {
37 name: "git add with --all flag",
38 script: "git add --all",
39 wantErr: true,
40 errMatch: "blind git add commands",
41 },
42 {
43 name: "git add with dot",
44 script: "git add .",
45 wantErr: true,
46 errMatch: "blind git add commands",
47 },
48 {
49 name: "git add with asterisk",
50 script: "git add *",
51 wantErr: true,
52 errMatch: "blind git add commands",
53 },
54 {
55 name: "git add with multiple flags including -A",
56 script: "git add -v -A",
57 wantErr: true,
58 errMatch: "blind git add commands",
59 },
60 {
61 name: "git add with specific file",
62 script: "git add main.go",
63 wantErr: false,
64 errMatch: "",
65 },
66 {
67 name: "git add with multiple specific files",
68 script: "git add main.go utils.go",
69 wantErr: false,
70 errMatch: "",
71 },
72 {
73 name: "git add with directory path",
74 script: "git add src/main.go",
75 wantErr: false,
76 errMatch: "",
77 },
78 {
79 name: "git add with git flags before add",
80 script: "git -C /path/to/repo add -A",
81 wantErr: true,
82 errMatch: "blind git add commands",
83 },
84 {
85 name: "git add with valid flags",
86 script: "git add -v main.go",
87 wantErr: false,
88 errMatch: "",
89 },
90 {
91 name: "git command without add",
92 script: "git status",
93 wantErr: false,
94 errMatch: "",
95 },
96 {
97 name: "multiline script with blind git add",
98 script: "echo 'Adding files' && git add -A && git commit -m 'Update'",
99 wantErr: true,
100 errMatch: "blind git add commands",
101 },
102 {
103 name: "git add with pattern that looks like blind but is specific",
104 script: "git add file.A",
105 wantErr: false,
106 errMatch: "",
107 },
108 {
109 name: "commented blind git add",
110 script: "# git add -A",
111 wantErr: false,
112 errMatch: "",
113 },
114 }
115
116 for _, tc := range tests {
117 t.Run(tc.name, func(t *testing.T) {
118 err := Check(tc.script)
119 if (err != nil) != tc.wantErr {
120 t.Errorf("Check() error = %v, wantErr %v", err, tc.wantErr)
121 return
122 }
123 if tc.wantErr && err != nil && !strings.Contains(err.Error(), tc.errMatch) {
124 t.Errorf("Check() error message = %v, want containing %v", err, tc.errMatch)
125 }
126 })
127 }
128}
129
130func TestWillRunGitCommit(t *testing.T) {
131 tests := []struct {
132 name string
133 script string
134 wantCommit bool
135 }{
136 {
137 name: "simple git commit",
138 script: "git commit -m 'Add feature'",
139 wantCommit: true,
140 },
141 {
142 name: "git command without commit",
143 script: "git status",
144 wantCommit: false,
145 },
146 {
147 name: "multiline script with git commit",
148 script: "echo 'Making changes' && git add . && git commit -m 'Update files'",
149 wantCommit: true,
150 },
151 {
152 name: "multiline script without git commit",
153 script: "echo 'Checking status' && git status",
154 wantCommit: false,
155 },
156 {
157 name: "script with commented git commit",
158 script: "# git commit -m 'This is commented out'",
159 wantCommit: false,
160 },
161 {
162 name: "git commit with variables",
163 script: "MSG='Fix bug' && git commit -m 'Using variable'",
164 wantCommit: true,
165 },
166 {
167 name: "only git command",
168 script: "git",
169 wantCommit: false,
170 },
171 {
172 name: "script with invalid syntax",
173 script: "git commit -m 'unterminated string",
174 wantCommit: false,
175 },
176 {
177 name: "commit used in different context",
178 script: "echo 'commit message'",
179 wantCommit: false,
180 },
181 {
182 name: "git with flags before commit",
183 script: "git -C /path/to/repo commit -m 'Update'",
184 wantCommit: true,
185 },
186 {
187 name: "git with multiple flags",
188 script: "git --git-dir=.git -C repo commit -a -m 'Update'",
189 wantCommit: true,
190 },
191 {
192 name: "git with env vars",
193 script: "GIT_AUTHOR_NAME=\"Josh Bleecher Snyder\" GIT_AUTHOR_EMAIL=\"josharian@gmail.com\" git commit -am \"Updated code\"",
194 wantCommit: true,
195 },
196 {
197 name: "git with redirections",
198 script: "git commit -m 'Fix issue' > output.log 2>&1",
199 wantCommit: true,
200 },
201 {
202 name: "git with piped commands",
203 script: "echo 'Committing' | git commit -F -",
204 wantCommit: true,
205 },
206 }
207
208 for _, tc := range tests {
209 t.Run(tc.name, func(t *testing.T) {
210 gotCommit, err := WillRunGitCommit(tc.script)
211 if err != nil {
212 t.Errorf("WillRunGitCommit() error = %v", err)
213 return
214 }
215 if gotCommit != tc.wantCommit {
216 t.Errorf("WillRunGitCommit() = %v, want %v", gotCommit, tc.wantCommit)
217 }
218 })
219 }
220}
221
222func TestSketchWipBranchProtection(t *testing.T) {
223 tests := []struct {
224 name string
225 script string
226 wantErr bool
227 errMatch string
228 resetBefore bool // if true, reset warning state before test
229 }{
230 {
231 name: "git branch rename sketch-wip",
232 script: "git branch -m sketch-wip new-branch",
233 wantErr: true,
234 errMatch: "cannot leave 'sketch-wip' branch",
235 resetBefore: true,
236 },
237 {
238 name: "git branch force rename sketch-wip",
239 script: "git branch -M sketch-wip new-branch",
240 wantErr: false, // second call should not error (already warned)
241 errMatch: "",
242 resetBefore: false,
243 },
244 {
245 name: "git checkout to other branch",
246 script: "git checkout main",
247 wantErr: false, // third call should not error (already warned)
248 errMatch: "",
249 resetBefore: false,
250 },
251 {
252 name: "git switch to other branch",
253 script: "git switch main",
254 wantErr: false, // fourth call should not error (already warned)
255 errMatch: "",
256 resetBefore: false,
257 },
258 {
259 name: "git checkout file (should be allowed)",
260 script: "git checkout -- file.txt",
261 wantErr: false,
262 errMatch: "",
263 resetBefore: false,
264 },
265 {
266 name: "git checkout path (should be allowed)",
267 script: "git checkout -- src/main.go",
268 wantErr: false,
269 errMatch: "",
270 resetBefore: false,
271 },
272 {
273 name: "git commit (should be allowed)",
274 script: "git commit -m 'test'",
275 wantErr: false,
276 errMatch: "",
277 resetBefore: false,
278 },
279 {
280 name: "git status (should be allowed)",
281 script: "git status",
282 wantErr: false,
283 errMatch: "",
284 resetBefore: false,
285 },
286 {
287 name: "git branch rename other branch (should be allowed)",
288 script: "git branch -m old-branch new-branch",
289 wantErr: false,
290 errMatch: "",
291 resetBefore: false,
292 },
293 }
294
295 for _, tc := range tests {
296 t.Run(tc.name, func(t *testing.T) {
297 if tc.resetBefore {
298 ResetSketchWipWarning()
299 }
300 err := Check(tc.script)
301 if (err != nil) != tc.wantErr {
302 t.Errorf("Check() error = %v, wantErr %v", err, tc.wantErr)
303 return
304 }
305 if tc.wantErr && err != nil && !strings.Contains(err.Error(), tc.errMatch) {
306 t.Errorf("Check() error message = %v, want containing %v", err, tc.errMatch)
307 }
308 })
309 }
310}
311
312func TestHasSketchWipBranchChanges(t *testing.T) {
313 tests := []struct {
314 name string
315 script string
316 wantHas bool
317 }{
318 {
319 name: "git branch rename sketch-wip",
320 script: "git branch -m sketch-wip new-branch",
321 wantHas: true,
322 },
323 {
324 name: "git branch force rename sketch-wip",
325 script: "git branch -M sketch-wip new-branch",
326 wantHas: true,
327 },
328 {
329 name: "git checkout to branch",
330 script: "git checkout main",
331 wantHas: true,
332 },
333 {
334 name: "git switch to branch",
335 script: "git switch main",
336 wantHas: true,
337 },
338 {
339 name: "git checkout file",
340 script: "git checkout -- file.txt",
341 wantHas: false,
342 },
343 {
344 name: "git checkout path",
345 script: "git checkout src/main.go",
346 wantHas: false,
347 },
348 {
349 name: "git checkout with .extension",
350 script: "git checkout file.go",
351 wantHas: false,
352 },
353 {
354 name: "git status",
355 script: "git status",
356 wantHas: false,
357 },
358 {
359 name: "git commit",
360 script: "git commit -m 'test'",
361 wantHas: false,
362 },
363 {
364 name: "git branch rename other",
365 script: "git branch -m old-branch new-branch",
366 wantHas: false,
367 },
368 {
369 name: "git switch with flag",
370 script: "git switch -c new-branch",
371 wantHas: false,
372 },
373 {
374 name: "git checkout with flag",
375 script: "git checkout -b new-branch",
376 wantHas: false,
377 },
378 {
379 name: "not a git command",
380 script: "echo hello",
381 wantHas: false,
382 },
383 {
384 name: "empty command",
385 script: "",
386 wantHas: false,
387 },
388 }
389
390 for _, tc := range tests {
391 t.Run(tc.name, func(t *testing.T) {
392 r := strings.NewReader(tc.script)
393 parser := syntax.NewParser()
394 file, err := parser.Parse(r, "")
395 if err != nil {
396 if tc.wantHas {
397 t.Errorf("Parse error: %v", err)
398 }
399 return
400 }
401
402 found := false
403 syntax.Walk(file, func(node syntax.Node) bool {
404 callExpr, ok := node.(*syntax.CallExpr)
405 if !ok {
406 return true
407 }
408 if hasSketchWipBranchChanges(callExpr) {
409 found = true
410 return false
411 }
412 return true
413 })
414
415 if found != tc.wantHas {
416 t.Errorf("hasSketchWipBranchChanges() = %v, want %v", found, tc.wantHas)
417 }
418 })
419 }
420}
421
422func TestDangerousRmRf(t *testing.T) {
423 tests := []struct {
424 name string
425 script string
426 wantErr bool
427 errMatch string
428 }{
429 // Dangerous rm -rf commands that should be blocked
430 {
431 name: "rm -rf .git",
432 script: "rm -rf .git",
433 wantErr: true,
434 errMatch: "could delete critical data",
435 },
436 {
437 name: "rm -rf with path ending in .git",
438 script: "rm -rf /path/to/.git",
439 wantErr: true,
440 errMatch: "could delete critical data",
441 },
442 {
443 name: "rm -rf ~ (home directory)",
444 script: "rm -rf ~",
445 wantErr: true,
446 errMatch: "could delete critical data",
447 },
448 {
449 name: "rm -rf ~/ (home directory with slash)",
450 script: "rm -rf ~/",
451 wantErr: true,
452 errMatch: "could delete critical data",
453 },
454 {
455 name: "rm -rf ~/path",
456 script: "rm -rf ~/Documents",
457 wantErr: true,
458 errMatch: "could delete critical data",
459 },
460 {
461 name: "rm -rf $HOME",
462 script: "rm -rf $HOME",
463 wantErr: true,
464 errMatch: "could delete critical data",
465 },
466 {
467 name: "rm -rf ${HOME}",
468 script: "rm -rf ${HOME}",
469 wantErr: true,
470 errMatch: "could delete critical data",
471 },
472 {
473 name: "rm -rf / (root)",
474 script: "rm -rf /",
475 wantErr: true,
476 errMatch: "could delete critical data",
477 },
478 {
479 name: "rm -rf .* (hidden files wildcard)",
480 script: "rm -rf .*",
481 wantErr: true,
482 errMatch: "could delete critical data",
483 },
484 {
485 name: "rm -rf * (all files wildcard)",
486 script: "rm -rf *",
487 wantErr: true,
488 errMatch: "could delete critical data",
489 },
490 {
491 name: "rm -rf /* (root wildcard)",
492 script: "rm -rf /*",
493 wantErr: true,
494 errMatch: "could delete critical data",
495 },
496 {
497 name: "rm -rf with separate flags",
498 script: "rm -r -f .git",
499 wantErr: true,
500 errMatch: "could delete critical data",
501 },
502 {
503 name: "rm -Rf .git (capital R)",
504 script: "rm -Rf .git",
505 wantErr: true,
506 errMatch: "could delete critical data",
507 },
508 {
509 name: "rm --recursive --force .git",
510 script: "rm --recursive --force .git",
511 wantErr: true,
512 errMatch: "could delete critical data",
513 },
514 {
515 name: "rm -rf path/.*/",
516 script: "rm -rf path/.*",
517 wantErr: true,
518 errMatch: "could delete critical data",
519 },
520 // Safe rm commands that should be allowed
521 {
522 name: "rm -rf specific directory",
523 script: "rm -rf /tmp/build",
524 wantErr: false,
525 },
526 {
527 name: "rm -rf node_modules",
528 script: "rm -rf node_modules",
529 wantErr: false,
530 },
531 {
532 name: "rm -rf specific file",
533 script: "rm -rf /tmp/file.txt",
534 wantErr: false,
535 },
536 {
537 name: "rm without recursive (safe)",
538 script: "rm -f .git",
539 wantErr: false,
540 },
541 {
542 name: "rm without force (safe)",
543 script: "rm -r .git",
544 wantErr: false,
545 },
546 {
547 name: "rm single file",
548 script: "rm file.txt",
549 wantErr: false,
550 },
551 {
552 name: "rm -rf with quoted $HOME (literal string)",
553 script: "rm -rf '$HOME'",
554 wantErr: false, // single quotes make it literal
555 },
556 // Complex commands
557 {
558 name: "multiline with dangerous rm",
559 script: "echo cleaning && rm -rf .git && echo done",
560 wantErr: true,
561 errMatch: "could delete critical data",
562 },
563 {
564 name: "multiline with safe rm",
565 script: "echo cleaning && rm -rf /tmp/build && echo done",
566 wantErr: false,
567 },
568 }
569
570 for _, tc := range tests {
571 t.Run(tc.name, func(t *testing.T) {
572 err := Check(tc.script)
573 if (err != nil) != tc.wantErr {
574 t.Errorf("Check() error = %v, wantErr %v", err, tc.wantErr)
575 return
576 }
577 if tc.wantErr && err != nil && !strings.Contains(err.Error(), tc.errMatch) {
578 t.Errorf("Check() error message = %v, want containing %v", err, tc.errMatch)
579 }
580 })
581 }
582}
583
584func TestEdgeCases(t *testing.T) {
585 tests := []struct {
586 name string
587 script string
588 wantErr bool
589 resetBefore bool // if true, reset warning state before test
590 }{
591 {
592 name: "git branch -m with current branch to sketch-wip (should be allowed)",
593 script: "git branch -m current-branch sketch-wip",
594 wantErr: false,
595 resetBefore: true,
596 },
597 {
598 name: "git branch -m sketch-wip with no destination (should be blocked)",
599 script: "git branch -m sketch-wip",
600 wantErr: true,
601 resetBefore: true,
602 },
603 {
604 name: "git branch -M with current branch to sketch-wip (should be allowed)",
605 script: "git branch -M current-branch sketch-wip",
606 wantErr: false,
607 resetBefore: true,
608 },
609 {
610 name: "git checkout with -- flags (should be allowed)",
611 script: "git checkout -- --weird-filename",
612 wantErr: false,
613 resetBefore: true,
614 },
615 {
616 name: "git switch with create flag (should be allowed)",
617 script: "git switch --create new-branch",
618 wantErr: false,
619 resetBefore: true,
620 },
621 {
622 name: "complex git command with sketch-wip rename",
623 script: "git add . && git commit -m \"test\" && git branch -m sketch-wip production",
624 wantErr: true,
625 resetBefore: true,
626 },
627 {
628 name: "git switch with -c short form (should be allowed)",
629 script: "git switch -c feature-branch",
630 wantErr: false,
631 resetBefore: true,
632 },
633 }
634
635 for _, tc := range tests {
636 t.Run(tc.name, func(t *testing.T) {
637 if tc.resetBefore {
638 ResetSketchWipWarning()
639 }
640 err := Check(tc.script)
641 if (err != nil) != tc.wantErr {
642 t.Errorf("Check() error = %v, wantErr %v", err, tc.wantErr)
643 }
644 })
645 }
646}
647
648func TestHasBlindGitAddEdgeCases(t *testing.T) {
649 tests := []struct {
650 name string
651 script string
652 wantHas bool
653 }{
654 {
655 name: "command with less than 2 args",
656 script: "git",
657 wantHas: false,
658 },
659 {
660 name: "non-git command",
661 script: "ls -A",
662 wantHas: false,
663 },
664 {
665 name: "git command without add subcommand",
666 script: "git status",
667 wantHas: false,
668 },
669 {
670 name: "git add with no arguments after add",
671 script: "git add",
672 wantHas: false,
673 },
674 {
675 name: "git add with valid file after flags",
676 script: "git add -v file.txt",
677 wantHas: false,
678 },
679 }
680
681 for _, tc := range tests {
682 t.Run(tc.name, func(t *testing.T) {
683 r := strings.NewReader(tc.script)
684 parser := syntax.NewParser()
685 file, err := parser.Parse(r, "")
686 if err != nil {
687 if tc.wantHas {
688 t.Errorf("Parse error: %v", err)
689 }
690 return
691 }
692
693 found := false
694 syntax.Walk(file, func(node syntax.Node) bool {
695 callExpr, ok := node.(*syntax.CallExpr)
696 if !ok {
697 return true
698 }
699 if hasBlindGitAdd(callExpr) {
700 found = true
701 return false
702 }
703 return true
704 })
705
706 if found != tc.wantHas {
707 t.Errorf("hasBlindGitAdd() = %v, want %v", found, tc.wantHas)
708 }
709 })
710 }
711}
712
713func TestHasSketchWipBranchChangesEdgeCases(t *testing.T) {
714 tests := []struct {
715 name string
716 script string
717 wantHas bool
718 }{
719 {
720 name: "git command with less than 2 args",
721 script: "git",
722 wantHas: false,
723 },
724 {
725 name: "non-git command",
726 script: "ls main",
727 wantHas: false,
728 },
729 {
730 name: "git branch -m with sketch-wip not as source",
731 script: "git branch -m other-branch sketch-wip",
732 wantHas: false,
733 },
734 {
735 name: "git checkout with complex path",
736 script: "git checkout src/components/file.go",
737 wantHas: false,
738 },
739 {
740 name: "git switch with complex flag",
741 script: "git switch --detach HEAD~1",
742 wantHas: false,
743 },
744 {
745 name: "git checkout with multiple flags",
746 script: "git checkout --ours --theirs file.txt",
747 wantHas: false,
748 },
749 }
750
751 for _, tc := range tests {
752 t.Run(tc.name, func(t *testing.T) {
753 r := strings.NewReader(tc.script)
754 parser := syntax.NewParser()
755 file, err := parser.Parse(r, "")
756 if err != nil {
757 if tc.wantHas {
758 t.Errorf("Parse error: %v", err)
759 }
760 return
761 }
762
763 found := false
764 syntax.Walk(file, func(node syntax.Node) bool {
765 callExpr, ok := node.(*syntax.CallExpr)
766 if !ok {
767 return true
768 }
769 if hasSketchWipBranchChanges(callExpr) {
770 found = true
771 return false
772 }
773 return true
774 })
775
776 if found != tc.wantHas {
777 t.Errorf("hasSketchWipBranchChanges() = %v, want %v", found, tc.wantHas)
778 }
779 })
780 }
781}