1package bashkit
2
3import (
4 "bytes"
5 "fmt"
6 "strings"
7 "sync"
8
9 "mvdan.cc/sh/v3/syntax"
10)
11
12var checks = []func(*syntax.CallExpr) error{
13 noBlindGitAdd,
14 noDangerousRmRf,
15}
16
17// Process-level checks that track state across calls
18var processAwareChecks = []func(*syntax.CallExpr) error{
19 noSketchWipBranchChangesOnce,
20}
21
22// Track whether sketch-wip branch warning has been shown in this process
23var (
24 sketchWipWarningMu sync.Mutex
25 sketchWipWarningShown bool
26)
27
28// ResetSketchWipWarning resets the warning state for testing purposes
29func ResetSketchWipWarning() {
30 sketchWipWarningMu.Lock()
31 sketchWipWarningShown = false
32 sketchWipWarningMu.Unlock()
33}
34
35// Check inspects bashScript and returns an error if it ought not be executed.
36// Check DOES NOT PROVIDE SECURITY against malicious actors.
37// It is intended to catch straightforward mistakes in which a model
38// does things despite having been instructed not to do them.
39func Check(bashScript string) error {
40 r := strings.NewReader(bashScript)
41 parser := syntax.NewParser()
42 file, err := parser.Parse(r, "")
43 if err != nil {
44 // Execution will fail, but we'll get a better error message from bash.
45 // Note that if this were security load bearing, this would be a terrible idea:
46 // You could smuggle stuff past Check by exploiting differences in what is considered syntactically valid.
47 // But it is not.
48 return nil
49 }
50
51 syntax.Walk(file, func(node syntax.Node) bool {
52 if err != nil {
53 return false
54 }
55 callExpr, ok := node.(*syntax.CallExpr)
56 if !ok {
57 return true
58 }
59 // Run regular checks
60 for _, check := range checks {
61 err = check(callExpr)
62 if err != nil {
63 return false
64 }
65 }
66 // Run process-aware checks
67 for _, check := range processAwareChecks {
68 err = check(callExpr)
69 if err != nil {
70 return false
71 }
72 }
73 return true
74 })
75
76 return err
77}
78
79// WillRunGitCommit reports whether bashScript contains a git commit command.
80func WillRunGitCommit(bashScript string) (bool, error) {
81 r := strings.NewReader(bashScript)
82 parser := syntax.NewParser()
83 file, err := parser.Parse(r, "")
84 if err != nil {
85 // Parsing failed, but let's not consider this an error for the same reasons as in Check
86 return false, nil
87 }
88
89 willCommit := false
90
91 syntax.Walk(file, func(node syntax.Node) bool {
92 callExpr, ok := node.(*syntax.CallExpr)
93 if !ok {
94 return true
95 }
96 if isGitCommitCommand(callExpr) {
97 willCommit = true
98 return false
99 }
100 return true
101 })
102
103 return willCommit, nil
104}
105
106// noDangerousRmRf checks for rm -rf commands that could delete critical directories.
107// It rejects patterns that could delete .git directories, home directories (~, $HOME),
108// or root directories.
109func noDangerousRmRf(cmd *syntax.CallExpr) error {
110 if hasDangerousRmRf(cmd) {
111 return fmt.Errorf("permission denied: this rm command could delete critical data (.git, home directory, or root). If you really need to run this command, spell out the full path explicitly (no wildcards, ~, or $HOME). Consider confirming with the user before running destructive cleanup commands")
112 }
113 return nil
114}
115
116// hasDangerousRmRf checks if an rm command could delete critical directories.
117func hasDangerousRmRf(cmd *syntax.CallExpr) bool {
118 if len(cmd.Args) < 1 {
119 return false
120 }
121
122 // Check if the command is rm
123 firstArg := cmd.Args[0].Lit()
124 if firstArg != "rm" {
125 return false
126 }
127
128 // Check if -r or -R is present (recursive)
129 hasRecursive := false
130 hasForce := false
131 for _, arg := range cmd.Args[1:] {
132 lit := arg.Lit()
133 // Handle combined flags like -rf, -fr, -Rf, etc.
134 if strings.HasPrefix(lit, "-") && !strings.HasPrefix(lit, "--") {
135 if strings.ContainsAny(lit, "rR") {
136 hasRecursive = true
137 }
138 if strings.Contains(lit, "f") {
139 hasForce = true
140 }
141 }
142 if lit == "--recursive" {
143 hasRecursive = true
144 }
145 if lit == "--force" {
146 hasForce = true
147 }
148 }
149
150 // Only check for dangerous paths if it's a recursive and forced rm
151 if !hasRecursive || !hasForce {
152 return false
153 }
154
155 // Check arguments for dangerous patterns
156 for _, arg := range cmd.Args[1:] {
157 lit := arg.Lit()
158 // Skip flags
159 if strings.HasPrefix(lit, "-") {
160 continue
161 }
162
163 // Check for .git directory patterns
164 if lit == ".git" || strings.HasSuffix(lit, "/.git") ||
165 strings.Contains(lit, ".git/") || strings.Contains(lit, ".git ") {
166 return true
167 }
168
169 // Check for home directory patterns
170 if lit == "~" || lit == "~/" || strings.HasPrefix(lit, "~/") {
171 return true
172 }
173
174 // Check for root directory
175 if lit == "/" {
176 return true
177 }
178
179 // Check for wildcards that could match .git
180 if lit == ".*" || strings.HasSuffix(lit, "/.*") {
181 return true
182 }
183
184 // Check for broad wildcards at dangerous locations
185 if lit == "*" || lit == "/*" {
186 return true
187 }
188 }
189
190 // Also check if the argument uses variable expansion (like $HOME)
191 // We need to walk the AST more carefully for this
192 for _, arg := range cmd.Args[1:] {
193 if containsHomeVariable(arg) {
194 return true
195 }
196 }
197
198 return false
199}
200
201// containsHomeVariable checks if a word contains $HOME or ${HOME} expansion
202func containsHomeVariable(word *syntax.Word) bool {
203 for _, part := range word.Parts {
204 switch p := part.(type) {
205 case *syntax.ParamExp:
206 if p.Param != nil && p.Param.Value == "HOME" {
207 return true
208 }
209 case *syntax.DblQuoted:
210 for _, inner := range p.Parts {
211 if pe, ok := inner.(*syntax.ParamExp); ok {
212 if pe.Param != nil && pe.Param.Value == "HOME" {
213 return true
214 }
215 }
216 }
217 }
218 }
219 return false
220}
221
222// noBlindGitAdd checks for git add commands that blindly add all files.
223// It rejects patterns like 'git add -A', 'git add .', 'git add --all', 'git add *'.
224func noBlindGitAdd(cmd *syntax.CallExpr) error {
225 if hasBlindGitAdd(cmd) {
226 return fmt.Errorf("permission denied: blind git add commands (git add -A, git add ., git add --all, git add *) are not allowed, specify files explicitly")
227 }
228 return nil
229}
230
231func hasBlindGitAdd(cmd *syntax.CallExpr) bool {
232 if len(cmd.Args) < 2 {
233 return false
234 }
235 if cmd.Args[0].Lit() != "git" {
236 return false
237 }
238
239 // Find the 'add' subcommand
240 addIndex := -1
241 for i, arg := range cmd.Args {
242 if arg.Lit() == "add" {
243 addIndex = i
244 break
245 }
246 }
247
248 if addIndex < 0 {
249 return false
250 }
251
252 // Check arguments after 'add' for blind patterns
253 for i := addIndex + 1; i < len(cmd.Args); i++ {
254 arg := cmd.Args[i].Lit()
255 // Check for blind add patterns
256 if arg == "-A" || arg == "--all" || arg == "." || arg == "*" {
257 return true
258 }
259 }
260
261 return false
262}
263
264// AddCoauthorTrailer modifies a bash script to add a Co-authored-by trailer
265// to any git commit commands. Returns the modified script.
266func AddCoauthorTrailer(bashScript, trailer string) string {
267 r := strings.NewReader(bashScript)
268 parser := syntax.NewParser(syntax.KeepComments(true))
269 file, err := parser.Parse(r, "")
270 if err != nil {
271 // Can't parse, return original
272 return bashScript
273 }
274
275 modified := false
276 syntax.Walk(file, func(node syntax.Node) bool {
277 callExpr, ok := node.(*syntax.CallExpr)
278 if !ok {
279 return true
280 }
281 if addTrailerToGitCommit(callExpr, trailer) {
282 modified = true
283 }
284 return true
285 })
286
287 if !modified {
288 return bashScript
289 }
290
291 var buf bytes.Buffer
292 printer := syntax.NewPrinter()
293 if err := printer.Print(&buf, file); err != nil {
294 return bashScript
295 }
296 return buf.String()
297}
298
299// addTrailerToGitCommit adds --trailer to a git commit command.
300// Returns true if the command was modified.
301func addTrailerToGitCommit(cmd *syntax.CallExpr, trailer string) bool {
302 if !isGitCommitCommand(cmd) {
303 return false
304 }
305
306 // Find where to insert --trailer (right after "commit")
307 insertIdx := -1
308 for i := 1; i < len(cmd.Args); i++ {
309 if cmd.Args[i].Lit() == "commit" {
310 insertIdx = i + 1
311 break
312 }
313 }
314 if insertIdx < 0 {
315 return false
316 }
317
318 // Create --trailer argument
319 trailerArg := &syntax.Word{
320 Parts: []syntax.WordPart{
321 &syntax.Lit{Value: "--trailer"},
322 },
323 }
324 // Create the trailer value argument
325 trailerVal := &syntax.Word{
326 Parts: []syntax.WordPart{
327 &syntax.DblQuoted{
328 Parts: []syntax.WordPart{
329 &syntax.Lit{Value: trailer},
330 },
331 },
332 },
333 }
334
335 // Insert the two new arguments
336 newArgs := make([]*syntax.Word, 0, len(cmd.Args)+2)
337 newArgs = append(newArgs, cmd.Args[:insertIdx]...)
338 newArgs = append(newArgs, trailerArg, trailerVal)
339 newArgs = append(newArgs, cmd.Args[insertIdx:]...)
340 cmd.Args = newArgs
341
342 return true
343}
344
345// isGitCommitCommand checks if a command is 'git commit'.
346func isGitCommitCommand(cmd *syntax.CallExpr) bool {
347 if len(cmd.Args) < 2 {
348 return false
349 }
350
351 // First argument must be 'git'
352 if cmd.Args[0].Lit() != "git" {
353 return false
354 }
355
356 // Look for 'commit' in any position after 'git'
357 for i := 1; i < len(cmd.Args); i++ {
358 if cmd.Args[i].Lit() == "commit" {
359 return true
360 }
361 }
362
363 return false
364}
365
366// noSketchWipBranchChangesOnce checks for git commands that would change the sketch-wip branch.
367// It rejects commands that would rename the sketch-wip branch or switch away from it.
368// This check only shows the warning once per process.
369func noSketchWipBranchChangesOnce(cmd *syntax.CallExpr) error {
370 if hasSketchWipBranchChanges(cmd) {
371 // Check if we've already warned in this process
372 sketchWipWarningMu.Lock()
373 alreadyWarned := sketchWipWarningShown
374 if !alreadyWarned {
375 sketchWipWarningShown = true
376 }
377 sketchWipWarningMu.Unlock()
378
379 if !alreadyWarned {
380 return fmt.Errorf("permission denied: cannot leave 'sketch-wip' branch. This branch is designated for change detection and auto-push; work on other branches may be lost. Warning shown once per session. Repeat command if needed for temporary operations (rebase, bisect, etc.) but return to sketch-wip afterward. Note: users can push to any branch via the Push button in the UI")
381 }
382 }
383 return nil
384}
385
386// hasSketchWipBranchChanges checks if a git command would change the sketch-wip branch.
387func hasSketchWipBranchChanges(cmd *syntax.CallExpr) bool {
388 if len(cmd.Args) < 2 {
389 return false
390 }
391 if cmd.Args[0].Lit() != "git" {
392 return false
393 }
394
395 // Look for subcommands that could change the sketch-wip branch
396 for i := 1; i < len(cmd.Args); i++ {
397 arg := cmd.Args[i].Lit()
398 switch arg {
399 case "branch":
400 // Check for branch rename: git branch -m sketch-wip newname or git branch -M sketch-wip newname
401 if i+2 < len(cmd.Args) {
402 // Look for -m or -M flag
403 for j := i + 1; j < len(cmd.Args)-1; j++ {
404 flag := cmd.Args[j].Lit()
405 if flag == "-m" || flag == "-M" {
406 // Check if sketch-wip is the source branch
407 if cmd.Args[j+1].Lit() == "sketch-wip" {
408 return true
409 }
410 }
411 }
412 }
413 case "checkout":
414 // Check for branch switching: git checkout otherbranch
415 // But allow git checkout files/paths
416 if i+1 < len(cmd.Args) {
417 nextArg := cmd.Args[i+1].Lit()
418 // Skip if it's a flag
419 if !strings.HasPrefix(nextArg, "-") {
420 // This might be a branch checkout - we'll be conservative and warn
421 // unless it looks like a file path
422 if !strings.Contains(nextArg, "/") && !strings.Contains(nextArg, ".") {
423 return true
424 }
425 }
426 }
427 case "switch":
428 // Check for branch switching: git switch otherbranch
429 if i+1 < len(cmd.Args) {
430 nextArg := cmd.Args[i+1].Lit()
431 // Skip if it's a flag
432 if !strings.HasPrefix(nextArg, "-") {
433 return true
434 }
435 }
436 }
437 }
438
439 return false
440}