bash.go

  1package tools
  2
  3import (
  4	"bytes"
  5	"cmp"
  6	"context"
  7	_ "embed"
  8	"fmt"
  9	"html/template"
 10	"os"
 11	"path/filepath"
 12	"runtime"
 13	"strings"
 14	"time"
 15
 16	"charm.land/fantasy"
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/shell"
 20)
 21
 22type BashParams struct {
 23	Description     string `json:"description" description:"A brief description of what the command does, try to keep it under 30 characters or so"`
 24	Command         string `json:"command" description:"The command to execute"`
 25	WorkingDir      string `json:"working_dir,omitempty" description:"The working directory to execute the command in (defaults to current directory)"`
 26	RunInBackground bool   `json:"run_in_background,omitempty" description:"Set to true (boolean) to run this command in the background. Use job_output to read the output later."`
 27}
 28
 29type BashPermissionsParams struct {
 30	Description     string `json:"description"`
 31	Command         string `json:"command"`
 32	WorkingDir      string `json:"working_dir"`
 33	RunInBackground bool   `json:"run_in_background"`
 34}
 35
 36type BashResponseMetadata struct {
 37	StartTime        int64  `json:"start_time"`
 38	EndTime          int64  `json:"end_time"`
 39	Output           string `json:"output"`
 40	Description      string `json:"description"`
 41	WorkingDirectory string `json:"working_directory"`
 42	Background       bool   `json:"background,omitempty"`
 43	ShellID          string `json:"shell_id,omitempty"`
 44}
 45
 46const (
 47	BashToolName = "bash"
 48
 49	AutoBackgroundThreshold = 1 * time.Minute // Commands taking longer automatically become background jobs
 50	MaxOutputLength         = 30000
 51	BashNoOutput            = "no output"
 52)
 53
 54//go:embed bash.tpl
 55var bashDescriptionTmpl []byte
 56
 57var bashDescriptionTpl = template.Must(
 58	template.New("bashDescription").
 59		Parse(string(bashDescriptionTmpl)),
 60)
 61
 62type bashDescriptionData struct {
 63	BannedCommands  string
 64	MaxOutputLength int
 65	Attribution     config.Attribution
 66	ModelName       string
 67}
 68
 69var defaultBannedCommands = []string{
 70	// Network/Download tools
 71	"alias",
 72	"aria2c",
 73	"axel",
 74	"chrome",
 75	"curl",
 76	"curlie",
 77	"firefox",
 78	"http-prompt",
 79	"httpie",
 80	"links",
 81	"lynx",
 82	"nc",
 83	"safari",
 84	"scp",
 85	"ssh",
 86	"telnet",
 87	"w3m",
 88	"wget",
 89	"xh",
 90
 91	// System administration
 92	"doas",
 93	"su",
 94	"sudo",
 95
 96	// Package managers
 97	"apk",
 98	"apt",
 99	"apt-cache",
100	"apt-get",
101	"dnf",
102	"dpkg",
103	"emerge",
104	"home-manager",
105	"makepkg",
106	"opkg",
107	"pacman",
108	"paru",
109	"pkg",
110	"pkg_add",
111	"pkg_delete",
112	"portage",
113	"rpm",
114	"yay",
115	"yum",
116	"zypper",
117
118	// System modification
119	"at",
120	"batch",
121	"chkconfig",
122	"crontab",
123	"fdisk",
124	"mkfs",
125	"mount",
126	"parted",
127	"service",
128	"systemctl",
129	"umount",
130
131	// Network configuration
132	"firewall-cmd",
133	"ifconfig",
134	"ip",
135	"iptables",
136	"netstat",
137	"pfctl",
138	"route",
139	"ufw",
140}
141
142func bashDescription(attribution *config.Attribution, modelName string, bashConfig config.ToolBash) string {
143	bannedCommandsList := resolveBannedCommandsList(bashConfig)
144	bannedCommandsStr := strings.Join(bannedCommandsList, ", ")
145	var out bytes.Buffer
146	if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
147		BannedCommands:  bannedCommandsStr,
148		MaxOutputLength: MaxOutputLength,
149		Attribution:     *attribution,
150		ModelName:       modelName,
151	}); err != nil {
152		// this should never happen.
153		panic("failed to execute bash description template: " + err.Error())
154	}
155	return out.String()
156}
157
158var defaultBannedSubCommands = []shell.BlockFunc{
159	// System package managers
160	shell.ArgumentsBlocker("apk", []string{"add"}, nil),
161	shell.ArgumentsBlocker("apt", []string{"install"}, nil),
162	shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
163	shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
164	shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
165	shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
166	shell.ArgumentsBlocker("yum", []string{"install"}, nil),
167	shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
168
169	// Language-specific package managers
170	shell.ArgumentsBlocker("brew", []string{"install"}, nil),
171	shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
172	shell.ArgumentsBlocker("gem", []string{"install"}, nil),
173	shell.ArgumentsBlocker("go", []string{"install"}, nil),
174	shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
175	shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
176	shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
177	shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
178	shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
179	shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
180	shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
181
182	// `go test -exec` can run arbitrary commands
183	shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
184}
185
186func blockFuncs(bannedCommands []string, bannedSubCommands []config.BannedToolArgsAndOrParams, includeSubCommandDefaults bool) []shell.BlockFunc {
187	blockFuncs := []shell.BlockFunc{}
188	blockFuncs = append(blockFuncs, shell.CommandsBlocker(bannedCommands))
189
190	for _, bannedSubCmd := range bannedSubCommands {
191		blockFuncs = append(blockFuncs, shell.ArgumentsBlocker(bannedSubCmd.Command, bannedSubCmd.Args, bannedSubCmd.Flags))
192	}
193
194	if includeSubCommandDefaults {
195		blockFuncs = append(blockFuncs, defaultBannedSubCommands...)
196	}
197	return blockFuncs
198}
199
200func resolveBannedCommandsList(cfg config.ToolBash) []string {
201	bannedCommands := cfg.BannedCommands
202	if !cfg.DisableDefaultCommands {
203		if len(bannedCommands) == 0 {
204			return defaultBannedCommands
205		}
206		bannedCommands = append(bannedCommands, defaultBannedCommands...)
207	}
208	return bannedCommands
209}
210
211func resolveBlockFuncs(cfg config.ToolBash) []shell.BlockFunc {
212	return blockFuncs(resolveBannedCommandsList(cfg), cfg.BannedSubCommands, cfg.DisableDefaultSubCommands)
213}
214
215func NewBashTool(
216	permissions permission.Service,
217	workingDir string, attribution *config.Attribution,
218	modelName string,
219	bashConfig config.ToolBash,
220) fantasy.AgentTool {
221	return fantasy.NewAgentTool(
222		BashToolName,
223		string(bashDescription(attribution, modelName, bashConfig)),
224		func(ctx context.Context, params BashParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
225			if params.Command == "" {
226				return fantasy.NewTextErrorResponse("missing command"), nil
227			}
228
229			// Determine working directory
230			execWorkingDir := cmp.Or(params.WorkingDir, workingDir)
231
232			isSafeReadOnly := false
233			cmdLower := strings.ToLower(params.Command)
234
235			for _, safe := range safeCommands {
236				if strings.HasPrefix(cmdLower, safe) {
237					if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
238						isSafeReadOnly = true
239						break
240					}
241				}
242			}
243
244			sessionID := GetSessionFromContext(ctx)
245			if sessionID == "" {
246				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
247			}
248			if !isSafeReadOnly {
249				p := permissions.Request(
250					permission.CreatePermissionRequest{
251						SessionID:   sessionID,
252						Path:        execWorkingDir,
253						ToolCallID:  call.ID,
254						ToolName:    BashToolName,
255						Action:      "execute",
256						Description: fmt.Sprintf("Execute command: %s", params.Command),
257						Params:      BashPermissionsParams(params),
258					},
259				)
260				if !p {
261					return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
262				}
263			}
264
265			// If explicitly requested as background, start immediately with detached context
266			if params.RunInBackground {
267				startTime := time.Now()
268				bgManager := shell.GetBackgroundShellManager()
269				bgManager.Cleanup()
270				// Use background context so it continues after tool returns
271				bgShell, err := bgManager.Start(context.Background(), execWorkingDir, resolveBlockFuncs(bashConfig), params.Command, params.Description)
272				if err != nil {
273					return fantasy.ToolResponse{}, fmt.Errorf("error starting background shell: %w", err)
274				}
275
276				// Wait a short time to detect fast failures (blocked commands, syntax errors, etc.)
277				time.Sleep(1 * time.Second)
278				stdout, stderr, done, execErr := bgShell.GetOutput()
279
280				if done {
281					// Command failed or completed very quickly
282					bgManager.Remove(bgShell.ID)
283
284					interrupted := shell.IsInterrupt(execErr)
285					exitCode := shell.ExitCode(execErr)
286					if exitCode == 0 && !interrupted && execErr != nil {
287						return fantasy.ToolResponse{}, fmt.Errorf("[Job %s] error executing command: %w", bgShell.ID, execErr)
288					}
289
290					stdout = formatOutput(stdout, stderr, execErr)
291
292					metadata := BashResponseMetadata{
293						StartTime:        startTime.UnixMilli(),
294						EndTime:          time.Now().UnixMilli(),
295						Output:           stdout,
296						Description:      params.Description,
297						Background:       params.RunInBackground,
298						WorkingDirectory: bgShell.WorkingDir,
299					}
300					if stdout == "" {
301						return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
302					}
303					stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(bgShell.WorkingDir))
304					return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
305				}
306
307				// Still running after fast-failure check - return as background job
308				metadata := BashResponseMetadata{
309					StartTime:        startTime.UnixMilli(),
310					EndTime:          time.Now().UnixMilli(),
311					Description:      params.Description,
312					WorkingDirectory: bgShell.WorkingDir,
313					Background:       true,
314					ShellID:          bgShell.ID,
315				}
316				response := fmt.Sprintf("Background shell started with ID: %s\n\nUse job_output tool to view output or job_kill to terminate.", bgShell.ID)
317				return fantasy.WithResponseMetadata(fantasy.NewTextResponse(response), metadata), nil
318			}
319
320			// Start synchronous execution with auto-background support
321			startTime := time.Now()
322
323			// Start with detached context so it can survive if moved to background
324			bgManager := shell.GetBackgroundShellManager()
325			bgManager.Cleanup()
326			bgShell, err := bgManager.Start(context.Background(), execWorkingDir, resolveBlockFuncs(bashConfig), params.Command, params.Description)
327			if err != nil {
328				return fantasy.ToolResponse{}, fmt.Errorf("error starting shell: %w", err)
329			}
330
331			// Wait for either completion, auto-background threshold, or context cancellation
332			ticker := time.NewTicker(100 * time.Millisecond)
333			defer ticker.Stop()
334			timeout := time.After(AutoBackgroundThreshold)
335
336			var stdout, stderr string
337			var done bool
338			var execErr error
339
340		waitLoop:
341			for {
342				select {
343				case <-ticker.C:
344					stdout, stderr, done, execErr = bgShell.GetOutput()
345					if done {
346						break waitLoop
347					}
348				case <-timeout:
349					stdout, stderr, done, execErr = bgShell.GetOutput()
350					break waitLoop
351				case <-ctx.Done():
352					// Incoming context was cancelled before we moved to background
353					// Kill the shell and return error
354					bgManager.Kill(bgShell.ID)
355					return fantasy.ToolResponse{}, ctx.Err()
356				}
357			}
358
359			if done {
360				// Command completed within threshold - return synchronously
361				// Remove from background manager since we're returning directly
362				// Don't call Kill() as it cancels the context and corrupts the exit code
363				bgManager.Remove(bgShell.ID)
364
365				interrupted := shell.IsInterrupt(execErr)
366				exitCode := shell.ExitCode(execErr)
367				if exitCode == 0 && !interrupted && execErr != nil {
368					return fantasy.ToolResponse{}, fmt.Errorf("[Job %s] error executing command: %w", bgShell.ID, execErr)
369				}
370
371				stdout = formatOutput(stdout, stderr, execErr)
372
373				metadata := BashResponseMetadata{
374					StartTime:        startTime.UnixMilli(),
375					EndTime:          time.Now().UnixMilli(),
376					Output:           stdout,
377					Description:      params.Description,
378					Background:       params.RunInBackground,
379					WorkingDirectory: bgShell.WorkingDir,
380				}
381				if stdout == "" {
382					return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
383				}
384				stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(bgShell.WorkingDir))
385				return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
386			}
387
388			// Still running - keep as background job
389			metadata := BashResponseMetadata{
390				StartTime:        startTime.UnixMilli(),
391				EndTime:          time.Now().UnixMilli(),
392				Description:      params.Description,
393				WorkingDirectory: bgShell.WorkingDir,
394				Background:       true,
395				ShellID:          bgShell.ID,
396			}
397			response := fmt.Sprintf("Command is taking longer than expected and has been moved to background.\n\nBackground shell ID: %s\n\nUse job_output tool to view output or job_kill to terminate.", bgShell.ID)
398			return fantasy.WithResponseMetadata(fantasy.NewTextResponse(response), metadata), nil
399		})
400}
401
402// formatOutput formats the output of a completed command with error handling
403func formatOutput(stdout, stderr string, execErr error) string {
404	interrupted := shell.IsInterrupt(execErr)
405	exitCode := shell.ExitCode(execErr)
406
407	stdout = truncateOutput(stdout)
408	stderr = truncateOutput(stderr)
409
410	errorMessage := stderr
411	if errorMessage == "" && execErr != nil {
412		errorMessage = execErr.Error()
413	}
414
415	if interrupted {
416		if errorMessage != "" {
417			errorMessage += "\n"
418		}
419		errorMessage += "Command was aborted before completion"
420	} else if exitCode != 0 {
421		if errorMessage != "" {
422			errorMessage += "\n"
423		}
424		errorMessage += fmt.Sprintf("Exit code %d", exitCode)
425	}
426
427	hasBothOutputs := stdout != "" && stderr != ""
428
429	if hasBothOutputs {
430		stdout += "\n"
431	}
432
433	if errorMessage != "" {
434		stdout += "\n" + errorMessage
435	}
436
437	return stdout
438}
439
440func truncateOutput(content string) string {
441	if len(content) <= MaxOutputLength {
442		return content
443	}
444
445	halfLength := MaxOutputLength / 2
446	start := content[:halfLength]
447	end := content[len(content)-halfLength:]
448
449	truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
450	return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
451}
452
453func countLines(s string) int {
454	if s == "" {
455		return 0
456	}
457	return len(strings.Split(s, "\n"))
458}
459
460func normalizeWorkingDir(path string) string {
461	if runtime.GOOS == "windows" {
462		cwd, err := os.Getwd()
463		if err != nil {
464			cwd = "C:"
465		}
466		path = strings.ReplaceAll(path, filepath.VolumeName(cwd), "")
467	}
468
469	return filepath.ToSlash(path)
470}