bash.go

  1package tools
  2
  3import (
  4	"bytes"
  5	"cmp"
  6	"context"
  7	_ "embed"
  8	"fmt"
  9	"html/template"
 10	"os"
 11	"path/filepath"
 12	"runtime"
 13	"strings"
 14	"time"
 15
 16	"charm.land/fantasy"
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/shell"
 20)
 21
 22type BashParams struct {
 23	Description     string `json:"description" description:"A brief description of what the command does, try to keep it under 30 characters or so"`
 24	Command         string `json:"command" description:"The command to execute"`
 25	WorkingDir      string `json:"working_dir,omitempty" description:"The working directory to execute the command in (defaults to current directory)"`
 26	RunInBackground bool   `json:"run_in_background,omitempty" description:"Set to true (boolean) to run this command in the background. Use job_output to read the output later."`
 27}
 28
 29type BashPermissionsParams struct {
 30	Description     string `json:"description"`
 31	Command         string `json:"command"`
 32	WorkingDir      string `json:"working_dir"`
 33	RunInBackground bool   `json:"run_in_background"`
 34}
 35
 36type BashResponseMetadata struct {
 37	StartTime        int64  `json:"start_time"`
 38	EndTime          int64  `json:"end_time"`
 39	Output           string `json:"output"`
 40	Description      string `json:"description"`
 41	WorkingDirectory string `json:"working_directory"`
 42	Background       bool   `json:"background,omitempty"`
 43	ShellID          string `json:"shell_id,omitempty"`
 44}
 45
 46const (
 47	BashToolName = "bash"
 48
 49	AutoBackgroundThreshold = 1 * time.Minute // Commands taking longer automatically become background jobs
 50	MaxOutputLength         = 30000
 51	BashNoOutput            = "no output"
 52)
 53
 54//go:embed bash.tpl
 55var bashDescriptionTmpl []byte
 56
 57var bashDescriptionTpl = template.Must(
 58	template.New("bashDescription").
 59		Parse(string(bashDescriptionTmpl)),
 60)
 61
 62type bashDescriptionData struct {
 63	BannedCommands  string
 64	MaxOutputLength int
 65	Attribution     config.Attribution
 66}
 67
 68var bannedCommands = []string{
 69	// Network/Download tools
 70	"alias",
 71	"aria2c",
 72	"axel",
 73	"chrome",
 74	"curl",
 75	"curlie",
 76	"firefox",
 77	"http-prompt",
 78	"httpie",
 79	"links",
 80	"lynx",
 81	"nc",
 82	"safari",
 83	"scp",
 84	"ssh",
 85	"telnet",
 86	"w3m",
 87	"wget",
 88	"xh",
 89
 90	// System administration
 91	"doas",
 92	"su",
 93	"sudo",
 94
 95	// Package managers
 96	"apk",
 97	"apt",
 98	"apt-cache",
 99	"apt-get",
100	"dnf",
101	"dpkg",
102	"emerge",
103	"home-manager",
104	"makepkg",
105	"opkg",
106	"pacman",
107	"paru",
108	"pkg",
109	"pkg_add",
110	"pkg_delete",
111	"portage",
112	"rpm",
113	"yay",
114	"yum",
115	"zypper",
116
117	// System modification
118	"at",
119	"batch",
120	"chkconfig",
121	"crontab",
122	"fdisk",
123	"mkfs",
124	"mount",
125	"parted",
126	"service",
127	"systemctl",
128	"umount",
129
130	// Network configuration
131	"firewall-cmd",
132	"ifconfig",
133	"ip",
134	"iptables",
135	"netstat",
136	"pfctl",
137	"route",
138	"ufw",
139}
140
141func bashDescription(attribution *config.Attribution) string {
142	bannedCommandsStr := strings.Join(bannedCommands, ", ")
143	var out bytes.Buffer
144	if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
145		BannedCommands:  bannedCommandsStr,
146		MaxOutputLength: MaxOutputLength,
147		Attribution:     *attribution,
148	}); err != nil {
149		// this should never happen.
150		panic("failed to execute bash description template: " + err.Error())
151	}
152	return out.String()
153}
154
155func blockFuncs() []shell.BlockFunc {
156	return []shell.BlockFunc{
157		shell.CommandsBlocker(bannedCommands),
158
159		// System package managers
160		shell.ArgumentsBlocker("apk", []string{"add"}, nil),
161		shell.ArgumentsBlocker("apt", []string{"install"}, nil),
162		shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
163		shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
164		shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
165		shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
166		shell.ArgumentsBlocker("yum", []string{"install"}, nil),
167		shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
168
169		// Language-specific package managers
170		shell.ArgumentsBlocker("brew", []string{"install"}, nil),
171		shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
172		shell.ArgumentsBlocker("gem", []string{"install"}, nil),
173		shell.ArgumentsBlocker("go", []string{"install"}, nil),
174		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
175		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
176		shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
177		shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
178		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
179		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
180		shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
181
182		// `go test -exec` can run arbitrary commands
183		shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
184	}
185}
186
187func NewBashTool(permissions permission.Service, workingDir string, attribution *config.Attribution) fantasy.AgentTool {
188	return fantasy.NewAgentTool(
189		BashToolName,
190		string(bashDescription(attribution)),
191		func(ctx context.Context, params BashParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
192			if params.Command == "" {
193				return fantasy.NewTextErrorResponse("missing command"), nil
194			}
195
196			// Determine working directory
197			execWorkingDir := cmp.Or(params.WorkingDir, workingDir)
198
199			isSafeReadOnly := false
200			cmdLower := strings.ToLower(params.Command)
201
202			for _, safe := range safeCommands {
203				if strings.HasPrefix(cmdLower, safe) {
204					if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
205						isSafeReadOnly = true
206						break
207					}
208				}
209			}
210
211			sessionID := GetSessionFromContext(ctx)
212			if sessionID == "" {
213				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
214			}
215			if !isSafeReadOnly {
216				p := permissions.Request(
217					permission.CreatePermissionRequest{
218						SessionID:   sessionID,
219						Path:        execWorkingDir,
220						ToolCallID:  call.ID,
221						ToolName:    BashToolName,
222						Action:      "execute",
223						Description: fmt.Sprintf("Execute command: %s", params.Command),
224						Params:      BashPermissionsParams(params),
225					},
226				)
227				if !p {
228					return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
229				}
230			}
231
232			// If explicitly requested as background, start immediately with detached context
233			if params.RunInBackground {
234				startTime := time.Now()
235				bgManager := shell.GetBackgroundShellManager()
236				bgManager.Cleanup()
237				// Use background context so it continues after tool returns
238				bgShell, err := bgManager.Start(context.Background(), execWorkingDir, blockFuncs(), params.Command, params.Description)
239				if err != nil {
240					return fantasy.ToolResponse{}, fmt.Errorf("error starting background shell: %w", err)
241				}
242
243				// Wait a short time to detect fast failures (blocked commands, syntax errors, etc.)
244				time.Sleep(1 * time.Second)
245				stdout, stderr, done, execErr := bgShell.GetOutput()
246
247				if done {
248					// Command failed or completed very quickly
249					bgManager.Remove(bgShell.ID)
250
251					interrupted := shell.IsInterrupt(execErr)
252					exitCode := shell.ExitCode(execErr)
253					if exitCode == 0 && !interrupted && execErr != nil {
254						return fantasy.ToolResponse{}, fmt.Errorf("[Job %s] error executing command: %w", bgShell.ID, execErr)
255					}
256
257					stdout = formatOutput(stdout, stderr, execErr)
258
259					metadata := BashResponseMetadata{
260						StartTime:        startTime.UnixMilli(),
261						EndTime:          time.Now().UnixMilli(),
262						Output:           stdout,
263						Description:      params.Description,
264						Background:       params.RunInBackground,
265						WorkingDirectory: bgShell.WorkingDir,
266					}
267					if stdout == "" {
268						return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
269					}
270					stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(bgShell.WorkingDir))
271					return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
272				}
273
274				// Still running after fast-failure check - return as background job
275				metadata := BashResponseMetadata{
276					StartTime:        startTime.UnixMilli(),
277					EndTime:          time.Now().UnixMilli(),
278					Description:      params.Description,
279					WorkingDirectory: bgShell.WorkingDir,
280					Background:       true,
281					ShellID:          bgShell.ID,
282				}
283				response := fmt.Sprintf("Background shell started with ID: %s\n\nUse job_output tool to view output or job_kill to terminate.", bgShell.ID)
284				return fantasy.WithResponseMetadata(fantasy.NewTextResponse(response), metadata), nil
285			}
286
287			// Start synchronous execution with auto-background support
288			startTime := time.Now()
289
290			// Start with detached context so it can survive if moved to background
291			bgManager := shell.GetBackgroundShellManager()
292			bgManager.Cleanup()
293			bgShell, err := bgManager.Start(context.Background(), execWorkingDir, blockFuncs(), params.Command, params.Description)
294			if err != nil {
295				return fantasy.ToolResponse{}, fmt.Errorf("error starting shell: %w", err)
296			}
297
298			// Wait for either completion, auto-background threshold, or context cancellation
299			ticker := time.NewTicker(100 * time.Millisecond)
300			defer ticker.Stop()
301			timeout := time.After(AutoBackgroundThreshold)
302
303			var stdout, stderr string
304			var done bool
305			var execErr error
306
307		waitLoop:
308			for {
309				select {
310				case <-ticker.C:
311					stdout, stderr, done, execErr = bgShell.GetOutput()
312					if done {
313						break waitLoop
314					}
315				case <-timeout:
316					stdout, stderr, done, execErr = bgShell.GetOutput()
317					break waitLoop
318				case <-ctx.Done():
319					// Incoming context was cancelled before we moved to background
320					// Kill the shell and return error
321					bgManager.Kill(bgShell.ID)
322					return fantasy.ToolResponse{}, ctx.Err()
323				}
324			}
325
326			if done {
327				// Command completed within threshold - return synchronously
328				// Remove from background manager since we're returning directly
329				// Don't call Kill() as it cancels the context and corrupts the exit code
330				bgManager.Remove(bgShell.ID)
331
332				interrupted := shell.IsInterrupt(execErr)
333				exitCode := shell.ExitCode(execErr)
334				if exitCode == 0 && !interrupted && execErr != nil {
335					return fantasy.ToolResponse{}, fmt.Errorf("[Job %s] error executing command: %w", bgShell.ID, execErr)
336				}
337
338				stdout = formatOutput(stdout, stderr, execErr)
339
340				metadata := BashResponseMetadata{
341					StartTime:        startTime.UnixMilli(),
342					EndTime:          time.Now().UnixMilli(),
343					Output:           stdout,
344					Description:      params.Description,
345					Background:       params.RunInBackground,
346					WorkingDirectory: bgShell.WorkingDir,
347				}
348				if stdout == "" {
349					return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
350				}
351				stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(bgShell.WorkingDir))
352				return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
353			}
354
355			// Still running - keep as background job
356			metadata := BashResponseMetadata{
357				StartTime:        startTime.UnixMilli(),
358				EndTime:          time.Now().UnixMilli(),
359				Description:      params.Description,
360				WorkingDirectory: bgShell.WorkingDir,
361				Background:       true,
362				ShellID:          bgShell.ID,
363			}
364			response := fmt.Sprintf("Command is taking longer than expected and has been moved to background.\n\nBackground shell ID: %s\n\nUse job_output tool to view output or job_kill to terminate.", bgShell.ID)
365			return fantasy.WithResponseMetadata(fantasy.NewTextResponse(response), metadata), nil
366		})
367}
368
369// formatOutput formats the output of a completed command with error handling
370func formatOutput(stdout, stderr string, execErr error) string {
371	interrupted := shell.IsInterrupt(execErr)
372	exitCode := shell.ExitCode(execErr)
373
374	stdout = truncateOutput(stdout)
375	stderr = truncateOutput(stderr)
376
377	errorMessage := stderr
378	if errorMessage == "" && execErr != nil {
379		errorMessage = execErr.Error()
380	}
381
382	if interrupted {
383		if errorMessage != "" {
384			errorMessage += "\n"
385		}
386		errorMessage += "Command was aborted before completion"
387	} else if exitCode != 0 {
388		if errorMessage != "" {
389			errorMessage += "\n"
390		}
391		errorMessage += fmt.Sprintf("Exit code %d", exitCode)
392	}
393
394	hasBothOutputs := stdout != "" && stderr != ""
395
396	if hasBothOutputs {
397		stdout += "\n"
398	}
399
400	if errorMessage != "" {
401		stdout += "\n" + errorMessage
402	}
403
404	return stdout
405}
406
407func truncateOutput(content string) string {
408	if len(content) <= MaxOutputLength {
409		return content
410	}
411
412	halfLength := MaxOutputLength / 2
413	start := content[:halfLength]
414	end := content[len(content)-halfLength:]
415
416	truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
417	return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
418}
419
420func countLines(s string) int {
421	if s == "" {
422		return 0
423	}
424	return len(strings.Split(s, "\n"))
425}
426
427func normalizeWorkingDir(path string) string {
428	if runtime.GOOS == "windows" {
429		cwd, err := os.Getwd()
430		if err != nil {
431			cwd = "C:"
432		}
433		path = strings.ReplaceAll(path, filepath.VolumeName(cwd), "")
434	}
435
436	return filepath.ToSlash(path)
437}