1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"fmt"
  8	"html/template"
  9	"os"
 10	"path/filepath"
 11	"runtime"
 12	"strings"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/permission"
 18	"github.com/charmbracelet/crush/internal/shell"
 19)
 20
 21type BashParams struct {
 22	Command     string `json:"command" description:"The command to execute"`
 23	Description string `json:"description,omitempty" description:"A brief description of what the command does"`
 24	Timeout     int    `json:"timeout,omitempty" description:"Optional timeout in milliseconds (max 600000)"`
 25}
 26
 27type BashPermissionsParams struct {
 28	Command     string `json:"command"`
 29	Description string `json:"description"`
 30	Timeout     int    `json:"timeout"`
 31}
 32
 33type BashResponseMetadata struct {
 34	StartTime        int64  `json:"start_time"`
 35	EndTime          int64  `json:"end_time"`
 36	Output           string `json:"output"`
 37	Description      string `json:"description"`
 38	WorkingDirectory string `json:"working_directory"`
 39}
 40
 41const (
 42	BashToolName = "bash"
 43
 44	DefaultTimeout  = 1 * 60 * 1000  // 1 minutes in milliseconds
 45	MaxTimeout      = 10 * 60 * 1000 // 10 minutes in milliseconds
 46	MaxOutputLength = 30000
 47	BashNoOutput    = "no output"
 48)
 49
 50//go:embed bash.tpl
 51var bashDescriptionTmpl []byte
 52
 53var bashDescriptionTpl = template.Must(
 54	template.New("bashDescription").
 55		Parse(string(bashDescriptionTmpl)),
 56)
 57
 58type bashDescriptionData struct {
 59	BannedCommands  string
 60	MaxOutputLength int
 61	Attribution     config.Attribution
 62}
 63
 64var bannedCommands = []string{
 65	// Network/Download tools
 66	"alias",
 67	"aria2c",
 68	"axel",
 69	"chrome",
 70	"curl",
 71	"curlie",
 72	"firefox",
 73	"http-prompt",
 74	"httpie",
 75	"links",
 76	"lynx",
 77	"nc",
 78	"safari",
 79	"scp",
 80	"ssh",
 81	"telnet",
 82	"w3m",
 83	"wget",
 84	"xh",
 85
 86	// System administration
 87	"doas",
 88	"su",
 89	"sudo",
 90
 91	// Package managers
 92	"apk",
 93	"apt",
 94	"apt-cache",
 95	"apt-get",
 96	"dnf",
 97	"dpkg",
 98	"emerge",
 99	"home-manager",
100	"makepkg",
101	"opkg",
102	"pacman",
103	"paru",
104	"pkg",
105	"pkg_add",
106	"pkg_delete",
107	"portage",
108	"rpm",
109	"yay",
110	"yum",
111	"zypper",
112
113	// System modification
114	"at",
115	"batch",
116	"chkconfig",
117	"crontab",
118	"fdisk",
119	"mkfs",
120	"mount",
121	"parted",
122	"service",
123	"systemctl",
124	"umount",
125
126	// Network configuration
127	"firewall-cmd",
128	"ifconfig",
129	"ip",
130	"iptables",
131	"netstat",
132	"pfctl",
133	"route",
134	"ufw",
135}
136
137func bashDescription(attribution *config.Attribution) string {
138	bannedCommandsStr := strings.Join(bannedCommands, ", ")
139	var out bytes.Buffer
140	if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
141		BannedCommands:  bannedCommandsStr,
142		MaxOutputLength: MaxOutputLength,
143		Attribution:     *attribution,
144	}); err != nil {
145		// this should never happen.
146		panic("failed to execute bash description template: " + err.Error())
147	}
148	return out.String()
149}
150
151func blockFuncs() []shell.BlockFunc {
152	return []shell.BlockFunc{
153		shell.CommandsBlocker(bannedCommands),
154
155		// System package managers
156		shell.ArgumentsBlocker("apk", []string{"add"}, nil),
157		shell.ArgumentsBlocker("apt", []string{"install"}, nil),
158		shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
159		shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
160		shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
161		shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
162		shell.ArgumentsBlocker("yum", []string{"install"}, nil),
163		shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
164
165		// Language-specific package managers
166		shell.ArgumentsBlocker("brew", []string{"install"}, nil),
167		shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
168		shell.ArgumentsBlocker("gem", []string{"install"}, nil),
169		shell.ArgumentsBlocker("go", []string{"install"}, nil),
170		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
171		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
172		shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
173		shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
174		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
175		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
176		shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
177
178		// `go test -exec` can run arbitrary commands
179		shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
180	}
181}
182
183func NewBashTool(permissions permission.Service, workingDir string, attribution *config.Attribution) fantasy.AgentTool {
184	// Set up command blocking on the persistent shell
185	persistentShell := shell.GetPersistentShell(workingDir)
186	persistentShell.SetBlockFuncs(blockFuncs())
187	return fantasy.NewAgentTool(
188		BashToolName,
189		string(bashDescription(attribution)),
190		func(ctx context.Context, params BashParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
191			if params.Timeout > MaxTimeout {
192				params.Timeout = MaxTimeout
193			} else if params.Timeout <= 0 {
194				params.Timeout = DefaultTimeout
195			}
196
197			if params.Command == "" {
198				return fantasy.NewTextErrorResponse("missing command"), nil
199			}
200
201			isSafeReadOnly := false
202			cmdLower := strings.ToLower(params.Command)
203
204			for _, safe := range safeCommands {
205				if strings.HasPrefix(cmdLower, safe) {
206					if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
207						isSafeReadOnly = true
208						break
209					}
210				}
211			}
212
213			sessionID := GetSessionFromContext(ctx)
214			if sessionID == "" {
215				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
216			}
217			if !isSafeReadOnly {
218				shell := shell.GetPersistentShell(workingDir)
219				p := permissions.Request(
220					permission.CreatePermissionRequest{
221						SessionID:   sessionID,
222						Path:        shell.GetWorkingDir(),
223						ToolCallID:  call.ID,
224						ToolName:    BashToolName,
225						Action:      "execute",
226						Description: fmt.Sprintf("Execute command: %s", params.Command),
227						Params: BashPermissionsParams{
228							Command:     params.Command,
229							Description: params.Description,
230						},
231					},
232				)
233				if !p {
234					return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
235				}
236			}
237			startTime := time.Now()
238			if params.Timeout > 0 {
239				var cancel context.CancelFunc
240				ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
241				defer cancel()
242			}
243
244			persistentShell := shell.GetPersistentShell(workingDir)
245			stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
246
247			// Get the current working directory after command execution
248			currentWorkingDir := persistentShell.GetWorkingDir()
249			interrupted := shell.IsInterrupt(err)
250			exitCode := shell.ExitCode(err)
251			if exitCode == 0 && !interrupted && err != nil {
252				return fantasy.ToolResponse{}, fmt.Errorf("error executing command: %w", err)
253			}
254
255			stdout = truncateOutput(stdout)
256			stderr = truncateOutput(stderr)
257
258			errorMessage := stderr
259			if errorMessage == "" && err != nil {
260				errorMessage = err.Error()
261			}
262
263			if interrupted {
264				if errorMessage != "" {
265					errorMessage += "\n"
266				}
267				errorMessage += "Command was aborted before completion"
268			} else if exitCode != 0 {
269				if errorMessage != "" {
270					errorMessage += "\n"
271				}
272				errorMessage += fmt.Sprintf("Exit code %d", exitCode)
273			}
274
275			hasBothOutputs := stdout != "" && stderr != ""
276
277			if hasBothOutputs {
278				stdout += "\n"
279			}
280
281			if errorMessage != "" {
282				stdout += "\n" + errorMessage
283			}
284
285			metadata := BashResponseMetadata{
286				StartTime:        startTime.UnixMilli(),
287				EndTime:          time.Now().UnixMilli(),
288				Output:           stdout,
289				Description:      params.Description,
290				WorkingDirectory: currentWorkingDir,
291			}
292			if stdout == "" {
293				return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
294			}
295			stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(currentWorkingDir))
296			return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
297		})
298}
299
300func truncateOutput(content string) string {
301	if len(content) <= MaxOutputLength {
302		return content
303	}
304
305	halfLength := MaxOutputLength / 2
306	start := content[:halfLength]
307	end := content[len(content)-halfLength:]
308
309	truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
310	return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
311}
312
313func countLines(s string) int {
314	if s == "" {
315		return 0
316	}
317	return len(strings.Split(s, "\n"))
318}
319
320func normalizeWorkingDir(path string) string {
321	if runtime.GOOS == "windows" {
322		cwd, err := os.Getwd()
323		if err != nil {
324			cwd = "C:"
325		}
326		path = strings.ReplaceAll(path, filepath.VolumeName(cwd), "")
327	}
328
329	return filepath.ToSlash(path)
330}