bash.go

  1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"fmt"
  8	"html/template"
  9	"os"
 10	"path/filepath"
 11	"runtime"
 12	"strings"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/permission"
 18	"github.com/charmbracelet/crush/internal/shell"
 19)
 20
 21type BashParams struct {
 22	Command     string `json:"command" description:"The command to execute"`
 23	Description string `json:"description,omitempty" description:"A brief description of what the command does"`
 24	WorkingDir  string `json:"working_dir,omitempty" description:"The working directory to execute the command in (defaults to current directory)"`
 25	Background  bool   `json:"background,omitempty" description:"Run the command in a background shell. Returns a shell ID for managing the process."`
 26}
 27
 28type BashPermissionsParams struct {
 29	Command     string `json:"command"`
 30	Description string `json:"description"`
 31	WorkingDir  string `json:"working_dir"`
 32	Background  bool   `json:"background"`
 33}
 34
 35type BashResponseMetadata struct {
 36	StartTime        int64  `json:"start_time"`
 37	EndTime          int64  `json:"end_time"`
 38	Output           string `json:"output"`
 39	Description      string `json:"description"`
 40	WorkingDirectory string `json:"working_directory"`
 41	Background       bool   `json:"background,omitempty"`
 42	ShellID          string `json:"shell_id,omitempty"`
 43}
 44
 45const (
 46	BashToolName = "bash"
 47
 48	AutoBackgroundThreshold = 1 * time.Minute // Commands taking longer automatically become background jobs
 49	MaxOutputLength         = 30000
 50	BashNoOutput            = "no output"
 51)
 52
 53//go:embed bash.tpl
 54var bashDescriptionTmpl []byte
 55
 56var bashDescriptionTpl = template.Must(
 57	template.New("bashDescription").
 58		Parse(string(bashDescriptionTmpl)),
 59)
 60
 61type bashDescriptionData struct {
 62	BannedCommands  string
 63	MaxOutputLength int
 64	Attribution     config.Attribution
 65}
 66
 67var bannedCommands = []string{
 68	// Network/Download tools
 69	"alias",
 70	"aria2c",
 71	"axel",
 72	"chrome",
 73	"curl",
 74	"curlie",
 75	"firefox",
 76	"http-prompt",
 77	"httpie",
 78	"links",
 79	"lynx",
 80	"nc",
 81	"safari",
 82	"scp",
 83	"ssh",
 84	"telnet",
 85	"w3m",
 86	"wget",
 87	"xh",
 88
 89	// System administration
 90	"doas",
 91	"su",
 92	"sudo",
 93
 94	// Package managers
 95	"apk",
 96	"apt",
 97	"apt-cache",
 98	"apt-get",
 99	"dnf",
100	"dpkg",
101	"emerge",
102	"home-manager",
103	"makepkg",
104	"opkg",
105	"pacman",
106	"paru",
107	"pkg",
108	"pkg_add",
109	"pkg_delete",
110	"portage",
111	"rpm",
112	"yay",
113	"yum",
114	"zypper",
115
116	// System modification
117	"at",
118	"batch",
119	"chkconfig",
120	"crontab",
121	"fdisk",
122	"mkfs",
123	"mount",
124	"parted",
125	"service",
126	"systemctl",
127	"umount",
128
129	// Network configuration
130	"firewall-cmd",
131	"ifconfig",
132	"ip",
133	"iptables",
134	"netstat",
135	"pfctl",
136	"route",
137	"ufw",
138}
139
140func bashDescription(attribution *config.Attribution) string {
141	bannedCommandsStr := strings.Join(bannedCommands, ", ")
142	var out bytes.Buffer
143	if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
144		BannedCommands:  bannedCommandsStr,
145		MaxOutputLength: MaxOutputLength,
146		Attribution:     *attribution,
147	}); err != nil {
148		// this should never happen.
149		panic("failed to execute bash description template: " + err.Error())
150	}
151	return out.String()
152}
153
154func blockFuncs() []shell.BlockFunc {
155	return []shell.BlockFunc{
156		shell.CommandsBlocker(bannedCommands),
157
158		// System package managers
159		shell.ArgumentsBlocker("apk", []string{"add"}, nil),
160		shell.ArgumentsBlocker("apt", []string{"install"}, nil),
161		shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
162		shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
163		shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
164		shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
165		shell.ArgumentsBlocker("yum", []string{"install"}, nil),
166		shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
167
168		// Language-specific package managers
169		shell.ArgumentsBlocker("brew", []string{"install"}, nil),
170		shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
171		shell.ArgumentsBlocker("gem", []string{"install"}, nil),
172		shell.ArgumentsBlocker("go", []string{"install"}, nil),
173		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
174		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
175		shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
176		shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
177		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
178		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
179		shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
180
181		// `go test -exec` can run arbitrary commands
182		shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
183	}
184}
185
186func NewBashTool(permissions permission.Service, workingDir string, attribution *config.Attribution) fantasy.AgentTool {
187	return fantasy.NewAgentTool(
188		BashToolName,
189		string(bashDescription(attribution)),
190		func(ctx context.Context, params BashParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
191			if params.Command == "" {
192				return fantasy.NewTextErrorResponse("missing command"), nil
193			}
194
195			// Determine working directory
196			execWorkingDir := workingDir
197			if params.WorkingDir != "" {
198				execWorkingDir = params.WorkingDir
199			}
200
201			isSafeReadOnly := false
202			cmdLower := strings.ToLower(params.Command)
203
204			for _, safe := range safeCommands {
205				if strings.HasPrefix(cmdLower, safe) {
206					if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
207						isSafeReadOnly = true
208						break
209					}
210				}
211			}
212
213			sessionID := GetSessionFromContext(ctx)
214			if sessionID == "" {
215				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
216			}
217			if !isSafeReadOnly {
218				p := permissions.Request(
219					permission.CreatePermissionRequest{
220						SessionID:   sessionID,
221						Path:        execWorkingDir,
222						ToolCallID:  call.ID,
223						ToolName:    BashToolName,
224						Action:      "execute",
225						Description: fmt.Sprintf("Execute command: %s", params.Command),
226						Params:      BashPermissionsParams(params),
227					},
228				)
229				if !p {
230					return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
231				}
232			}
233
234			// If explicitly requested as background, start immediately
235			if params.Background {
236				startTime := time.Now()
237				bgManager := shell.GetBackgroundShellManager()
238				bgShell, err := bgManager.Start(ctx, execWorkingDir, blockFuncs(), params.Command)
239				if err != nil {
240					return fantasy.ToolResponse{}, fmt.Errorf("error starting background shell: %w", err)
241				}
242
243				metadata := BashResponseMetadata{
244					StartTime:        startTime.UnixMilli(),
245					EndTime:          time.Now().UnixMilli(),
246					Description:      params.Description,
247					WorkingDirectory: bgShell.GetWorkingDir(),
248					Background:       true,
249					ShellID:          bgShell.ID,
250				}
251				response := fmt.Sprintf("Background shell started with ID: %s\n\nUse bash_output tool to view output or bash_kill to terminate.", bgShell.ID)
252				return fantasy.WithResponseMetadata(fantasy.NewTextResponse(response), metadata), nil
253			}
254
255			// Start synchronous execution with auto-background support
256			startTime := time.Now()
257
258			// Start background shell immediately but wait for threshold before deciding
259			bgManager := shell.GetBackgroundShellManager()
260			bgShell, err := bgManager.Start(ctx, execWorkingDir, blockFuncs(), params.Command)
261			if err != nil {
262				return fantasy.ToolResponse{}, fmt.Errorf("error starting shell: %w", err)
263			}
264
265			// Wait for either completion, auto-background threshold, or context cancellation
266			ticker := time.NewTicker(100 * time.Millisecond)
267			defer ticker.Stop()
268			timeout := time.After(AutoBackgroundThreshold)
269
270			var stdout, stderr string
271			var done bool
272			var execErr error
273
274		waitLoop:
275			for {
276				select {
277				case <-ticker.C:
278					stdout, stderr, done, execErr = bgShell.GetOutput()
279					if done {
280						break waitLoop
281					}
282				case <-timeout:
283					stdout, stderr, done, execErr = bgShell.GetOutput()
284					break waitLoop
285				case <-ctx.Done():
286					// Context was cancelled, kill the shell and return error
287					bgManager.Kill(bgShell.ID)
288					return fantasy.ToolResponse{}, ctx.Err()
289				}
290			}
291
292			if done {
293				// Command completed within threshold - return synchronously
294				// Remove from background manager since we're returning directly
295				// Don't call Kill() as it cancels the context and corrupts the exit code
296				bgManager.Remove(bgShell.ID)
297
298				interrupted := shell.IsInterrupt(execErr)
299				exitCode := shell.ExitCode(execErr)
300				if exitCode == 0 && !interrupted && execErr != nil {
301					return fantasy.ToolResponse{}, fmt.Errorf("error executing command: %w", execErr)
302				}
303
304				stdout = truncateOutput(stdout)
305				stderr = truncateOutput(stderr)
306
307				errorMessage := stderr
308				if errorMessage == "" && execErr != nil {
309					errorMessage = execErr.Error()
310				}
311
312				if interrupted {
313					if errorMessage != "" {
314						errorMessage += "\n"
315					}
316					errorMessage += "Command was aborted before completion"
317				} else if exitCode != 0 {
318					if errorMessage != "" {
319						errorMessage += "\n"
320					}
321					errorMessage += fmt.Sprintf("Exit code %d", exitCode)
322				}
323
324				hasBothOutputs := stdout != "" && stderr != ""
325
326				if hasBothOutputs {
327					stdout += "\n"
328				}
329
330				if errorMessage != "" {
331					stdout += "\n" + errorMessage
332				}
333
334				metadata := BashResponseMetadata{
335					StartTime:        startTime.UnixMilli(),
336					EndTime:          time.Now().UnixMilli(),
337					Output:           stdout,
338					Description:      params.Description,
339					WorkingDirectory: bgShell.GetWorkingDir(),
340				}
341				if stdout == "" {
342					return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
343				}
344				stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(bgShell.GetWorkingDir()))
345				return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
346			}
347
348			// Still running - keep as background job
349			metadata := BashResponseMetadata{
350				StartTime:        startTime.UnixMilli(),
351				EndTime:          time.Now().UnixMilli(),
352				Description:      params.Description,
353				WorkingDirectory: bgShell.GetWorkingDir(),
354				Background:       true,
355				ShellID:          bgShell.ID,
356			}
357			response := fmt.Sprintf("Command is taking longer than expected and has been moved to background.\n\nBackground shell ID: %s\n\nUse bash_output tool to view output or bash_kill to terminate.", bgShell.ID)
358			return fantasy.WithResponseMetadata(fantasy.NewTextResponse(response), metadata), nil
359		})
360}
361
362func truncateOutput(content string) string {
363	if len(content) <= MaxOutputLength {
364		return content
365	}
366
367	halfLength := MaxOutputLength / 2
368	start := content[:halfLength]
369	end := content[len(content)-halfLength:]
370
371	truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
372	return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
373}
374
375func countLines(s string) int {
376	if s == "" {
377		return 0
378	}
379	return len(strings.Split(s, "\n"))
380}
381
382func normalizeWorkingDir(path string) string {
383	if runtime.GOOS == "windows" {
384		cwd, err := os.Getwd()
385		if err != nil {
386			cwd = "C:"
387		}
388		path = strings.ReplaceAll(path, filepath.VolumeName(cwd), "")
389	}
390
391	return filepath.ToSlash(path)
392}