1package bashkit
2
3import (
4 "bytes"
5 "fmt"
6 "strings"
7 "sync"
8
9 "mvdan.cc/sh/v3/syntax"
10)
11
12var checks = []func(*syntax.CallExpr) error{
13 noBlindGitAdd,
14 noDangerousRmRf,
15}
16
17// Process-level checks that track state across calls
18var processAwareChecks = []func(*syntax.CallExpr) error{
19 noSketchWipBranchChangesOnce,
20}
21
22// Track whether sketch-wip branch warning has been shown in this process
23var (
24 sketchWipWarningMu sync.Mutex
25 sketchWipWarningShown bool
26)
27
28// ResetSketchWipWarning resets the warning state for testing purposes
29func ResetSketchWipWarning() {
30 sketchWipWarningMu.Lock()
31 sketchWipWarningShown = false
32 sketchWipWarningMu.Unlock()
33}
34
35// Check inspects bashScript and returns an error if it ought not be executed.
36// Check DOES NOT PROVIDE SECURITY against malicious actors.
37// It is intended to catch straightforward mistakes in which a model
38// does things despite having been instructed not to do them.
39func Check(bashScript string) error {
40 r := strings.NewReader(bashScript)
41 parser := syntax.NewParser()
42 file, err := parser.Parse(r, "")
43 if err != nil {
44 // Execution will fail, but we'll get a better error message from bash.
45 // Note that if this were security load bearing, this would be a terrible idea:
46 // You could smuggle stuff past Check by exploiting differences in what is considered syntactically valid.
47 // But it is not.
48 return nil
49 }
50
51 syntax.Walk(file, func(node syntax.Node) bool {
52 if err != nil {
53 return false
54 }
55 callExpr, ok := node.(*syntax.CallExpr)
56 if !ok {
57 return true
58 }
59 // Run regular checks
60 for _, check := range checks {
61 err = check(callExpr)
62 if err != nil {
63 return false
64 }
65 }
66 // Run process-aware checks
67 for _, check := range processAwareChecks {
68 err = check(callExpr)
69 if err != nil {
70 return false
71 }
72 }
73 return true
74 })
75
76 return err
77}
78
79// WillRunGitCommit checks if the provided bash script will run 'git commit'.
80// It returns true if any command in the script is a git commit command.
81func WillRunGitCommit(bashScript string) (bool, error) {
82 r := strings.NewReader(bashScript)
83 parser := syntax.NewParser()
84 file, err := parser.Parse(r, "")
85 if err != nil {
86 // Parsing failed, but let's not consider this an error for the same reasons as in Check
87 return false, nil
88 }
89
90 willCommit := false
91
92 syntax.Walk(file, func(node syntax.Node) bool {
93 callExpr, ok := node.(*syntax.CallExpr)
94 if !ok {
95 return true
96 }
97 if isGitCommitCommand(callExpr) {
98 willCommit = true
99 return false
100 }
101 return true
102 })
103
104 return willCommit, nil
105}
106
107// noDangerousRmRf checks for rm -rf commands that could delete critical directories.
108// It rejects patterns that could delete .git directories, home directories (~, $HOME),
109// or root directories.
110func noDangerousRmRf(cmd *syntax.CallExpr) error {
111 if hasDangerousRmRf(cmd) {
112 return fmt.Errorf("permission denied: this rm command could delete critical data (.git, home directory, or root). If you really need to run this command, spell out the full path explicitly (no wildcards, ~, or $HOME). Consider confirming with the user before running destructive cleanup commands")
113 }
114 return nil
115}
116
117// hasDangerousRmRf checks if an rm command could delete critical directories.
118func hasDangerousRmRf(cmd *syntax.CallExpr) bool {
119 if len(cmd.Args) < 1 {
120 return false
121 }
122
123 // Check if the command is rm
124 firstArg := cmd.Args[0].Lit()
125 if firstArg != "rm" {
126 return false
127 }
128
129 // Check if -r or -R is present (recursive)
130 hasRecursive := false
131 hasForce := false
132 for _, arg := range cmd.Args[1:] {
133 lit := arg.Lit()
134 // Handle combined flags like -rf, -fr, -Rf, etc.
135 if strings.HasPrefix(lit, "-") && !strings.HasPrefix(lit, "--") {
136 if strings.ContainsAny(lit, "rR") {
137 hasRecursive = true
138 }
139 if strings.Contains(lit, "f") {
140 hasForce = true
141 }
142 }
143 if lit == "--recursive" {
144 hasRecursive = true
145 }
146 if lit == "--force" {
147 hasForce = true
148 }
149 }
150
151 // Only check for dangerous paths if it's a recursive and forced rm
152 if !hasRecursive || !hasForce {
153 return false
154 }
155
156 // Check arguments for dangerous patterns
157 for _, arg := range cmd.Args[1:] {
158 lit := arg.Lit()
159 // Skip flags
160 if strings.HasPrefix(lit, "-") {
161 continue
162 }
163
164 // Check for .git directory patterns
165 if lit == ".git" || strings.HasSuffix(lit, "/.git") ||
166 strings.Contains(lit, ".git/") || strings.Contains(lit, ".git ") {
167 return true
168 }
169
170 // Check for home directory patterns
171 if lit == "~" || lit == "~/" || strings.HasPrefix(lit, "~/") {
172 return true
173 }
174
175 // Check for root directory
176 if lit == "/" {
177 return true
178 }
179
180 // Check for wildcards that could match .git
181 if lit == ".*" || strings.HasSuffix(lit, "/.*") {
182 return true
183 }
184
185 // Check for broad wildcards at dangerous locations
186 if lit == "*" || lit == "/*" {
187 return true
188 }
189 }
190
191 // Also check if the argument uses variable expansion (like $HOME)
192 // We need to walk the AST more carefully for this
193 for _, arg := range cmd.Args[1:] {
194 if containsHomeVariable(arg) {
195 return true
196 }
197 }
198
199 return false
200}
201
202// containsHomeVariable checks if a word contains $HOME or ${HOME} expansion
203func containsHomeVariable(word *syntax.Word) bool {
204 for _, part := range word.Parts {
205 switch p := part.(type) {
206 case *syntax.ParamExp:
207 if p.Param != nil && p.Param.Value == "HOME" {
208 return true
209 }
210 case *syntax.DblQuoted:
211 for _, inner := range p.Parts {
212 if pe, ok := inner.(*syntax.ParamExp); ok {
213 if pe.Param != nil && pe.Param.Value == "HOME" {
214 return true
215 }
216 }
217 }
218 }
219 }
220 return false
221}
222
223// noBlindGitAdd checks for git add commands that blindly add all files.
224// It rejects patterns like 'git add -A', 'git add .', 'git add --all', 'git add *'.
225func noBlindGitAdd(cmd *syntax.CallExpr) error {
226 if hasBlindGitAdd(cmd) {
227 return fmt.Errorf("permission denied: blind git add commands (git add -A, git add ., git add --all, git add *) are not allowed, specify files explicitly")
228 }
229 return nil
230}
231
232func hasBlindGitAdd(cmd *syntax.CallExpr) bool {
233 if len(cmd.Args) < 2 {
234 return false
235 }
236 if cmd.Args[0].Lit() != "git" {
237 return false
238 }
239
240 // Find the 'add' subcommand
241 addIndex := -1
242 for i, arg := range cmd.Args {
243 if arg.Lit() == "add" {
244 addIndex = i
245 break
246 }
247 }
248
249 if addIndex < 0 {
250 return false
251 }
252
253 // Check arguments after 'add' for blind patterns
254 for i := addIndex + 1; i < len(cmd.Args); i++ {
255 arg := cmd.Args[i].Lit()
256 // Check for blind add patterns
257 if arg == "-A" || arg == "--all" || arg == "." || arg == "*" {
258 return true
259 }
260 }
261
262 return false
263}
264
265// AddCoauthorTrailer modifies a bash script to add a Co-authored-by trailer
266// to any git commit commands. Returns the modified script.
267func AddCoauthorTrailer(bashScript, trailer string) string {
268 r := strings.NewReader(bashScript)
269 parser := syntax.NewParser(syntax.KeepComments(true))
270 file, err := parser.Parse(r, "")
271 if err != nil {
272 // Can't parse, return original
273 return bashScript
274 }
275
276 modified := false
277 syntax.Walk(file, func(node syntax.Node) bool {
278 callExpr, ok := node.(*syntax.CallExpr)
279 if !ok {
280 return true
281 }
282 if addTrailerToGitCommit(callExpr, trailer) {
283 modified = true
284 }
285 return true
286 })
287
288 if !modified {
289 return bashScript
290 }
291
292 var buf bytes.Buffer
293 printer := syntax.NewPrinter()
294 if err := printer.Print(&buf, file); err != nil {
295 return bashScript
296 }
297 return buf.String()
298}
299
300// addTrailerToGitCommit adds --trailer to a git commit command.
301// Returns true if the command was modified.
302func addTrailerToGitCommit(cmd *syntax.CallExpr, trailer string) bool {
303 if !isGitCommitCommand(cmd) {
304 return false
305 }
306
307 // Find where to insert --trailer (right after "commit")
308 insertIdx := -1
309 for i := 1; i < len(cmd.Args); i++ {
310 if cmd.Args[i].Lit() == "commit" {
311 insertIdx = i + 1
312 break
313 }
314 }
315 if insertIdx < 0 {
316 return false
317 }
318
319 // Create --trailer argument
320 trailerArg := &syntax.Word{
321 Parts: []syntax.WordPart{
322 &syntax.Lit{Value: "--trailer"},
323 },
324 }
325 // Create the trailer value argument
326 trailerVal := &syntax.Word{
327 Parts: []syntax.WordPart{
328 &syntax.DblQuoted{
329 Parts: []syntax.WordPart{
330 &syntax.Lit{Value: trailer},
331 },
332 },
333 },
334 }
335
336 // Insert the two new arguments
337 newArgs := make([]*syntax.Word, 0, len(cmd.Args)+2)
338 newArgs = append(newArgs, cmd.Args[:insertIdx]...)
339 newArgs = append(newArgs, trailerArg, trailerVal)
340 newArgs = append(newArgs, cmd.Args[insertIdx:]...)
341 cmd.Args = newArgs
342
343 return true
344}
345
346// isGitCommitCommand checks if a command is 'git commit'.
347func isGitCommitCommand(cmd *syntax.CallExpr) bool {
348 if len(cmd.Args) < 2 {
349 return false
350 }
351
352 // First argument must be 'git'
353 if cmd.Args[0].Lit() != "git" {
354 return false
355 }
356
357 // Look for 'commit' in any position after 'git'
358 for i := 1; i < len(cmd.Args); i++ {
359 if cmd.Args[i].Lit() == "commit" {
360 return true
361 }
362 }
363
364 return false
365}
366
367// noSketchWipBranchChangesOnce checks for git commands that would change the sketch-wip branch.
368// It rejects commands that would rename the sketch-wip branch or switch away from it.
369// This check only shows the warning once per process.
370func noSketchWipBranchChangesOnce(cmd *syntax.CallExpr) error {
371 if hasSketchWipBranchChanges(cmd) {
372 // Check if we've already warned in this process
373 sketchWipWarningMu.Lock()
374 alreadyWarned := sketchWipWarningShown
375 if !alreadyWarned {
376 sketchWipWarningShown = true
377 }
378 sketchWipWarningMu.Unlock()
379
380 if !alreadyWarned {
381 return fmt.Errorf("permission denied: cannot leave 'sketch-wip' branch. This branch is designated for change detection and auto-push; work on other branches may be lost. Warning shown once per session. Repeat command if needed for temporary operations (rebase, bisect, etc.) but return to sketch-wip afterward. Note: users can push to any branch via the Push button in the UI")
382 }
383 }
384 return nil
385}
386
387// hasSketchWipBranchChanges checks if a git command would change the sketch-wip branch.
388func hasSketchWipBranchChanges(cmd *syntax.CallExpr) bool {
389 if len(cmd.Args) < 2 {
390 return false
391 }
392 if cmd.Args[0].Lit() != "git" {
393 return false
394 }
395
396 // Look for subcommands that could change the sketch-wip branch
397 for i := 1; i < len(cmd.Args); i++ {
398 arg := cmd.Args[i].Lit()
399 switch arg {
400 case "branch":
401 // Check for branch rename: git branch -m sketch-wip newname or git branch -M sketch-wip newname
402 if i+2 < len(cmd.Args) {
403 // Look for -m or -M flag
404 for j := i + 1; j < len(cmd.Args)-1; j++ {
405 flag := cmd.Args[j].Lit()
406 if flag == "-m" || flag == "-M" {
407 // Check if sketch-wip is the source branch
408 if cmd.Args[j+1].Lit() == "sketch-wip" {
409 return true
410 }
411 }
412 }
413 }
414 case "checkout":
415 // Check for branch switching: git checkout otherbranch
416 // But allow git checkout files/paths
417 if i+1 < len(cmd.Args) {
418 nextArg := cmd.Args[i+1].Lit()
419 // Skip if it's a flag
420 if !strings.HasPrefix(nextArg, "-") {
421 // This might be a branch checkout - we'll be conservative and warn
422 // unless it looks like a file path
423 if !strings.Contains(nextArg, "/") && !strings.Contains(nextArg, ".") {
424 return true
425 }
426 }
427 }
428 case "switch":
429 // Check for branch switching: git switch otherbranch
430 if i+1 < len(cmd.Args) {
431 nextArg := cmd.Args[i+1].Lit()
432 // Skip if it's a flag
433 if !strings.HasPrefix(nextArg, "-") {
434 return true
435 }
436 }
437 }
438 }
439
440 return false
441}