bash.go

  1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"fmt"
  8	"html/template"
  9	"os"
 10	"path/filepath"
 11	"runtime"
 12	"strings"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/permission"
 18	"github.com/charmbracelet/crush/internal/shell"
 19)
 20
 21type BashParams struct {
 22	Command         string `json:"command" description:"The command to execute"`
 23	Description     string `json:"description,omitempty" description:"A brief description of what the command does"`
 24	WorkingDir      string `json:"working_dir,omitempty" description:"The working directory to execute the command in (defaults to current directory)"`
 25	RunInBackground bool   `json:"run_in_background,omitempty" description:"Set to true (boolean) to run this command in the background. Use bash_output to read the output later."`
 26}
 27
 28type BashPermissionsParams struct {
 29	Command         string `json:"command"`
 30	Description     string `json:"description"`
 31	WorkingDir      string `json:"working_dir"`
 32	RunInBackground bool   `json:"run_in_background"`
 33}
 34
 35type BashResponseMetadata struct {
 36	StartTime        int64  `json:"start_time"`
 37	EndTime          int64  `json:"end_time"`
 38	Output           string `json:"output"`
 39	Description      string `json:"description"`
 40	WorkingDirectory string `json:"working_directory"`
 41	Background       bool   `json:"background,omitempty"`
 42	ShellID          string `json:"shell_id,omitempty"`
 43}
 44
 45const (
 46	BashToolName = "bash"
 47
 48	AutoBackgroundThreshold = 1 * time.Minute // Commands taking longer automatically become background jobs
 49	MaxOutputLength         = 30000
 50	BashNoOutput            = "no output"
 51)
 52
 53//go:embed bash.tpl
 54var bashDescriptionTmpl []byte
 55
 56var bashDescriptionTpl = template.Must(
 57	template.New("bashDescription").
 58		Parse(string(bashDescriptionTmpl)),
 59)
 60
 61type bashDescriptionData struct {
 62	BannedCommands  string
 63	MaxOutputLength int
 64	Attribution     config.Attribution
 65}
 66
 67var bannedCommands = []string{
 68	// Network/Download tools
 69	"alias",
 70	"aria2c",
 71	"axel",
 72	"chrome",
 73	"curl",
 74	"curlie",
 75	"firefox",
 76	"http-prompt",
 77	"httpie",
 78	"links",
 79	"lynx",
 80	"nc",
 81	"safari",
 82	"scp",
 83	"ssh",
 84	"telnet",
 85	"w3m",
 86	"wget",
 87	"xh",
 88
 89	// System administration
 90	"doas",
 91	"su",
 92	"sudo",
 93
 94	// Package managers
 95	"apk",
 96	"apt",
 97	"apt-cache",
 98	"apt-get",
 99	"dnf",
100	"dpkg",
101	"emerge",
102	"home-manager",
103	"makepkg",
104	"opkg",
105	"pacman",
106	"paru",
107	"pkg",
108	"pkg_add",
109	"pkg_delete",
110	"portage",
111	"rpm",
112	"yay",
113	"yum",
114	"zypper",
115
116	// System modification
117	"at",
118	"batch",
119	"chkconfig",
120	"crontab",
121	"fdisk",
122	"mkfs",
123	"mount",
124	"parted",
125	"service",
126	"systemctl",
127	"umount",
128
129	// Network configuration
130	"firewall-cmd",
131	"ifconfig",
132	"ip",
133	"iptables",
134	"netstat",
135	"pfctl",
136	"route",
137	"ufw",
138}
139
140func bashDescription(attribution *config.Attribution) string {
141	bannedCommandsStr := strings.Join(bannedCommands, ", ")
142	var out bytes.Buffer
143	if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
144		BannedCommands:  bannedCommandsStr,
145		MaxOutputLength: MaxOutputLength,
146		Attribution:     *attribution,
147	}); err != nil {
148		// this should never happen.
149		panic("failed to execute bash description template: " + err.Error())
150	}
151	return out.String()
152}
153
154func blockFuncs() []shell.BlockFunc {
155	return []shell.BlockFunc{
156		shell.CommandsBlocker(bannedCommands),
157
158		// System package managers
159		shell.ArgumentsBlocker("apk", []string{"add"}, nil),
160		shell.ArgumentsBlocker("apt", []string{"install"}, nil),
161		shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
162		shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
163		shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
164		shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
165		shell.ArgumentsBlocker("yum", []string{"install"}, nil),
166		shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
167
168		// Language-specific package managers
169		shell.ArgumentsBlocker("brew", []string{"install"}, nil),
170		shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
171		shell.ArgumentsBlocker("gem", []string{"install"}, nil),
172		shell.ArgumentsBlocker("go", []string{"install"}, nil),
173		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
174		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
175		shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
176		shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
177		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
178		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
179		shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
180
181		// `go test -exec` can run arbitrary commands
182		shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
183	}
184}
185
186func NewBashTool(permissions permission.Service, workingDir string, attribution *config.Attribution) fantasy.AgentTool {
187	return fantasy.NewAgentTool(
188		BashToolName,
189		string(bashDescription(attribution)),
190		func(ctx context.Context, params BashParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
191			if params.Command == "" {
192				return fantasy.NewTextErrorResponse("missing command"), nil
193			}
194
195			// Determine working directory
196			execWorkingDir := workingDir
197			if params.WorkingDir != "" {
198				execWorkingDir = params.WorkingDir
199			}
200
201			isSafeReadOnly := false
202			cmdLower := strings.ToLower(params.Command)
203
204			for _, safe := range safeCommands {
205				if strings.HasPrefix(cmdLower, safe) {
206					if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
207						isSafeReadOnly = true
208						break
209					}
210				}
211			}
212
213			sessionID := GetSessionFromContext(ctx)
214			if sessionID == "" {
215				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
216			}
217			if !isSafeReadOnly {
218				p := permissions.Request(
219					permission.CreatePermissionRequest{
220						SessionID:   sessionID,
221						Path:        execWorkingDir,
222						ToolCallID:  call.ID,
223						ToolName:    BashToolName,
224						Action:      "execute",
225						Description: fmt.Sprintf("Execute command: %s", params.Command),
226						Params:      BashPermissionsParams(params),
227					},
228				)
229				if !p {
230					return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
231				}
232			}
233
234			// If explicitly requested as background, start immediately with detached context
235			if params.RunInBackground {
236				startTime := time.Now()
237				bgManager := shell.GetBackgroundShellManager()
238				// Use background context so it continues after tool returns
239				bgShell, err := bgManager.Start(context.Background(), execWorkingDir, blockFuncs(), params.Command)
240				if err != nil {
241					return fantasy.ToolResponse{}, fmt.Errorf("error starting background shell: %w", err)
242				}
243
244				metadata := BashResponseMetadata{
245					StartTime:        startTime.UnixMilli(),
246					EndTime:          time.Now().UnixMilli(),
247					Description:      params.Description,
248					WorkingDirectory: bgShell.GetWorkingDir(),
249					Background:       true,
250					ShellID:          bgShell.ID,
251				}
252				response := fmt.Sprintf("Background shell started with ID: %s\n\nUse bash_output tool to view output or bash_kill to terminate.", bgShell.ID)
253				return fantasy.WithResponseMetadata(fantasy.NewTextResponse(response), metadata), nil
254			}
255
256			// Start synchronous execution with auto-background support
257			startTime := time.Now()
258
259			// Start with detached context so it can survive if moved to background
260			bgManager := shell.GetBackgroundShellManager()
261			bgShell, err := bgManager.Start(context.Background(), execWorkingDir, blockFuncs(), params.Command)
262			if err != nil {
263				return fantasy.ToolResponse{}, fmt.Errorf("error starting shell: %w", err)
264			}
265
266			// Wait for either completion, auto-background threshold, or context cancellation
267			ticker := time.NewTicker(100 * time.Millisecond)
268			defer ticker.Stop()
269			timeout := time.After(AutoBackgroundThreshold)
270
271			var stdout, stderr string
272			var done bool
273			var execErr error
274
275		waitLoop:
276			for {
277				select {
278				case <-ticker.C:
279					stdout, stderr, done, execErr = bgShell.GetOutput()
280					if done {
281						break waitLoop
282					}
283				case <-timeout:
284					stdout, stderr, done, execErr = bgShell.GetOutput()
285					break waitLoop
286				case <-ctx.Done():
287					// Incoming context was cancelled before we moved to background
288					// Kill the shell and return error
289					bgManager.Kill(bgShell.ID)
290					return fantasy.ToolResponse{}, ctx.Err()
291				}
292			}
293
294			if done {
295				// Command completed within threshold - return synchronously
296				// Remove from background manager since we're returning directly
297				// Don't call Kill() as it cancels the context and corrupts the exit code
298				bgManager.Remove(bgShell.ID)
299
300				interrupted := shell.IsInterrupt(execErr)
301				exitCode := shell.ExitCode(execErr)
302				if exitCode == 0 && !interrupted && execErr != nil {
303					return fantasy.ToolResponse{}, fmt.Errorf("error executing command: %w", execErr)
304				}
305
306				stdout = truncateOutput(stdout)
307				stderr = truncateOutput(stderr)
308
309				errorMessage := stderr
310				if errorMessage == "" && execErr != nil {
311					errorMessage = execErr.Error()
312				}
313
314				if interrupted {
315					if errorMessage != "" {
316						errorMessage += "\n"
317					}
318					errorMessage += "Command was aborted before completion"
319				} else if exitCode != 0 {
320					if errorMessage != "" {
321						errorMessage += "\n"
322					}
323					errorMessage += fmt.Sprintf("Exit code %d", exitCode)
324				}
325
326				hasBothOutputs := stdout != "" && stderr != ""
327
328				if hasBothOutputs {
329					stdout += "\n"
330				}
331
332				if errorMessage != "" {
333					stdout += "\n" + errorMessage
334				}
335
336				metadata := BashResponseMetadata{
337					StartTime:        startTime.UnixMilli(),
338					EndTime:          time.Now().UnixMilli(),
339					Output:           stdout,
340					Description:      params.Description,
341					WorkingDirectory: bgShell.GetWorkingDir(),
342				}
343				if stdout == "" {
344					return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
345				}
346				stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(bgShell.GetWorkingDir()))
347				return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
348			}
349
350			// Still running - keep as background job
351			metadata := BashResponseMetadata{
352				StartTime:        startTime.UnixMilli(),
353				EndTime:          time.Now().UnixMilli(),
354				Description:      params.Description,
355				WorkingDirectory: bgShell.GetWorkingDir(),
356				Background:       true,
357				ShellID:          bgShell.ID,
358			}
359			response := fmt.Sprintf("Command is taking longer than expected and has been moved to background.\n\nBackground shell ID: %s\n\nUse bash_output tool to view output or bash_kill to terminate.", bgShell.ID)
360			return fantasy.WithResponseMetadata(fantasy.NewTextResponse(response), metadata), nil
361		})
362}
363
364func truncateOutput(content string) string {
365	if len(content) <= MaxOutputLength {
366		return content
367	}
368
369	halfLength := MaxOutputLength / 2
370	start := content[:halfLength]
371	end := content[len(content)-halfLength:]
372
373	truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
374	return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
375}
376
377func countLines(s string) int {
378	if s == "" {
379		return 0
380	}
381	return len(strings.Split(s, "\n"))
382}
383
384func normalizeWorkingDir(path string) string {
385	if runtime.GOOS == "windows" {
386		cwd, err := os.Getwd()
387		if err != nil {
388			cwd = "C:"
389		}
390		path = strings.ReplaceAll(path, filepath.VolumeName(cwd), "")
391	}
392
393	return filepath.ToSlash(path)
394}