bashkit_test.go

  1package bashkit
  2
  3import (
  4	"strings"
  5	"testing"
  6
  7	"mvdan.cc/sh/v3/syntax"
  8)
  9
 10func TestCheck(t *testing.T) {
 11	tests := []struct {
 12		name     string
 13		script   string
 14		wantErr  bool
 15		errMatch string // string to match in error message, if wantErr is true
 16	}{
 17		{
 18			name:     "valid script",
 19			script:   "echo hello world",
 20			wantErr:  false,
 21			errMatch: "",
 22		},
 23		{
 24			name:     "invalid syntax",
 25			script:   "echo 'unterminated string",
 26			wantErr:  false, // As per implementation, syntax errors are not flagged
 27			errMatch: "",
 28		},
 29		// Git add validation tests
 30		{
 31			name:     "git add with -A flag",
 32			script:   "git add -A",
 33			wantErr:  true,
 34			errMatch: "blind git add commands",
 35		},
 36		{
 37			name:     "git add with --all flag",
 38			script:   "git add --all",
 39			wantErr:  true,
 40			errMatch: "blind git add commands",
 41		},
 42		{
 43			name:     "git add with dot",
 44			script:   "git add .",
 45			wantErr:  true,
 46			errMatch: "blind git add commands",
 47		},
 48		{
 49			name:     "git add with asterisk",
 50			script:   "git add *",
 51			wantErr:  true,
 52			errMatch: "blind git add commands",
 53		},
 54		{
 55			name:     "git add with multiple flags including -A",
 56			script:   "git add -v -A",
 57			wantErr:  true,
 58			errMatch: "blind git add commands",
 59		},
 60		{
 61			name:     "git add with specific file",
 62			script:   "git add main.go",
 63			wantErr:  false,
 64			errMatch: "",
 65		},
 66		{
 67			name:     "git add with multiple specific files",
 68			script:   "git add main.go utils.go",
 69			wantErr:  false,
 70			errMatch: "",
 71		},
 72		{
 73			name:     "git add with directory path",
 74			script:   "git add src/main.go",
 75			wantErr:  false,
 76			errMatch: "",
 77		},
 78		{
 79			name:     "git add with git flags before add",
 80			script:   "git -C /path/to/repo add -A",
 81			wantErr:  true,
 82			errMatch: "blind git add commands",
 83		},
 84		{
 85			name:     "git add with valid flags",
 86			script:   "git add -v main.go",
 87			wantErr:  false,
 88			errMatch: "",
 89		},
 90		{
 91			name:     "git command without add",
 92			script:   "git status",
 93			wantErr:  false,
 94			errMatch: "",
 95		},
 96		{
 97			name:     "multiline script with blind git add",
 98			script:   "echo 'Adding files' && git add -A && git commit -m 'Update'",
 99			wantErr:  true,
100			errMatch: "blind git add commands",
101		},
102		{
103			name:     "git add with pattern that looks like blind but is specific",
104			script:   "git add file.A",
105			wantErr:  false,
106			errMatch: "",
107		},
108		{
109			name:     "commented blind git add",
110			script:   "# git add -A",
111			wantErr:  false,
112			errMatch: "",
113		},
114	}
115
116	for _, tc := range tests {
117		t.Run(tc.name, func(t *testing.T) {
118			err := Check(tc.script)
119			if (err != nil) != tc.wantErr {
120				t.Errorf("Check() error = %v, wantErr %v", err, tc.wantErr)
121				return
122			}
123			if tc.wantErr && err != nil && !strings.Contains(err.Error(), tc.errMatch) {
124				t.Errorf("Check() error message = %v, want containing %v", err, tc.errMatch)
125			}
126		})
127	}
128}
129
130func TestWillRunGitCommit(t *testing.T) {
131	tests := []struct {
132		name       string
133		script     string
134		wantCommit bool
135	}{
136		{
137			name:       "simple git commit",
138			script:     "git commit -m 'Add feature'",
139			wantCommit: true,
140		},
141		{
142			name:       "git command without commit",
143			script:     "git status",
144			wantCommit: false,
145		},
146		{
147			name:       "multiline script with git commit",
148			script:     "echo 'Making changes' && git add . && git commit -m 'Update files'",
149			wantCommit: true,
150		},
151		{
152			name:       "multiline script without git commit",
153			script:     "echo 'Checking status' && git status",
154			wantCommit: false,
155		},
156		{
157			name:       "script with commented git commit",
158			script:     "# git commit -m 'This is commented out'",
159			wantCommit: false,
160		},
161		{
162			name:       "git commit with variables",
163			script:     "MSG='Fix bug' && git commit -m 'Using variable'",
164			wantCommit: true,
165		},
166		{
167			name:       "only git command",
168			script:     "git",
169			wantCommit: false,
170		},
171		{
172			name:       "script with invalid syntax",
173			script:     "git commit -m 'unterminated string",
174			wantCommit: false,
175		},
176		{
177			name:       "commit used in different context",
178			script:     "echo 'commit message'",
179			wantCommit: false,
180		},
181		{
182			name:       "git with flags before commit",
183			script:     "git -C /path/to/repo commit -m 'Update'",
184			wantCommit: true,
185		},
186		{
187			name:       "git with multiple flags",
188			script:     "git --git-dir=.git -C repo commit -a -m 'Update'",
189			wantCommit: true,
190		},
191		{
192			name:       "git with env vars",
193			script:     "GIT_AUTHOR_NAME=\"Josh Bleecher Snyder\" GIT_AUTHOR_EMAIL=\"josharian@gmail.com\" git commit -am \"Updated code\"",
194			wantCommit: true,
195		},
196		{
197			name:       "git with redirections",
198			script:     "git commit -m 'Fix issue' > output.log 2>&1",
199			wantCommit: true,
200		},
201		{
202			name:       "git with piped commands",
203			script:     "echo 'Committing' | git commit -F -",
204			wantCommit: true,
205		},
206	}
207
208	for _, tc := range tests {
209		t.Run(tc.name, func(t *testing.T) {
210			gotCommit, err := WillRunGitCommit(tc.script)
211			if err != nil {
212				t.Errorf("WillRunGitCommit() error = %v", err)
213				return
214			}
215			if gotCommit != tc.wantCommit {
216				t.Errorf("WillRunGitCommit() = %v, want %v", gotCommit, tc.wantCommit)
217			}
218		})
219	}
220}
221
222func TestSketchWipBranchProtection(t *testing.T) {
223	tests := []struct {
224		name        string
225		script      string
226		wantErr     bool
227		errMatch    string
228		resetBefore bool // if true, reset warning state before test
229	}{
230		{
231			name:        "git branch rename sketch-wip",
232			script:      "git branch -m sketch-wip new-branch",
233			wantErr:     true,
234			errMatch:    "cannot leave 'sketch-wip' branch",
235			resetBefore: true,
236		},
237		{
238			name:        "git branch force rename sketch-wip",
239			script:      "git branch -M sketch-wip new-branch",
240			wantErr:     false, // second call should not error (already warned)
241			errMatch:    "",
242			resetBefore: false,
243		},
244		{
245			name:        "git checkout to other branch",
246			script:      "git checkout main",
247			wantErr:     false, // third call should not error (already warned)
248			errMatch:    "",
249			resetBefore: false,
250		},
251		{
252			name:        "git switch to other branch",
253			script:      "git switch main",
254			wantErr:     false, // fourth call should not error (already warned)
255			errMatch:    "",
256			resetBefore: false,
257		},
258		{
259			name:        "git checkout file (should be allowed)",
260			script:      "git checkout -- file.txt",
261			wantErr:     false,
262			errMatch:    "",
263			resetBefore: false,
264		},
265		{
266			name:        "git checkout path (should be allowed)",
267			script:      "git checkout -- src/main.go",
268			wantErr:     false,
269			errMatch:    "",
270			resetBefore: false,
271		},
272		{
273			name:        "git commit (should be allowed)",
274			script:      "git commit -m 'test'",
275			wantErr:     false,
276			errMatch:    "",
277			resetBefore: false,
278		},
279		{
280			name:        "git status (should be allowed)",
281			script:      "git status",
282			wantErr:     false,
283			errMatch:    "",
284			resetBefore: false,
285		},
286		{
287			name:        "git branch rename other branch (should be allowed)",
288			script:      "git branch -m old-branch new-branch",
289			wantErr:     false,
290			errMatch:    "",
291			resetBefore: false,
292		},
293	}
294
295	for _, tc := range tests {
296		t.Run(tc.name, func(t *testing.T) {
297			if tc.resetBefore {
298				ResetSketchWipWarning()
299			}
300			err := Check(tc.script)
301			if (err != nil) != tc.wantErr {
302				t.Errorf("Check() error = %v, wantErr %v", err, tc.wantErr)
303				return
304			}
305			if tc.wantErr && err != nil && !strings.Contains(err.Error(), tc.errMatch) {
306				t.Errorf("Check() error message = %v, want containing %v", err, tc.errMatch)
307			}
308		})
309	}
310}
311
312func TestHasSketchWipBranchChanges(t *testing.T) {
313	tests := []struct {
314		name    string
315		script  string
316		wantHas bool
317	}{
318		{
319			name:    "git branch rename sketch-wip",
320			script:  "git branch -m sketch-wip new-branch",
321			wantHas: true,
322		},
323		{
324			name:    "git branch force rename sketch-wip",
325			script:  "git branch -M sketch-wip new-branch",
326			wantHas: true,
327		},
328		{
329			name:    "git checkout to branch",
330			script:  "git checkout main",
331			wantHas: true,
332		},
333		{
334			name:    "git switch to branch",
335			script:  "git switch main",
336			wantHas: true,
337		},
338		{
339			name:    "git checkout file",
340			script:  "git checkout -- file.txt",
341			wantHas: false,
342		},
343		{
344			name:    "git checkout path",
345			script:  "git checkout src/main.go",
346			wantHas: false,
347		},
348		{
349			name:    "git checkout with .extension",
350			script:  "git checkout file.go",
351			wantHas: false,
352		},
353		{
354			name:    "git status",
355			script:  "git status",
356			wantHas: false,
357		},
358		{
359			name:    "git commit",
360			script:  "git commit -m 'test'",
361			wantHas: false,
362		},
363		{
364			name:    "git branch rename other",
365			script:  "git branch -m old-branch new-branch",
366			wantHas: false,
367		},
368		{
369			name:    "git switch with flag",
370			script:  "git switch -c new-branch",
371			wantHas: false,
372		},
373		{
374			name:    "git checkout with flag",
375			script:  "git checkout -b new-branch",
376			wantHas: false,
377		},
378		{
379			name:    "not a git command",
380			script:  "echo hello",
381			wantHas: false,
382		},
383		{
384			name:    "empty command",
385			script:  "",
386			wantHas: false,
387		},
388	}
389
390	for _, tc := range tests {
391		t.Run(tc.name, func(t *testing.T) {
392			r := strings.NewReader(tc.script)
393			parser := syntax.NewParser()
394			file, err := parser.Parse(r, "")
395			if err != nil {
396				if tc.wantHas {
397					t.Errorf("Parse error: %v", err)
398				}
399				return
400			}
401
402			found := false
403			syntax.Walk(file, func(node syntax.Node) bool {
404				callExpr, ok := node.(*syntax.CallExpr)
405				if !ok {
406					return true
407				}
408				if hasSketchWipBranchChanges(callExpr) {
409					found = true
410					return false
411				}
412				return true
413			})
414
415			if found != tc.wantHas {
416				t.Errorf("hasSketchWipBranchChanges() = %v, want %v", found, tc.wantHas)
417			}
418		})
419	}
420}
421
422func TestDangerousRmRf(t *testing.T) {
423	tests := []struct {
424		name     string
425		script   string
426		wantErr  bool
427		errMatch string
428	}{
429		// Dangerous rm -rf commands that should be blocked
430		{
431			name:     "rm -rf .git",
432			script:   "rm -rf .git",
433			wantErr:  true,
434			errMatch: "could delete critical data",
435		},
436		{
437			name:     "rm -rf with path ending in .git",
438			script:   "rm -rf /path/to/.git",
439			wantErr:  true,
440			errMatch: "could delete critical data",
441		},
442		{
443			name:     "rm -rf ~ (home directory)",
444			script:   "rm -rf ~",
445			wantErr:  true,
446			errMatch: "could delete critical data",
447		},
448		{
449			name:     "rm -rf ~/ (home directory with slash)",
450			script:   "rm -rf ~/",
451			wantErr:  true,
452			errMatch: "could delete critical data",
453		},
454		{
455			name:     "rm -rf ~/path",
456			script:   "rm -rf ~/Documents",
457			wantErr:  true,
458			errMatch: "could delete critical data",
459		},
460		{
461			name:     "rm -rf $HOME",
462			script:   "rm -rf $HOME",
463			wantErr:  true,
464			errMatch: "could delete critical data",
465		},
466		{
467			name:     "rm -rf ${HOME}",
468			script:   "rm -rf ${HOME}",
469			wantErr:  true,
470			errMatch: "could delete critical data",
471		},
472		{
473			name:     "rm -rf / (root)",
474			script:   "rm -rf /",
475			wantErr:  true,
476			errMatch: "could delete critical data",
477		},
478		{
479			name:     "rm -rf .* (hidden files wildcard)",
480			script:   "rm -rf .*",
481			wantErr:  true,
482			errMatch: "could delete critical data",
483		},
484		{
485			name:     "rm -rf * (all files wildcard)",
486			script:   "rm -rf *",
487			wantErr:  true,
488			errMatch: "could delete critical data",
489		},
490		{
491			name:     "rm -rf /* (root wildcard)",
492			script:   "rm -rf /*",
493			wantErr:  true,
494			errMatch: "could delete critical data",
495		},
496		{
497			name:     "rm -rf with separate flags",
498			script:   "rm -r -f .git",
499			wantErr:  true,
500			errMatch: "could delete critical data",
501		},
502		{
503			name:     "rm -Rf .git (capital R)",
504			script:   "rm -Rf .git",
505			wantErr:  true,
506			errMatch: "could delete critical data",
507		},
508		{
509			name:     "rm --recursive --force .git",
510			script:   "rm --recursive --force .git",
511			wantErr:  true,
512			errMatch: "could delete critical data",
513		},
514		{
515			name:     "rm -rf path/.*/",
516			script:   "rm -rf path/.*",
517			wantErr:  true,
518			errMatch: "could delete critical data",
519		},
520		// Safe rm commands that should be allowed
521		{
522			name:    "rm -rf specific directory",
523			script:  "rm -rf /tmp/build",
524			wantErr: false,
525		},
526		{
527			name:    "rm -rf node_modules",
528			script:  "rm -rf node_modules",
529			wantErr: false,
530		},
531		{
532			name:    "rm -rf specific file",
533			script:  "rm -rf /tmp/file.txt",
534			wantErr: false,
535		},
536		{
537			name:    "rm without recursive (safe)",
538			script:  "rm -f .git",
539			wantErr: false,
540		},
541		{
542			name:    "rm without force (safe)",
543			script:  "rm -r .git",
544			wantErr: false,
545		},
546		{
547			name:    "rm single file",
548			script:  "rm file.txt",
549			wantErr: false,
550		},
551		{
552			name:    "rm -rf with quoted $HOME (literal string)",
553			script:  "rm -rf '$HOME'",
554			wantErr: false, // single quotes make it literal
555		},
556		// Complex commands
557		{
558			name:     "multiline with dangerous rm",
559			script:   "echo cleaning && rm -rf .git && echo done",
560			wantErr:  true,
561			errMatch: "could delete critical data",
562		},
563		{
564			name:    "multiline with safe rm",
565			script:  "echo cleaning && rm -rf /tmp/build && echo done",
566			wantErr: false,
567		},
568	}
569
570	for _, tc := range tests {
571		t.Run(tc.name, func(t *testing.T) {
572			err := Check(tc.script)
573			if (err != nil) != tc.wantErr {
574				t.Errorf("Check() error = %v, wantErr %v", err, tc.wantErr)
575				return
576			}
577			if tc.wantErr && err != nil && !strings.Contains(err.Error(), tc.errMatch) {
578				t.Errorf("Check() error message = %v, want containing %v", err, tc.errMatch)
579			}
580		})
581	}
582}
583
584func TestEdgeCases(t *testing.T) {
585	tests := []struct {
586		name        string
587		script      string
588		wantErr     bool
589		resetBefore bool // if true, reset warning state before test
590	}{
591		{
592			name:        "git branch -m with current branch to sketch-wip (should be allowed)",
593			script:      "git branch -m current-branch sketch-wip",
594			wantErr:     false,
595			resetBefore: true,
596		},
597		{
598			name:        "git branch -m sketch-wip with no destination (should be blocked)",
599			script:      "git branch -m sketch-wip",
600			wantErr:     true,
601			resetBefore: true,
602		},
603		{
604			name:        "git branch -M with current branch to sketch-wip (should be allowed)",
605			script:      "git branch -M current-branch sketch-wip",
606			wantErr:     false,
607			resetBefore: true,
608		},
609		{
610			name:        "git checkout with -- flags (should be allowed)",
611			script:      "git checkout -- --weird-filename",
612			wantErr:     false,
613			resetBefore: true,
614		},
615		{
616			name:        "git switch with create flag (should be allowed)",
617			script:      "git switch --create new-branch",
618			wantErr:     false,
619			resetBefore: true,
620		},
621		{
622			name:        "complex git command with sketch-wip rename",
623			script:      "git add . && git commit -m \"test\" && git branch -m sketch-wip production",
624			wantErr:     true,
625			resetBefore: true,
626		},
627		{
628			name:        "git switch with -c short form (should be allowed)",
629			script:      "git switch -c feature-branch",
630			wantErr:     false,
631			resetBefore: true,
632		},
633	}
634
635	for _, tc := range tests {
636		t.Run(tc.name, func(t *testing.T) {
637			if tc.resetBefore {
638				ResetSketchWipWarning()
639			}
640			err := Check(tc.script)
641			if (err != nil) != tc.wantErr {
642				t.Errorf("Check() error = %v, wantErr %v", err, tc.wantErr)
643			}
644		})
645	}
646}
647
648func TestHasBlindGitAddEdgeCases(t *testing.T) {
649	tests := []struct {
650		name    string
651		script  string
652		wantHas bool
653	}{
654		{
655			name:    "command with less than 2 args",
656			script:  "git",
657			wantHas: false,
658		},
659		{
660			name:    "non-git command",
661			script:  "ls -A",
662			wantHas: false,
663		},
664		{
665			name:    "git command without add subcommand",
666			script:  "git status",
667			wantHas: false,
668		},
669		{
670			name:    "git add with no arguments after add",
671			script:  "git add",
672			wantHas: false,
673		},
674		{
675			name:    "git add with valid file after flags",
676			script:  "git add -v file.txt",
677			wantHas: false,
678		},
679	}
680
681	for _, tc := range tests {
682		t.Run(tc.name, func(t *testing.T) {
683			r := strings.NewReader(tc.script)
684			parser := syntax.NewParser()
685			file, err := parser.Parse(r, "")
686			if err != nil {
687				if tc.wantHas {
688					t.Errorf("Parse error: %v", err)
689				}
690				return
691			}
692
693			found := false
694			syntax.Walk(file, func(node syntax.Node) bool {
695				callExpr, ok := node.(*syntax.CallExpr)
696				if !ok {
697					return true
698				}
699				if hasBlindGitAdd(callExpr) {
700					found = true
701					return false
702				}
703				return true
704			})
705
706			if found != tc.wantHas {
707				t.Errorf("hasBlindGitAdd() = %v, want %v", found, tc.wantHas)
708			}
709		})
710	}
711}
712
713func TestHasSketchWipBranchChangesEdgeCases(t *testing.T) {
714	tests := []struct {
715		name    string
716		script  string
717		wantHas bool
718	}{
719		{
720			name:    "git command with less than 2 args",
721			script:  "git",
722			wantHas: false,
723		},
724		{
725			name:    "non-git command",
726			script:  "ls main",
727			wantHas: false,
728		},
729		{
730			name:    "git branch -m with sketch-wip not as source",
731			script:  "git branch -m other-branch sketch-wip",
732			wantHas: false,
733		},
734		{
735			name:    "git checkout with complex path",
736			script:  "git checkout src/components/file.go",
737			wantHas: false,
738		},
739		{
740			name:    "git switch with complex flag",
741			script:  "git switch --detach HEAD~1",
742			wantHas: false,
743		},
744		{
745			name:    "git checkout with multiple flags",
746			script:  "git checkout --ours --theirs file.txt",
747			wantHas: false,
748		},
749	}
750
751	for _, tc := range tests {
752		t.Run(tc.name, func(t *testing.T) {
753			r := strings.NewReader(tc.script)
754			parser := syntax.NewParser()
755			file, err := parser.Parse(r, "")
756			if err != nil {
757				if tc.wantHas {
758					t.Errorf("Parse error: %v", err)
759				}
760				return
761			}
762
763			found := false
764			syntax.Walk(file, func(node syntax.Node) bool {
765				callExpr, ok := node.(*syntax.CallExpr)
766				if !ok {
767					return true
768				}
769				if hasSketchWipBranchChanges(callExpr) {
770					found = true
771					return false
772				}
773				return true
774			})
775
776			if found != tc.wantHas {
777				t.Errorf("hasSketchWipBranchChanges() = %v, want %v", found, tc.wantHas)
778			}
779		})
780	}
781}