bashkit.go

  1package bashkit
  2
  3import (
  4	"bytes"
  5	"fmt"
  6	"strings"
  7	"sync"
  8
  9	"mvdan.cc/sh/v3/syntax"
 10)
 11
 12var checks = []func(*syntax.CallExpr) error{
 13	noBlindGitAdd,
 14	noDangerousRmRf,
 15}
 16
 17// Process-level checks that track state across calls
 18var processAwareChecks = []func(*syntax.CallExpr) error{
 19	noSketchWipBranchChangesOnce,
 20}
 21
 22// Track whether sketch-wip branch warning has been shown in this process
 23var (
 24	sketchWipWarningMu    sync.Mutex
 25	sketchWipWarningShown bool
 26)
 27
 28// ResetSketchWipWarning resets the warning state for testing purposes
 29func ResetSketchWipWarning() {
 30	sketchWipWarningMu.Lock()
 31	sketchWipWarningShown = false
 32	sketchWipWarningMu.Unlock()
 33}
 34
 35// Check inspects bashScript and returns an error if it ought not be executed.
 36// Check DOES NOT PROVIDE SECURITY against malicious actors.
 37// It is intended to catch straightforward mistakes in which a model
 38// does things despite having been instructed not to do them.
 39func Check(bashScript string) error {
 40	r := strings.NewReader(bashScript)
 41	parser := syntax.NewParser()
 42	file, err := parser.Parse(r, "")
 43	if err != nil {
 44		// Execution will fail, but we'll get a better error message from bash.
 45		// Note that if this were security load bearing, this would be a terrible idea:
 46		// You could smuggle stuff past Check by exploiting differences in what is considered syntactically valid.
 47		// But it is not.
 48		return nil
 49	}
 50
 51	syntax.Walk(file, func(node syntax.Node) bool {
 52		if err != nil {
 53			return false
 54		}
 55		callExpr, ok := node.(*syntax.CallExpr)
 56		if !ok {
 57			return true
 58		}
 59		// Run regular checks
 60		for _, check := range checks {
 61			err = check(callExpr)
 62			if err != nil {
 63				return false
 64			}
 65		}
 66		// Run process-aware checks
 67		for _, check := range processAwareChecks {
 68			err = check(callExpr)
 69			if err != nil {
 70				return false
 71			}
 72		}
 73		return true
 74	})
 75
 76	return err
 77}
 78
 79// WillRunGitCommit checks if the provided bash script will run 'git commit'.
 80// It returns true if any command in the script is a git commit command.
 81func WillRunGitCommit(bashScript string) (bool, error) {
 82	r := strings.NewReader(bashScript)
 83	parser := syntax.NewParser()
 84	file, err := parser.Parse(r, "")
 85	if err != nil {
 86		// Parsing failed, but let's not consider this an error for the same reasons as in Check
 87		return false, nil
 88	}
 89
 90	willCommit := false
 91
 92	syntax.Walk(file, func(node syntax.Node) bool {
 93		callExpr, ok := node.(*syntax.CallExpr)
 94		if !ok {
 95			return true
 96		}
 97		if isGitCommitCommand(callExpr) {
 98			willCommit = true
 99			return false
100		}
101		return true
102	})
103
104	return willCommit, nil
105}
106
107// noDangerousRmRf checks for rm -rf commands that could delete critical directories.
108// It rejects patterns that could delete .git directories, home directories (~, $HOME),
109// or root directories.
110func noDangerousRmRf(cmd *syntax.CallExpr) error {
111	if hasDangerousRmRf(cmd) {
112		return fmt.Errorf("permission denied: this rm command could delete critical data (.git, home directory, or root). If you really need to run this command, spell out the full path explicitly (no wildcards, ~, or $HOME). Consider confirming with the user before running destructive cleanup commands")
113	}
114	return nil
115}
116
117// hasDangerousRmRf checks if an rm command could delete critical directories.
118func hasDangerousRmRf(cmd *syntax.CallExpr) bool {
119	if len(cmd.Args) < 1 {
120		return false
121	}
122
123	// Check if the command is rm
124	firstArg := cmd.Args[0].Lit()
125	if firstArg != "rm" {
126		return false
127	}
128
129	// Check if -r or -R is present (recursive)
130	hasRecursive := false
131	hasForce := false
132	for _, arg := range cmd.Args[1:] {
133		lit := arg.Lit()
134		// Handle combined flags like -rf, -fr, -Rf, etc.
135		if strings.HasPrefix(lit, "-") && !strings.HasPrefix(lit, "--") {
136			if strings.ContainsAny(lit, "rR") {
137				hasRecursive = true
138			}
139			if strings.Contains(lit, "f") {
140				hasForce = true
141			}
142		}
143		if lit == "--recursive" {
144			hasRecursive = true
145		}
146		if lit == "--force" {
147			hasForce = true
148		}
149	}
150
151	// Only check for dangerous paths if it's a recursive and forced rm
152	if !hasRecursive || !hasForce {
153		return false
154	}
155
156	// Check arguments for dangerous patterns
157	for _, arg := range cmd.Args[1:] {
158		lit := arg.Lit()
159		// Skip flags
160		if strings.HasPrefix(lit, "-") {
161			continue
162		}
163
164		// Check for .git directory patterns
165		if lit == ".git" || strings.HasSuffix(lit, "/.git") ||
166			strings.Contains(lit, ".git/") || strings.Contains(lit, ".git ") {
167			return true
168		}
169
170		// Check for home directory patterns
171		if lit == "~" || lit == "~/" || strings.HasPrefix(lit, "~/") {
172			return true
173		}
174
175		// Check for root directory
176		if lit == "/" {
177			return true
178		}
179
180		// Check for wildcards that could match .git
181		if lit == ".*" || strings.HasSuffix(lit, "/.*") {
182			return true
183		}
184
185		// Check for broad wildcards at dangerous locations
186		if lit == "*" || lit == "/*" {
187			return true
188		}
189	}
190
191	// Also check if the argument uses variable expansion (like $HOME)
192	// We need to walk the AST more carefully for this
193	for _, arg := range cmd.Args[1:] {
194		if containsHomeVariable(arg) {
195			return true
196		}
197	}
198
199	return false
200}
201
202// containsHomeVariable checks if a word contains $HOME or ${HOME} expansion
203func containsHomeVariable(word *syntax.Word) bool {
204	for _, part := range word.Parts {
205		switch p := part.(type) {
206		case *syntax.ParamExp:
207			if p.Param != nil && p.Param.Value == "HOME" {
208				return true
209			}
210		case *syntax.DblQuoted:
211			for _, inner := range p.Parts {
212				if pe, ok := inner.(*syntax.ParamExp); ok {
213					if pe.Param != nil && pe.Param.Value == "HOME" {
214						return true
215					}
216				}
217			}
218		}
219	}
220	return false
221}
222
223// noBlindGitAdd checks for git add commands that blindly add all files.
224// It rejects patterns like 'git add -A', 'git add .', 'git add --all', 'git add *'.
225func noBlindGitAdd(cmd *syntax.CallExpr) error {
226	if hasBlindGitAdd(cmd) {
227		return fmt.Errorf("permission denied: blind git add commands (git add -A, git add ., git add --all, git add *) are not allowed, specify files explicitly")
228	}
229	return nil
230}
231
232func hasBlindGitAdd(cmd *syntax.CallExpr) bool {
233	if len(cmd.Args) < 2 {
234		return false
235	}
236	if cmd.Args[0].Lit() != "git" {
237		return false
238	}
239
240	// Find the 'add' subcommand
241	addIndex := -1
242	for i, arg := range cmd.Args {
243		if arg.Lit() == "add" {
244			addIndex = i
245			break
246		}
247	}
248
249	if addIndex < 0 {
250		return false
251	}
252
253	// Check arguments after 'add' for blind patterns
254	for i := addIndex + 1; i < len(cmd.Args); i++ {
255		arg := cmd.Args[i].Lit()
256		// Check for blind add patterns
257		if arg == "-A" || arg == "--all" || arg == "." || arg == "*" {
258			return true
259		}
260	}
261
262	return false
263}
264
265// AddCoauthorTrailer modifies a bash script to add a Co-authored-by trailer
266// to any git commit commands. Returns the modified script.
267func AddCoauthorTrailer(bashScript, trailer string) string {
268	r := strings.NewReader(bashScript)
269	parser := syntax.NewParser(syntax.KeepComments(true))
270	file, err := parser.Parse(r, "")
271	if err != nil {
272		// Can't parse, return original
273		return bashScript
274	}
275
276	modified := false
277	syntax.Walk(file, func(node syntax.Node) bool {
278		callExpr, ok := node.(*syntax.CallExpr)
279		if !ok {
280			return true
281		}
282		if addTrailerToGitCommit(callExpr, trailer) {
283			modified = true
284		}
285		return true
286	})
287
288	if !modified {
289		return bashScript
290	}
291
292	var buf bytes.Buffer
293	printer := syntax.NewPrinter()
294	if err := printer.Print(&buf, file); err != nil {
295		return bashScript
296	}
297	return buf.String()
298}
299
300// addTrailerToGitCommit adds --trailer to a git commit command.
301// Returns true if the command was modified.
302func addTrailerToGitCommit(cmd *syntax.CallExpr, trailer string) bool {
303	if !isGitCommitCommand(cmd) {
304		return false
305	}
306
307	// Find where to insert --trailer (right after "commit")
308	insertIdx := -1
309	for i := 1; i < len(cmd.Args); i++ {
310		if cmd.Args[i].Lit() == "commit" {
311			insertIdx = i + 1
312			break
313		}
314	}
315	if insertIdx < 0 {
316		return false
317	}
318
319	// Create --trailer argument
320	trailerArg := &syntax.Word{
321		Parts: []syntax.WordPart{
322			&syntax.Lit{Value: "--trailer"},
323		},
324	}
325	// Create the trailer value argument
326	trailerVal := &syntax.Word{
327		Parts: []syntax.WordPart{
328			&syntax.DblQuoted{
329				Parts: []syntax.WordPart{
330					&syntax.Lit{Value: trailer},
331				},
332			},
333		},
334	}
335
336	// Insert the two new arguments
337	newArgs := make([]*syntax.Word, 0, len(cmd.Args)+2)
338	newArgs = append(newArgs, cmd.Args[:insertIdx]...)
339	newArgs = append(newArgs, trailerArg, trailerVal)
340	newArgs = append(newArgs, cmd.Args[insertIdx:]...)
341	cmd.Args = newArgs
342
343	return true
344}
345
346// isGitCommitCommand checks if a command is 'git commit'.
347func isGitCommitCommand(cmd *syntax.CallExpr) bool {
348	if len(cmd.Args) < 2 {
349		return false
350	}
351
352	// First argument must be 'git'
353	if cmd.Args[0].Lit() != "git" {
354		return false
355	}
356
357	// Look for 'commit' in any position after 'git'
358	for i := 1; i < len(cmd.Args); i++ {
359		if cmd.Args[i].Lit() == "commit" {
360			return true
361		}
362	}
363
364	return false
365}
366
367// noSketchWipBranchChangesOnce checks for git commands that would change the sketch-wip branch.
368// It rejects commands that would rename the sketch-wip branch or switch away from it.
369// This check only shows the warning once per process.
370func noSketchWipBranchChangesOnce(cmd *syntax.CallExpr) error {
371	if hasSketchWipBranchChanges(cmd) {
372		// Check if we've already warned in this process
373		sketchWipWarningMu.Lock()
374		alreadyWarned := sketchWipWarningShown
375		if !alreadyWarned {
376			sketchWipWarningShown = true
377		}
378		sketchWipWarningMu.Unlock()
379
380		if !alreadyWarned {
381			return fmt.Errorf("permission denied: cannot leave 'sketch-wip' branch. This branch is designated for change detection and auto-push; work on other branches may be lost. Warning shown once per session. Repeat command if needed for temporary operations (rebase, bisect, etc.) but return to sketch-wip afterward. Note: users can push to any branch via the Push button in the UI")
382		}
383	}
384	return nil
385}
386
387// hasSketchWipBranchChanges checks if a git command would change the sketch-wip branch.
388func hasSketchWipBranchChanges(cmd *syntax.CallExpr) bool {
389	if len(cmd.Args) < 2 {
390		return false
391	}
392	if cmd.Args[0].Lit() != "git" {
393		return false
394	}
395
396	// Look for subcommands that could change the sketch-wip branch
397	for i := 1; i < len(cmd.Args); i++ {
398		arg := cmd.Args[i].Lit()
399		switch arg {
400		case "branch":
401			// Check for branch rename: git branch -m sketch-wip newname or git branch -M sketch-wip newname
402			if i+2 < len(cmd.Args) {
403				// Look for -m or -M flag
404				for j := i + 1; j < len(cmd.Args)-1; j++ {
405					flag := cmd.Args[j].Lit()
406					if flag == "-m" || flag == "-M" {
407						// Check if sketch-wip is the source branch
408						if cmd.Args[j+1].Lit() == "sketch-wip" {
409							return true
410						}
411					}
412				}
413			}
414		case "checkout":
415			// Check for branch switching: git checkout otherbranch
416			// But allow git checkout files/paths
417			if i+1 < len(cmd.Args) {
418				nextArg := cmd.Args[i+1].Lit()
419				// Skip if it's a flag
420				if !strings.HasPrefix(nextArg, "-") {
421					// This might be a branch checkout - we'll be conservative and warn
422					// unless it looks like a file path
423					if !strings.Contains(nextArg, "/") && !strings.Contains(nextArg, ".") {
424						return true
425					}
426				}
427			}
428		case "switch":
429			// Check for branch switching: git switch otherbranch
430			if i+1 < len(cmd.Args) {
431				nextArg := cmd.Args[i+1].Lit()
432				// Skip if it's a flag
433				if !strings.HasPrefix(nextArg, "-") {
434					return true
435				}
436			}
437		}
438	}
439
440	return false
441}