bashkit.go

  1package bashkit
  2
  3import (
  4	"bytes"
  5	"fmt"
  6	"strings"
  7	"sync"
  8
  9	"mvdan.cc/sh/v3/syntax"
 10)
 11
 12var checks = []func(*syntax.CallExpr) error{
 13	noBlindGitAdd,
 14	noDangerousRmRf,
 15}
 16
 17// Process-level checks that track state across calls
 18var processAwareChecks = []func(*syntax.CallExpr) error{
 19	noSketchWipBranchChangesOnce,
 20}
 21
 22// Track whether sketch-wip branch warning has been shown in this process
 23var (
 24	sketchWipWarningMu    sync.Mutex
 25	sketchWipWarningShown bool
 26)
 27
 28// ResetSketchWipWarning resets the warning state for testing purposes
 29func ResetSketchWipWarning() {
 30	sketchWipWarningMu.Lock()
 31	sketchWipWarningShown = false
 32	sketchWipWarningMu.Unlock()
 33}
 34
 35// Check inspects bashScript and returns an error if it ought not be executed.
 36// Check DOES NOT PROVIDE SECURITY against malicious actors.
 37// It is intended to catch straightforward mistakes in which a model
 38// does things despite having been instructed not to do them.
 39func Check(bashScript string) error {
 40	r := strings.NewReader(bashScript)
 41	parser := syntax.NewParser()
 42	file, err := parser.Parse(r, "")
 43	if err != nil {
 44		// Execution will fail, but we'll get a better error message from bash.
 45		// Note that if this were security load bearing, this would be a terrible idea:
 46		// You could smuggle stuff past Check by exploiting differences in what is considered syntactically valid.
 47		// But it is not.
 48		return nil
 49	}
 50
 51	syntax.Walk(file, func(node syntax.Node) bool {
 52		if err != nil {
 53			return false
 54		}
 55		callExpr, ok := node.(*syntax.CallExpr)
 56		if !ok {
 57			return true
 58		}
 59		// Run regular checks
 60		for _, check := range checks {
 61			err = check(callExpr)
 62			if err != nil {
 63				return false
 64			}
 65		}
 66		// Run process-aware checks
 67		for _, check := range processAwareChecks {
 68			err = check(callExpr)
 69			if err != nil {
 70				return false
 71			}
 72		}
 73		return true
 74	})
 75
 76	return err
 77}
 78
 79// WillRunGitCommit reports whether bashScript contains a git commit command.
 80func WillRunGitCommit(bashScript string) (bool, error) {
 81	r := strings.NewReader(bashScript)
 82	parser := syntax.NewParser()
 83	file, err := parser.Parse(r, "")
 84	if err != nil {
 85		// Parsing failed, but let's not consider this an error for the same reasons as in Check
 86		return false, nil
 87	}
 88
 89	willCommit := false
 90
 91	syntax.Walk(file, func(node syntax.Node) bool {
 92		callExpr, ok := node.(*syntax.CallExpr)
 93		if !ok {
 94			return true
 95		}
 96		if isGitCommitCommand(callExpr) {
 97			willCommit = true
 98			return false
 99		}
100		return true
101	})
102
103	return willCommit, nil
104}
105
106// noDangerousRmRf checks for rm -rf commands that could delete critical directories.
107// It rejects patterns that could delete .git directories, home directories (~, $HOME),
108// or root directories.
109func noDangerousRmRf(cmd *syntax.CallExpr) error {
110	if hasDangerousRmRf(cmd) {
111		return fmt.Errorf("permission denied: this rm command could delete critical data (.git, home directory, or root). If you really need to run this command, spell out the full path explicitly (no wildcards, ~, or $HOME). Consider confirming with the user before running destructive cleanup commands")
112	}
113	return nil
114}
115
116// hasDangerousRmRf checks if an rm command could delete critical directories.
117func hasDangerousRmRf(cmd *syntax.CallExpr) bool {
118	if len(cmd.Args) < 1 {
119		return false
120	}
121
122	// Check if the command is rm
123	firstArg := cmd.Args[0].Lit()
124	if firstArg != "rm" {
125		return false
126	}
127
128	// Check if -r or -R is present (recursive)
129	hasRecursive := false
130	hasForce := false
131	for _, arg := range cmd.Args[1:] {
132		lit := arg.Lit()
133		// Handle combined flags like -rf, -fr, -Rf, etc.
134		if strings.HasPrefix(lit, "-") && !strings.HasPrefix(lit, "--") {
135			if strings.ContainsAny(lit, "rR") {
136				hasRecursive = true
137			}
138			if strings.Contains(lit, "f") {
139				hasForce = true
140			}
141		}
142		if lit == "--recursive" {
143			hasRecursive = true
144		}
145		if lit == "--force" {
146			hasForce = true
147		}
148	}
149
150	// Only check for dangerous paths if it's a recursive and forced rm
151	if !hasRecursive || !hasForce {
152		return false
153	}
154
155	// Check arguments for dangerous patterns
156	for _, arg := range cmd.Args[1:] {
157		lit := arg.Lit()
158		// Skip flags
159		if strings.HasPrefix(lit, "-") {
160			continue
161		}
162
163		// Check for .git directory patterns
164		if lit == ".git" || strings.HasSuffix(lit, "/.git") ||
165			strings.Contains(lit, ".git/") || strings.Contains(lit, ".git ") {
166			return true
167		}
168
169		// Check for home directory patterns
170		if lit == "~" || lit == "~/" || strings.HasPrefix(lit, "~/") {
171			return true
172		}
173
174		// Check for root directory
175		if lit == "/" {
176			return true
177		}
178
179		// Check for wildcards that could match .git
180		if lit == ".*" || strings.HasSuffix(lit, "/.*") {
181			return true
182		}
183
184		// Check for broad wildcards at dangerous locations
185		if lit == "*" || lit == "/*" {
186			return true
187		}
188	}
189
190	// Also check if the argument uses variable expansion (like $HOME)
191	// We need to walk the AST more carefully for this
192	for _, arg := range cmd.Args[1:] {
193		if containsHomeVariable(arg) {
194			return true
195		}
196	}
197
198	return false
199}
200
201// containsHomeVariable checks if a word contains $HOME or ${HOME} expansion
202func containsHomeVariable(word *syntax.Word) bool {
203	for _, part := range word.Parts {
204		switch p := part.(type) {
205		case *syntax.ParamExp:
206			if p.Param != nil && p.Param.Value == "HOME" {
207				return true
208			}
209		case *syntax.DblQuoted:
210			for _, inner := range p.Parts {
211				if pe, ok := inner.(*syntax.ParamExp); ok {
212					if pe.Param != nil && pe.Param.Value == "HOME" {
213						return true
214					}
215				}
216			}
217		}
218	}
219	return false
220}
221
222// noBlindGitAdd checks for git add commands that blindly add all files.
223// It rejects patterns like 'git add -A', 'git add .', 'git add --all', 'git add *'.
224func noBlindGitAdd(cmd *syntax.CallExpr) error {
225	if hasBlindGitAdd(cmd) {
226		return fmt.Errorf("permission denied: blind git add commands (git add -A, git add ., git add --all, git add *) are not allowed, specify files explicitly")
227	}
228	return nil
229}
230
231func hasBlindGitAdd(cmd *syntax.CallExpr) bool {
232	if len(cmd.Args) < 2 {
233		return false
234	}
235	if cmd.Args[0].Lit() != "git" {
236		return false
237	}
238
239	// Find the 'add' subcommand
240	addIndex := -1
241	for i, arg := range cmd.Args {
242		if arg.Lit() == "add" {
243			addIndex = i
244			break
245		}
246	}
247
248	if addIndex < 0 {
249		return false
250	}
251
252	// Check arguments after 'add' for blind patterns
253	for i := addIndex + 1; i < len(cmd.Args); i++ {
254		arg := cmd.Args[i].Lit()
255		// Check for blind add patterns
256		if arg == "-A" || arg == "--all" || arg == "." || arg == "*" {
257			return true
258		}
259	}
260
261	return false
262}
263
264// AddCoauthorTrailer modifies a bash script to add a Co-authored-by trailer
265// to any git commit commands. Returns the modified script.
266func AddCoauthorTrailer(bashScript, trailer string) string {
267	r := strings.NewReader(bashScript)
268	parser := syntax.NewParser(syntax.KeepComments(true))
269	file, err := parser.Parse(r, "")
270	if err != nil {
271		// Can't parse, return original
272		return bashScript
273	}
274
275	modified := false
276	syntax.Walk(file, func(node syntax.Node) bool {
277		callExpr, ok := node.(*syntax.CallExpr)
278		if !ok {
279			return true
280		}
281		if addTrailerToGitCommit(callExpr, trailer) {
282			modified = true
283		}
284		return true
285	})
286
287	if !modified {
288		return bashScript
289	}
290
291	var buf bytes.Buffer
292	printer := syntax.NewPrinter()
293	if err := printer.Print(&buf, file); err != nil {
294		return bashScript
295	}
296	return buf.String()
297}
298
299// addTrailerToGitCommit adds --trailer to a git commit command.
300// Returns true if the command was modified.
301func addTrailerToGitCommit(cmd *syntax.CallExpr, trailer string) bool {
302	if !isGitCommitCommand(cmd) {
303		return false
304	}
305
306	// Find where to insert --trailer (right after "commit")
307	insertIdx := -1
308	for i := 1; i < len(cmd.Args); i++ {
309		if cmd.Args[i].Lit() == "commit" {
310			insertIdx = i + 1
311			break
312		}
313	}
314	if insertIdx < 0 {
315		return false
316	}
317
318	// Create --trailer argument
319	trailerArg := &syntax.Word{
320		Parts: []syntax.WordPart{
321			&syntax.Lit{Value: "--trailer"},
322		},
323	}
324	// Create the trailer value argument
325	trailerVal := &syntax.Word{
326		Parts: []syntax.WordPart{
327			&syntax.DblQuoted{
328				Parts: []syntax.WordPart{
329					&syntax.Lit{Value: trailer},
330				},
331			},
332		},
333	}
334
335	// Insert the two new arguments
336	newArgs := make([]*syntax.Word, 0, len(cmd.Args)+2)
337	newArgs = append(newArgs, cmd.Args[:insertIdx]...)
338	newArgs = append(newArgs, trailerArg, trailerVal)
339	newArgs = append(newArgs, cmd.Args[insertIdx:]...)
340	cmd.Args = newArgs
341
342	return true
343}
344
345// isGitCommitCommand checks if a command is 'git commit'.
346func isGitCommitCommand(cmd *syntax.CallExpr) bool {
347	if len(cmd.Args) < 2 {
348		return false
349	}
350
351	// First argument must be 'git'
352	if cmd.Args[0].Lit() != "git" {
353		return false
354	}
355
356	// Look for 'commit' in any position after 'git'
357	for i := 1; i < len(cmd.Args); i++ {
358		if cmd.Args[i].Lit() == "commit" {
359			return true
360		}
361	}
362
363	return false
364}
365
366// noSketchWipBranchChangesOnce checks for git commands that would change the sketch-wip branch.
367// It rejects commands that would rename the sketch-wip branch or switch away from it.
368// This check only shows the warning once per process.
369func noSketchWipBranchChangesOnce(cmd *syntax.CallExpr) error {
370	if hasSketchWipBranchChanges(cmd) {
371		// Check if we've already warned in this process
372		sketchWipWarningMu.Lock()
373		alreadyWarned := sketchWipWarningShown
374		if !alreadyWarned {
375			sketchWipWarningShown = true
376		}
377		sketchWipWarningMu.Unlock()
378
379		if !alreadyWarned {
380			return fmt.Errorf("permission denied: cannot leave 'sketch-wip' branch. This branch is designated for change detection and auto-push; work on other branches may be lost. Warning shown once per session. Repeat command if needed for temporary operations (rebase, bisect, etc.) but return to sketch-wip afterward. Note: users can push to any branch via the Push button in the UI")
381		}
382	}
383	return nil
384}
385
386// hasSketchWipBranchChanges checks if a git command would change the sketch-wip branch.
387func hasSketchWipBranchChanges(cmd *syntax.CallExpr) bool {
388	if len(cmd.Args) < 2 {
389		return false
390	}
391	if cmd.Args[0].Lit() != "git" {
392		return false
393	}
394
395	// Look for subcommands that could change the sketch-wip branch
396	for i := 1; i < len(cmd.Args); i++ {
397		arg := cmd.Args[i].Lit()
398		switch arg {
399		case "branch":
400			// Check for branch rename: git branch -m sketch-wip newname or git branch -M sketch-wip newname
401			if i+2 < len(cmd.Args) {
402				// Look for -m or -M flag
403				for j := i + 1; j < len(cmd.Args)-1; j++ {
404					flag := cmd.Args[j].Lit()
405					if flag == "-m" || flag == "-M" {
406						// Check if sketch-wip is the source branch
407						if cmd.Args[j+1].Lit() == "sketch-wip" {
408							return true
409						}
410					}
411				}
412			}
413		case "checkout":
414			// Check for branch switching: git checkout otherbranch
415			// But allow git checkout files/paths
416			if i+1 < len(cmd.Args) {
417				nextArg := cmd.Args[i+1].Lit()
418				// Skip if it's a flag
419				if !strings.HasPrefix(nextArg, "-") {
420					// This might be a branch checkout - we'll be conservative and warn
421					// unless it looks like a file path
422					if !strings.Contains(nextArg, "/") && !strings.Contains(nextArg, ".") {
423						return true
424					}
425				}
426			}
427		case "switch":
428			// Check for branch switching: git switch otherbranch
429			if i+1 < len(cmd.Args) {
430				nextArg := cmd.Args[i+1].Lit()
431				// Skip if it's a flag
432				if !strings.HasPrefix(nextArg, "-") {
433					return true
434				}
435			}
436		}
437	}
438
439	return false
440}