bash.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"runtime"
  8	"strings"
  9	"time"
 10
 11	"github.com/charmbracelet/crush/internal/permission"
 12	"github.com/charmbracelet/crush/internal/shell"
 13	"mvdan.cc/sh/v3/syntax"
 14)
 15
 16type BashParams struct {
 17	Command string `json:"command"`
 18	Timeout int    `json:"timeout"`
 19}
 20
 21type BashPermissionsParams struct {
 22	Command string `json:"command"`
 23	Timeout int    `json:"timeout"`
 24}
 25
 26type BashResponseMetadata struct {
 27	StartTime int64 `json:"start_time"`
 28	EndTime   int64 `json:"end_time"`
 29}
 30type bashTool struct {
 31	permissions permission.Service
 32	workingDir  string
 33}
 34
 35const (
 36	BashToolName = "bash"
 37
 38	DefaultTimeout  = 1 * 60 * 1000  // 1 minutes in milliseconds
 39	MaxTimeout      = 10 * 60 * 1000 // 10 minutes in milliseconds
 40	MaxOutputLength = 30000
 41	BashNoOutput    = "no output"
 42)
 43
 44func containsBannedCommand(node syntax.Node) bool {
 45	if node == nil {
 46		return false
 47	}
 48
 49	switch n := node.(type) {
 50	case *syntax.CallExpr:
 51		if len(n.Args) > 0 {
 52			cmdName := getWordValue(n.Args[0])
 53			for _, banned := range bannedCommands {
 54				if strings.EqualFold(cmdName, banned) {
 55					return true
 56				}
 57			}
 58		}
 59		for _, arg := range n.Args {
 60			if containsBannedCommand(arg) {
 61				return true
 62			}
 63		}
 64	case *syntax.Word:
 65		if checkWordForBannedCommands(n) {
 66			return true
 67		}
 68		for _, part := range n.Parts {
 69			if containsBannedCommand(part) {
 70				return true
 71			}
 72		}
 73	case *syntax.CmdSubst:
 74		for _, stmt := range n.Stmts {
 75			if containsBannedCommand(stmt) {
 76				return true
 77			}
 78		}
 79	case *syntax.Subshell:
 80		for _, stmt := range n.Stmts {
 81			if containsBannedCommand(stmt) {
 82				return true
 83			}
 84		}
 85	case *syntax.Stmt:
 86		if containsBannedCommand(n.Cmd) {
 87			return true
 88		}
 89		for _, redir := range n.Redirs {
 90			if containsBannedCommand(redir) {
 91				return true
 92			}
 93		}
 94	case *syntax.BinaryCmd:
 95		return containsBannedCommand(n.X) || containsBannedCommand(n.Y)
 96	case *syntax.Block:
 97		for _, stmt := range n.Stmts {
 98			if containsBannedCommand(stmt) {
 99				return true
100			}
101		}
102	case *syntax.IfClause:
103		for _, stmt := range n.Cond {
104			if containsBannedCommand(stmt) {
105				return true
106			}
107		}
108		for _, stmt := range n.Then {
109			if containsBannedCommand(stmt) {
110				return true
111			}
112		}
113		if n.Else != nil && containsBannedCommand(n.Else) {
114			return true
115		}
116	case *syntax.WhileClause:
117		for _, stmt := range n.Cond {
118			if containsBannedCommand(stmt) {
119				return true
120			}
121		}
122		for _, stmt := range n.Do {
123			if containsBannedCommand(stmt) {
124				return true
125			}
126		}
127	case *syntax.ForClause:
128		for _, stmt := range n.Do {
129			if containsBannedCommand(stmt) {
130				return true
131			}
132		}
133		if containsBannedCommand(n.Loop) {
134			return true
135		}
136	case *syntax.CaseClause:
137		for _, item := range n.Items {
138			for _, stmt := range item.Stmts {
139				if containsBannedCommand(stmt) {
140					return true
141				}
142			}
143		}
144	case *syntax.FuncDecl:
145		return containsBannedCommand(n.Body)
146	case *syntax.ArithmExp:
147		return containsBannedCommand(n.X)
148	case *syntax.Redirect:
149		return containsBannedCommand(n.Word)
150	}
151	return false
152}
153
154func checkWordForBannedCommands(word *syntax.Word) bool {
155	if word == nil {
156		return false
157	}
158
159	for _, part := range word.Parts {
160		switch p := part.(type) {
161		case *syntax.SglQuoted:
162			if checkQuotedStringForBannedCommands(p.Value) {
163				return true
164			}
165		case *syntax.DblQuoted:
166			var content strings.Builder
167			for _, qpart := range p.Parts {
168				if lit, ok := qpart.(*syntax.Lit); ok {
169					content.WriteString(lit.Value)
170				}
171			}
172			if checkQuotedStringForBannedCommands(content.String()) {
173				return true
174			}
175		}
176	}
177	return false
178}
179
180func checkQuotedStringForBannedCommands(content string) bool {
181	parser := syntax.NewParser()
182	file, err := parser.Parse(strings.NewReader(content), "")
183	if err != nil {
184		return false
185	}
186
187	if len(file.Stmts) == 0 {
188		return false
189	}
190
191	// Simple heuristic: if it looks like prose rather than commands, don't flag it
192	if len(file.Stmts) == 1 {
193		stmt := file.Stmts[0]
194		if callExpr, ok := stmt.Cmd.(*syntax.CallExpr); ok {
195			if len(callExpr.Args) > 2 {
196				allText := true
197				for i, arg := range callExpr.Args {
198					if i == 0 {
199						continue
200					}
201					argStr := getWordValue(arg)
202					if strings.HasPrefix(argStr, "-") {
203						allText = false
204						break
205					}
206				}
207				if allText {
208					return false
209				}
210			}
211		}
212	}
213
214	for _, stmt := range file.Stmts {
215		if containsBannedCommand(stmt) {
216			return true
217		}
218	}
219	return false
220}
221
222func getWordValue(word *syntax.Word) string {
223	if word == nil || len(word.Parts) == 0 {
224		return ""
225	}
226
227	var result strings.Builder
228	for _, part := range word.Parts {
229		switch p := part.(type) {
230		case *syntax.Lit:
231			result.WriteString(p.Value)
232		case *syntax.SglQuoted:
233			result.WriteString(p.Value)
234		case *syntax.DblQuoted:
235			for _, qpart := range p.Parts {
236				if lit, ok := qpart.(*syntax.Lit); ok {
237					result.WriteString(lit.Value)
238				}
239			}
240		}
241	}
242	return result.String()
243}
244
245func validateCommand(command string) error {
246	parser := syntax.NewParser()
247	file, err := parser.Parse(strings.NewReader(command), "")
248	if err != nil {
249		parts := strings.Fields(command)
250		if len(parts) > 0 {
251			baseCmd := parts[0]
252			for _, banned := range bannedCommands {
253				if strings.EqualFold(baseCmd, banned) {
254					return fmt.Errorf("command '%s' is not allowed", baseCmd)
255				}
256			}
257		}
258		return nil
259	}
260
261	for _, stmt := range file.Stmts {
262		if containsBannedCommand(stmt) {
263			return fmt.Errorf("command contains banned operations")
264		}
265	}
266	return nil
267}
268
269var bannedCommands = []string{
270	"alias", "curl", "curlie", "wget", "axel", "aria2c",
271	"nc", "telnet", "lynx", "w3m", "links", "httpie", "xh",
272	"http-prompt", "chrome", "firefox", "safari", "sudo",
273}
274
275// getSafeReadOnlyCommands returns platform-appropriate safe commands
276func getSafeReadOnlyCommands() []string {
277	// Base commands that work on all platforms
278	baseCommands := []string{
279		// Cross-platform commands
280		"echo", "hostname", "whoami",
281
282		// Git commands (cross-platform)
283		"git status", "git log", "git diff", "git show", "git branch", "git tag", "git remote", "git ls-files", "git ls-remote",
284		"git rev-parse", "git config --get", "git config --list", "git describe", "git blame", "git grep", "git shortlog",
285
286		// Go commands (cross-platform)
287		"go version", "go help", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean",
288	}
289
290	if runtime.GOOS == "windows" {
291		// Windows-specific commands
292		windowsCommands := []string{
293			"dir", "type", "where", "ver", "systeminfo", "tasklist", "ipconfig", "ping", "nslookup",
294			"Get-Process", "Get-Location", "Get-ChildItem", "Get-Content", "Get-Date", "Get-Host", "Get-ComputerInfo",
295		}
296		return append(baseCommands, windowsCommands...)
297	} else {
298		// Unix/Linux commands (including WSL, since WSL reports as Linux)
299		unixCommands := []string{
300			"ls", "pwd", "date", "cal", "uptime", "id", "groups", "env", "printenv", "set", "unset", "which", "type", "whereis",
301			"whatis", "uname", "df", "du", "free", "top", "ps", "kill", "killall", "nice", "nohup", "time", "timeout",
302		}
303		return append(baseCommands, unixCommands...)
304	}
305}
306
307func bashDescription() string {
308	bannedCommandsStr := strings.Join(bannedCommands, ", ")
309	return fmt.Sprintf(`Executes a given bash command in a persistent shell session with optional timeout, ensuring proper handling and security measures.
310
311CROSS-PLATFORM SHELL SUPPORT:
312- Unix/Linux/macOS: Uses native bash/sh shell
313- Windows: Intelligent shell selection:
314  * Windows commands (dir, type, copy, etc.) use cmd.exe
315  * PowerShell commands (Get-, Set-, etc.) use PowerShell
316  * Unix-style commands (ls, cat, etc.) use POSIX emulation
317- WSL: Automatically treated as Linux (which is correct)
318- Automatic detection: Chooses the best shell based on command and platform
319- Persistent state: Working directory and environment variables persist between commands
320
321WINDOWS-SPECIFIC FEATURES:
322- Native Windows commands: dir, type, copy, move, del, md, rd, cls, where, tasklist, etc.
323- PowerShell support: Get-Process, Set-Location, and other PowerShell cmdlets
324- Windows path handling: Supports both forward slashes (/) and backslashes (\)
325- Drive letters: Properly handles C:\, D:\, etc.
326- Environment variables: Supports both Unix ($VAR) and Windows (%%VAR%%) syntax
327
328Before executing the command, please follow these steps:
329
3301. Directory Verification:
331 - If the command will create new directories or files, first use the LS tool to verify the parent directory exists and is the correct location
332 - For example, before running "mkdir foo/bar", first use LS to check that "foo" exists and is the intended parent directory
333
3342. Security Check:
335 - For security and to limit the threat of a prompt injection attack, some commands are limited or banned. If you use a disallowed command, you will receive an error message explaining the restriction. Explain the error to the User.
336 - Verify that the command is not one of the banned commands: %s.
337
3383. Command Execution:
339 - After ensuring proper quoting, execute the command.
340 - Capture the output of the command.
341
3424. Output Processing:
343 - If the output exceeds %d characters, output will be truncated before being returned to you.
344 - Prepare the output for display to the user.
345
3465. Return Result:
347 - Provide the processed output of the command.
348 - If any errors occurred during execution, include those in the output.
349
350Usage notes:
351- The command argument is required.
352- You can specify an optional timeout in milliseconds (up to 600000ms / 10 minutes). If not specified, commands will timeout after 30 minutes.
353- VERY IMPORTANT: You MUST avoid using search commands like 'find' and 'grep'. Instead use Grep, Glob, or Agent tools to search. You MUST avoid read tools like 'cat', 'head', 'tail', and 'ls', and use FileRead and LS tools to read files.
354- When issuing multiple commands, use the ';' or '&&' operator to separate them. DO NOT use newlines (newlines are ok in quoted strings).
355- IMPORTANT: All commands share the same shell session. Shell state (environment variables, virtual environments, current directory, etc.) persist between commands. For example, if you set an environment variable as part of a command, the environment variable will persist for subsequent commands.
356- Try to maintain your current working directory throughout the session by using absolute paths and avoiding usage of 'cd'. You may use 'cd' if the User explicitly requests it.
357<good-example>
358pytest /foo/bar/tests
359</good-example>
360<bad-example>
361cd /foo/bar && pytest tests
362</bad-example>
363
364# Committing changes with git
365
366When the user asks you to create a new git commit, follow these steps carefully:
367
3681. Start with a single message that contains exactly three tool_use blocks that do the following (it is VERY IMPORTANT that you send these tool_use blocks in a single message, otherwise it will feel slow to the user!):
369 - Run a git status command to see all untracked files.
370 - Run a git diff command to see both staged and unstaged changes that will be committed.
371 - Run a git log command to see recent commit messages, so that you can follow this repository's commit message style.
372
3732. Use the git context at the start of this conversation to determine which files are relevant to your commit. Add relevant untracked files to the staging area. Do not commit files that were already modified at the start of this conversation, if they are not relevant to your commit.
374
3753. Analyze all staged changes (both previously staged and newly added) and draft a commit message. Wrap your analysis process in <commit_analysis> tags:
376
377<commit_analysis>
378- List the files that have been changed or added
379- Summarize the nature of the changes (eg. new feature, enhancement to an existing feature, bug fix, refactoring, test, docs, etc.)
380- Brainstorm the purpose or motivation behind these changes
381- Do not use tools to explore code, beyond what is available in the git context
382- Assess the impact of these changes on the overall project
383- Check for any sensitive information that shouldn't be committed
384- Draft a concise (1-2 sentences) commit message that focuses on the "why" rather than the "what"
385- Ensure your language is clear, concise, and to the point
386- Ensure the message accurately reflects the changes and their purpose (i.e. "add" means a wholly new feature, "update" means an enhancement to an existing feature, "fix" means a bug fix, etc.)
387- Ensure the message is not generic (avoid words like "Update" or "Fix" without context)
388- Review the draft message to ensure it accurately reflects the changes and their purpose
389</commit_analysis>
390
3914. Create the commit with a message ending with:
392💘 Generated with Crush
393Co-Authored-By: Crush <noreply@crush.charm.land>
394
395- In order to ensure good formatting, ALWAYS pass the commit message via a HEREDOC, a la this example:
396<example>
397git commit -m "$(cat <<'EOF'
398 Commit message here.
399
400 💘 Generated with Crush
401 Co-Authored-By: 💘 Crush <noreply@crush.charm.land>
402 EOF
403 )"
404</example>
405
4065. If the commit fails due to pre-commit hook changes, retry the commit ONCE to include these automated changes. If it fails again, it usually means a pre-commit hook is preventing the commit. If the commit succeeds but you notice that files were modified by the pre-commit hook, you MUST amend your commit to include them.
407
4086. Finally, run git status to make sure the commit succeeded.
409
410Important notes:
411- When possible, combine the "git add" and "git commit" commands into a single "git commit -am" command, to speed things up
412- However, be careful not to stage files (e.g. with 'git add .') for commits that aren't part of the change, they may have untracked files they want to keep around, but not commit.
413- NEVER update the git config
414- DO NOT push to the remote repository
415- IMPORTANT: Never use git commands with the -i flag (like git rebase -i or git add -i) since they require interactive input which is not supported.
416- If there are no changes to commit (i.e., no untracked files and no modifications), do not create an empty commit
417- Ensure your commit message is meaningful and concise. It should explain the purpose of the changes, not just describe them.
418- Return an empty response - the user will see the git output directly
419
420# Creating pull requests
421Use the gh command via the Bash tool for ALL GitHub-related tasks including working with issues, pull requests, checks, and releases. If given a Github URL use the gh command to get the information needed.
422
423IMPORTANT: When the user asks you to create a pull request, follow these steps carefully:
424
4251. Understand the current state of the branch. Remember to send a single message that contains multiple tool_use blocks (it is VERY IMPORTANT that you do this in a single message, otherwise it will feel slow to the user!):
426 - Run a git status command to see all untracked files.
427 - Run a git diff command to see both staged and unstaged changes that will be committed.
428 - Check if the current branch tracks a remote branch and is up to date with the remote, so you know if you need to push to the remote
429 - Run a git log command and 'git diff main...HEAD' to understand the full commit history for the current branch (from the time it diverged from the 'main' branch.)
430
4312. Create new branch if needed
432
4333. Commit changes if needed
434
4354. Push to remote with -u flag if needed
436
4375. Analyze all changes that will be included in the pull request, making sure to look at all relevant commits (not just the latest commit, but all commits that will be included in the pull request!), and draft a pull request summary. Wrap your analysis process in <pr_analysis> tags:
438
439<pr_analysis>
440- List the commits since diverging from the main branch
441- Summarize the nature of the changes (eg. new feature, enhancement to an existing feature, bug fix, refactoring, test, docs, etc.)
442- Brainstorm the purpose or motivation behind these changes
443- Assess the impact of these changes on the overall project
444- Do not use tools to explore code, beyond what is available in the git context
445- Check for any sensitive information that shouldn't be committed
446- Draft a concise (1-2 bullet points) pull request summary that focuses on the "why" rather than the "what"
447- Ensure the summary accurately reflects all changes since diverging from the main branch
448- Ensure your language is clear, concise, and to the point
449- Ensure the summary accurately reflects the changes and their purpose (ie. "add" means a wholly new feature, "update" means an enhancement to an existing feature, "fix" means a bug fix, etc.)
450- Ensure the summary is not generic (avoid words like "Update" or "Fix" without context)
451- Review the draft summary to ensure it accurately reflects the changes and their purpose
452</pr_analysis>
453
4546. Create PR using gh pr create with the format below. Use a HEREDOC to pass the body to ensure correct formatting.
455<example>
456gh pr create --title "the pr title" --body "$(cat <<'EOF'
457## Summary
458<1-3 bullet points>
459
460## Test plan
461[Checklist of TODOs for testing the pull request...]
462
463💘 Generated with Crush
464EOF
465)"
466</example>
467
468Important:
469- Return an empty response - the user will see the gh output directly
470- Never update git config`, bannedCommandsStr, MaxOutputLength)
471}
472
473func NewBashTool(permission permission.Service, workingDir string) BaseTool {
474	return &bashTool{
475		permissions: permission,
476		workingDir:  workingDir,
477	}
478}
479
480func (b *bashTool) Name() string {
481	return BashToolName
482}
483
484func (b *bashTool) Info() ToolInfo {
485	return ToolInfo{
486		Name:        BashToolName,
487		Description: bashDescription(),
488		Parameters: map[string]any{
489			"command": map[string]any{
490				"type":        "string",
491				"description": "The command to execute",
492			},
493			"timeout": map[string]any{
494				"type":        "number",
495				"description": "Optional timeout in milliseconds (max 600000)",
496			},
497		},
498		Required: []string{"command"},
499	}
500}
501
502func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
503	var params BashParams
504	if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
505		return NewTextErrorResponse("invalid parameters"), nil
506	}
507
508	if params.Timeout > MaxTimeout {
509		params.Timeout = MaxTimeout
510	} else if params.Timeout <= 0 {
511		params.Timeout = DefaultTimeout
512	}
513
514	if params.Command == "" {
515		return NewTextErrorResponse("missing command"), nil
516	}
517
518
519	if err := validateCommand(params.Command); err != nil {
520		return NewTextErrorResponse(err.Error()), nil
521	}
522
523	isSafeReadOnly := false
524	cmdLower := strings.ToLower(params.Command)
525
526	// Get platform-appropriate safe commands
527	safeReadOnlyCommands := getSafeReadOnlyCommands()
528	for _, safe := range safeReadOnlyCommands {
529		if strings.HasPrefix(cmdLower, strings.ToLower(safe)) {
530			if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
531				isSafeReadOnly = true
532				break
533			}
534		}
535	}
536
537	sessionID, messageID := GetContextValues(ctx)
538	if sessionID == "" || messageID == "" {
539		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
540	}
541	if !isSafeReadOnly {
542		p := b.permissions.Request(
543			permission.CreatePermissionRequest{
544				SessionID:   sessionID,
545				Path:        b.workingDir,
546				ToolName:    BashToolName,
547				Action:      "execute",
548				Description: fmt.Sprintf("Execute command: %s", params.Command),
549				Params: BashPermissionsParams{
550					Command: params.Command,
551				},
552			},
553		)
554		if !p {
555			return ToolResponse{}, permission.ErrorPermissionDenied
556		}
557	}
558	startTime := time.Now()
559	if params.Timeout > 0 {
560		var cancel context.CancelFunc
561		ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
562		defer cancel()
563	}
564	stdout, stderr, err := shell.
565		GetPersistentShell(b.workingDir).
566		Exec(ctx, params.Command)
567	interrupted := shell.IsInterrupt(err)
568	exitCode := shell.ExitCode(err)
569	if exitCode == 0 && !interrupted && err != nil {
570		return ToolResponse{}, fmt.Errorf("error executing command: %w", err)
571	}
572
573	stdout = truncateOutput(stdout)
574	stderr = truncateOutput(stderr)
575
576	errorMessage := stderr
577	if interrupted {
578		if errorMessage != "" {
579			errorMessage += "\n"
580		}
581		errorMessage += "Command was aborted before completion"
582	} else if exitCode != 0 {
583		if errorMessage != "" {
584			errorMessage += "\n"
585		}
586		errorMessage += fmt.Sprintf("Exit code %d", exitCode)
587	}
588
589	hasBothOutputs := stdout != "" && stderr != ""
590
591	if hasBothOutputs {
592		stdout += "\n"
593	}
594
595	if errorMessage != "" {
596		stdout += "\n" + errorMessage
597	}
598
599	metadata := BashResponseMetadata{
600		StartTime: startTime.UnixMilli(),
601		EndTime:   time.Now().UnixMilli(),
602	}
603	if stdout == "" {
604		return WithResponseMetadata(NewTextResponse(BashNoOutput), metadata), nil
605	}
606	return WithResponseMetadata(NewTextResponse(stdout), metadata), nil
607}
608
609func truncateOutput(content string) string {
610	if len(content) <= MaxOutputLength {
611		return content
612	}
613
614	halfLength := MaxOutputLength / 2
615	start := content[:halfLength]
616	end := content[len(content)-halfLength:]
617
618	truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
619	return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
620}
621
622func countLines(s string) int {
623	if s == "" {
624		return 0
625	}
626	return len(strings.Split(s, "\n"))
627}