1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"fmt"
  9	"html/template"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/permission"
 15	"github.com/charmbracelet/crush/internal/shell"
 16)
 17
 18type BashParams struct {
 19	Command string `json:"command"`
 20	Timeout int    `json:"timeout"`
 21}
 22
 23type BashPermissionsParams struct {
 24	Command string `json:"command"`
 25	Timeout int    `json:"timeout"`
 26}
 27
 28type BashResponseMetadata struct {
 29	StartTime        int64  `json:"start_time"`
 30	EndTime          int64  `json:"end_time"`
 31	Output           string `json:"output"`
 32	WorkingDirectory string `json:"working_directory"`
 33}
 34type bashTool struct {
 35	permissions permission.Service
 36	workingDir  string
 37	attribution *config.Attribution
 38}
 39
 40const (
 41	BashToolName = "bash"
 42
 43	DefaultTimeout  = 1 * 60 * 1000  // 1 minutes in milliseconds
 44	MaxTimeout      = 10 * 60 * 1000 // 10 minutes in milliseconds
 45	MaxOutputLength = 30000
 46	BashNoOutput    = "no output"
 47)
 48
 49//go:embed bash.md
 50var bashDescription []byte
 51
 52var bashDescriptionTpl = template.Must(
 53	template.New("bashDescription").
 54		Parse(string(bashDescription)),
 55)
 56
 57type bashDescriptionData struct {
 58	BannedCommands     string
 59	MaxOutputLength    int
 60	AttributionStep    string
 61	AttributionExample string
 62	PRAttribution      string
 63}
 64
 65var bannedCommands = []string{
 66	// Network/Download tools
 67	"alias",
 68	"aria2c",
 69	"axel",
 70	"chrome",
 71	"curl",
 72	"curlie",
 73	"firefox",
 74	"http-prompt",
 75	"httpie",
 76	"links",
 77	"lynx",
 78	"nc",
 79	"safari",
 80	"scp",
 81	"ssh",
 82	"telnet",
 83	"w3m",
 84	"wget",
 85	"xh",
 86
 87	// System administration
 88	"doas",
 89	"su",
 90	"sudo",
 91
 92	// Package managers
 93	"apk",
 94	"apt",
 95	"apt-cache",
 96	"apt-get",
 97	"dnf",
 98	"dpkg",
 99	"emerge",
100	"home-manager",
101	"makepkg",
102	"opkg",
103	"pacman",
104	"paru",
105	"pkg",
106	"pkg_add",
107	"pkg_delete",
108	"portage",
109	"rpm",
110	"yay",
111	"yum",
112	"zypper",
113
114	// System modification
115	"at",
116	"batch",
117	"chkconfig",
118	"crontab",
119	"fdisk",
120	"mkfs",
121	"mount",
122	"parted",
123	"service",
124	"systemctl",
125	"umount",
126
127	// Network configuration
128	"firewall-cmd",
129	"ifconfig",
130	"ip",
131	"iptables",
132	"netstat",
133	"pfctl",
134	"route",
135	"ufw",
136}
137
138func (b *bashTool) bashDescription() string {
139	bannedCommandsStr := strings.Join(bannedCommands, ", ")
140
141	// Build attribution text based on settings
142	var attributionStep, attributionExample, prAttribution string
143
144	// Default to true if attribution is nil (backward compatibility)
145	generatedWith := b.attribution == nil || b.attribution.GeneratedWith
146	coAuthoredBy := b.attribution == nil || b.attribution.CoAuthoredBy
147
148	// Build PR attribution
149	if generatedWith {
150		prAttribution = "💘 Generated with Crush"
151	}
152
153	if generatedWith || coAuthoredBy {
154		var attributionParts []string
155		if generatedWith {
156			attributionParts = append(attributionParts, "💘 Generated with Crush")
157		}
158		if coAuthoredBy {
159			attributionParts = append(attributionParts, "Co-Authored-By: Crush <crush@charm.land>")
160		}
161
162		if len(attributionParts) > 0 {
163			attributionStep = fmt.Sprintf("4. Create the commit with a message ending with:\n%s", strings.Join(attributionParts, "\n"))
164
165			attributionText := strings.Join(attributionParts, "\n ")
166			attributionExample = fmt.Sprintf(`<example>
167git commit -m "$(cat <<'EOF'
168 Commit message here.
169
170 %s
171 EOF
172)"</example>`, attributionText)
173		}
174	}
175
176	if attributionStep == "" {
177		attributionStep = "4. Create the commit with your commit message."
178		attributionExample = `<example>
179git commit -m "$(cat <<'EOF'
180 Commit message here.
181 EOF
182)"</example>`
183	}
184
185	var out bytes.Buffer
186	if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
187		BannedCommands:     bannedCommandsStr,
188		MaxOutputLength:    MaxOutputLength,
189		AttributionStep:    attributionStep,
190		AttributionExample: attributionExample,
191		PRAttribution:      prAttribution,
192	}); err != nil {
193		// this should never happen.
194		panic("failed to execute bash description template: " + err.Error())
195	}
196	return out.String()
197}
198
199func blockFuncs() []shell.BlockFunc {
200	return []shell.BlockFunc{
201		shell.CommandsBlocker(bannedCommands),
202
203		// System package managers
204		shell.ArgumentsBlocker("apk", []string{"add"}, nil),
205		shell.ArgumentsBlocker("apt", []string{"install"}, nil),
206		shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
207		shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
208		shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
209		shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
210		shell.ArgumentsBlocker("yum", []string{"install"}, nil),
211		shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
212
213		// Language-specific package managers
214		shell.ArgumentsBlocker("brew", []string{"install"}, nil),
215		shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
216		shell.ArgumentsBlocker("gem", []string{"install"}, nil),
217		shell.ArgumentsBlocker("go", []string{"install"}, nil),
218		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
219		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
220		shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
221		shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
222		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
223		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
224		shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
225
226		// `go test -exec` can run arbitrary commands
227		shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
228	}
229}
230
231func NewBashTool(permission permission.Service, workingDir string, attribution *config.Attribution) BaseTool {
232	// Set up command blocking on the persistent shell
233	persistentShell := shell.GetPersistentShell(workingDir)
234	persistentShell.SetBlockFuncs(blockFuncs())
235
236	return &bashTool{
237		permissions: permission,
238		workingDir:  workingDir,
239		attribution: attribution,
240	}
241}
242
243func (b *bashTool) Name() string {
244	return BashToolName
245}
246
247func (b *bashTool) Info() ToolInfo {
248	return ToolInfo{
249		Name:        BashToolName,
250		Description: b.bashDescription(),
251		Parameters: map[string]any{
252			"command": map[string]any{
253				"type":        "string",
254				"description": "The command to execute",
255			},
256			"timeout": map[string]any{
257				"type":        "number",
258				"description": "Optional timeout in milliseconds (max 600000)",
259			},
260		},
261		Required: []string{"command"},
262	}
263}
264
265func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
266	var params BashParams
267	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
268		return NewTextErrorResponse("invalid parameters"), nil
269	}
270
271	if params.Timeout > MaxTimeout {
272		params.Timeout = MaxTimeout
273	} else if params.Timeout <= 0 {
274		params.Timeout = DefaultTimeout
275	}
276
277	if params.Command == "" {
278		return NewTextErrorResponse("missing command"), nil
279	}
280
281	isSafeReadOnly := false
282	cmdLower := strings.ToLower(params.Command)
283
284	for _, safe := range safeCommands {
285		if strings.HasPrefix(cmdLower, safe) {
286			if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
287				isSafeReadOnly = true
288				break
289			}
290		}
291	}
292
293	sessionID, messageID := GetContextValues(ctx)
294	if sessionID == "" || messageID == "" {
295		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for executing shell command")
296	}
297	if !isSafeReadOnly {
298		shell := shell.GetPersistentShell(b.workingDir)
299		p := b.permissions.Request(
300			permission.CreatePermissionRequest{
301				SessionID:   sessionID,
302				Path:        shell.GetWorkingDir(),
303				ToolCallID:  call.ID,
304				ToolName:    BashToolName,
305				Action:      "execute",
306				Description: fmt.Sprintf("Execute command: %s", params.Command),
307				Params: BashPermissionsParams{
308					Command: params.Command,
309				},
310			},
311		)
312		if !p {
313			return ToolResponse{}, permission.ErrorPermissionDenied
314		}
315	}
316	startTime := time.Now()
317	if params.Timeout > 0 {
318		var cancel context.CancelFunc
319		ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
320		defer cancel()
321	}
322
323	persistentShell := shell.GetPersistentShell(b.workingDir)
324	stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
325
326	// Get the current working directory after command execution
327	currentWorkingDir := persistentShell.GetWorkingDir()
328	interrupted := shell.IsInterrupt(err)
329	exitCode := shell.ExitCode(err)
330	if exitCode == 0 && !interrupted && err != nil {
331		return ToolResponse{}, fmt.Errorf("error executing command: %w", err)
332	}
333
334	stdout = truncateOutput(stdout)
335	stderr = truncateOutput(stderr)
336
337	errorMessage := stderr
338	if errorMessage == "" && err != nil {
339		errorMessage = err.Error()
340	}
341
342	if interrupted {
343		if errorMessage != "" {
344			errorMessage += "\n"
345		}
346		errorMessage += "Command was aborted before completion"
347	} else if exitCode != 0 {
348		if errorMessage != "" {
349			errorMessage += "\n"
350		}
351		errorMessage += fmt.Sprintf("Exit code %d", exitCode)
352	}
353
354	hasBothOutputs := stdout != "" && stderr != ""
355
356	if hasBothOutputs {
357		stdout += "\n"
358	}
359
360	if errorMessage != "" {
361		stdout += "\n" + errorMessage
362	}
363
364	metadata := BashResponseMetadata{
365		StartTime:        startTime.UnixMilli(),
366		EndTime:          time.Now().UnixMilli(),
367		Output:           stdout,
368		WorkingDirectory: currentWorkingDir,
369	}
370	if stdout == "" {
371		return WithResponseMetadata(NewTextResponse(BashNoOutput), metadata), nil
372	}
373	stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", currentWorkingDir)
374	return WithResponseMetadata(NewTextResponse(stdout), metadata), nil
375}
376
377func truncateOutput(content string) string {
378	if len(content) <= MaxOutputLength {
379		return content
380	}
381
382	halfLength := MaxOutputLength / 2
383	start := content[:halfLength]
384	end := content[len(content)-halfLength:]
385
386	truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
387	return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
388}
389
390func countLines(s string) int {
391	if s == "" {
392		return 0
393	}
394	return len(strings.Split(s, "\n"))
395}