1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"fmt"
  9	"html/template"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/permission"
 15	"github.com/charmbracelet/crush/internal/proto"
 16	"github.com/charmbracelet/crush/internal/shell"
 17)
 18
 19type (
 20	BashParams            = proto.BashParams
 21	BashPermissionsParams = proto.BashPermissionsParams
 22	BashResponseMetadata  = proto.BashResponseMetadata
 23)
 24
 25type bashTool struct {
 26	permissions permission.Service
 27	workingDir  string
 28	attribution *config.Attribution
 29}
 30
 31const (
 32	BashToolName = proto.BashToolName
 33
 34	DefaultTimeout  = 1 * 60 * 1000  // 1 minutes in milliseconds
 35	MaxTimeout      = 10 * 60 * 1000 // 10 minutes in milliseconds
 36	MaxOutputLength = 30000
 37	BashNoOutput    = "no output"
 38)
 39
 40//go:embed bash.md
 41var bashDescription []byte
 42
 43var bashDescriptionTpl = template.Must(
 44	template.New("bashDescription").
 45		Parse(string(bashDescription)),
 46)
 47
 48type bashDescriptionData struct {
 49	BannedCommands     string
 50	MaxOutputLength    int
 51	AttributionStep    string
 52	AttributionExample string
 53	PRAttribution      string
 54}
 55
 56var bannedCommands = []string{
 57	// Network/Download tools
 58	"alias",
 59	"aria2c",
 60	"axel",
 61	"chrome",
 62	"curl",
 63	"curlie",
 64	"firefox",
 65	"http-prompt",
 66	"httpie",
 67	"links",
 68	"lynx",
 69	"nc",
 70	"safari",
 71	"scp",
 72	"ssh",
 73	"telnet",
 74	"w3m",
 75	"wget",
 76	"xh",
 77
 78	// System administration
 79	"doas",
 80	"su",
 81	"sudo",
 82
 83	// Package managers
 84	"apk",
 85	"apt",
 86	"apt-cache",
 87	"apt-get",
 88	"dnf",
 89	"dpkg",
 90	"emerge",
 91	"home-manager",
 92	"makepkg",
 93	"opkg",
 94	"pacman",
 95	"paru",
 96	"pkg",
 97	"pkg_add",
 98	"pkg_delete",
 99	"portage",
100	"rpm",
101	"yay",
102	"yum",
103	"zypper",
104
105	// System modification
106	"at",
107	"batch",
108	"chkconfig",
109	"crontab",
110	"fdisk",
111	"mkfs",
112	"mount",
113	"parted",
114	"service",
115	"systemctl",
116	"umount",
117
118	// Network configuration
119	"firewall-cmd",
120	"ifconfig",
121	"ip",
122	"iptables",
123	"netstat",
124	"pfctl",
125	"route",
126	"ufw",
127}
128
129func (b *bashTool) bashDescription() string {
130	bannedCommandsStr := strings.Join(bannedCommands, ", ")
131
132	// Build attribution text based on settings
133	var attributionStep, attributionExample, prAttribution string
134
135	// Default to true if attribution is nil (backward compatibility)
136	generatedWith := b.attribution == nil || b.attribution.GeneratedWith
137	coAuthoredBy := b.attribution == nil || b.attribution.CoAuthoredBy
138
139	// Build PR attribution
140	if generatedWith {
141		prAttribution = "💘 Generated with Crush"
142	}
143
144	if generatedWith || coAuthoredBy {
145		var attributionParts []string
146		if generatedWith {
147			attributionParts = append(attributionParts, "💘 Generated with Crush")
148		}
149		if coAuthoredBy {
150			attributionParts = append(attributionParts, "Co-Authored-By: Crush <crush@charm.land>")
151		}
152
153		if len(attributionParts) > 0 {
154			attributionStep = fmt.Sprintf("4. Create the commit with a message ending with:\n%s", strings.Join(attributionParts, "\n"))
155
156			attributionText := strings.Join(attributionParts, "\n ")
157			attributionExample = fmt.Sprintf(`<example>
158git commit -m "$(cat <<'EOF'
159 Commit message here.
160
161 %s
162 EOF
163)"</example>`, attributionText)
164		}
165	}
166
167	if attributionStep == "" {
168		attributionStep = "4. Create the commit with your commit message."
169		attributionExample = `<example>
170git commit -m "$(cat <<'EOF'
171 Commit message here.
172 EOF
173)"</example>`
174	}
175
176	var out bytes.Buffer
177	if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
178		BannedCommands:     bannedCommandsStr,
179		MaxOutputLength:    MaxOutputLength,
180		AttributionStep:    attributionStep,
181		AttributionExample: attributionExample,
182		PRAttribution:      prAttribution,
183	}); err != nil {
184		// this should never happen.
185		panic("failed to execute bash description template: " + err.Error())
186	}
187	return out.String()
188}
189
190func blockFuncs() []shell.BlockFunc {
191	return []shell.BlockFunc{
192		shell.CommandsBlocker(bannedCommands),
193
194		// System package managers
195		shell.ArgumentsBlocker("apk", []string{"add"}, nil),
196		shell.ArgumentsBlocker("apt", []string{"install"}, nil),
197		shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
198		shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
199		shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
200		shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
201		shell.ArgumentsBlocker("yum", []string{"install"}, nil),
202		shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
203
204		// Language-specific package managers
205		shell.ArgumentsBlocker("brew", []string{"install"}, nil),
206		shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
207		shell.ArgumentsBlocker("gem", []string{"install"}, nil),
208		shell.ArgumentsBlocker("go", []string{"install"}, nil),
209		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
210		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
211		shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
212		shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
213		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
214		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
215		shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
216
217		// `go test -exec` can run arbitrary commands
218		shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
219	}
220}
221
222func NewBashTool(permission permission.Service, workingDir string, attribution *config.Attribution) BaseTool {
223	// Set up command blocking on the persistent shell
224	persistentShell := shell.GetPersistentShell(workingDir)
225	persistentShell.SetBlockFuncs(blockFuncs())
226
227	return &bashTool{
228		permissions: permission,
229		workingDir:  workingDir,
230		attribution: attribution,
231	}
232}
233
234func (b *bashTool) Name() string {
235	return BashToolName
236}
237
238func (b *bashTool) Info() ToolInfo {
239	return ToolInfo{
240		Name:        BashToolName,
241		Description: b.bashDescription(),
242		Parameters: map[string]any{
243			"command": map[string]any{
244				"type":        "string",
245				"description": "The command to execute",
246			},
247			"timeout": map[string]any{
248				"type":        "number",
249				"description": "Optional timeout in milliseconds (max 600000)",
250			},
251		},
252		Required: []string{"command"},
253	}
254}
255
256func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
257	var params BashParams
258	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
259		return NewTextErrorResponse("invalid parameters"), nil
260	}
261
262	if params.Timeout > MaxTimeout {
263		params.Timeout = MaxTimeout
264	} else if params.Timeout <= 0 {
265		params.Timeout = DefaultTimeout
266	}
267
268	if params.Command == "" {
269		return NewTextErrorResponse("missing command"), nil
270	}
271
272	isSafeReadOnly := false
273	cmdLower := strings.ToLower(params.Command)
274
275	for _, safe := range safeCommands {
276		if strings.HasPrefix(cmdLower, safe) {
277			if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
278				isSafeReadOnly = true
279				break
280			}
281		}
282	}
283
284	sessionID, messageID := GetContextValues(ctx)
285	if sessionID == "" || messageID == "" {
286		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for executing shell command")
287	}
288	if !isSafeReadOnly {
289		shell := shell.GetPersistentShell(b.workingDir)
290		p := b.permissions.Request(
291			permission.CreatePermissionRequest{
292				SessionID:   sessionID,
293				Path:        shell.GetWorkingDir(),
294				ToolCallID:  call.ID,
295				ToolName:    BashToolName,
296				Action:      "execute",
297				Description: fmt.Sprintf("Execute command: %s", params.Command),
298				Params: BashPermissionsParams{
299					Command: params.Command,
300				},
301			},
302		)
303		if !p {
304			return ToolResponse{}, permission.ErrorPermissionDenied
305		}
306	}
307	startTime := time.Now()
308	if params.Timeout > 0 {
309		var cancel context.CancelFunc
310		ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
311		defer cancel()
312	}
313
314	persistentShell := shell.GetPersistentShell(b.workingDir)
315	stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
316
317	// Get the current working directory after command execution
318	currentWorkingDir := persistentShell.GetWorkingDir()
319	interrupted := shell.IsInterrupt(err)
320	exitCode := shell.ExitCode(err)
321	if exitCode == 0 && !interrupted && err != nil {
322		return ToolResponse{}, fmt.Errorf("error executing command: %w", err)
323	}
324
325	stdout = truncateOutput(stdout)
326	stderr = truncateOutput(stderr)
327
328	errorMessage := stderr
329	if errorMessage == "" && err != nil {
330		errorMessage = err.Error()
331	}
332
333	if interrupted {
334		if errorMessage != "" {
335			errorMessage += "\n"
336		}
337		errorMessage += "Command was aborted before completion"
338	} else if exitCode != 0 {
339		if errorMessage != "" {
340			errorMessage += "\n"
341		}
342		errorMessage += fmt.Sprintf("Exit code %d", exitCode)
343	}
344
345	hasBothOutputs := stdout != "" && stderr != ""
346
347	if hasBothOutputs {
348		stdout += "\n"
349	}
350
351	if errorMessage != "" {
352		stdout += "\n" + errorMessage
353	}
354
355	metadata := BashResponseMetadata{
356		StartTime:        startTime.UnixMilli(),
357		EndTime:          time.Now().UnixMilli(),
358		Output:           stdout,
359		WorkingDirectory: currentWorkingDir,
360	}
361	if stdout == "" {
362		return WithResponseMetadata(NewTextResponse(BashNoOutput), metadata), nil
363	}
364	stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", currentWorkingDir)
365	return WithResponseMetadata(NewTextResponse(stdout), metadata), nil
366}
367
368func truncateOutput(content string) string {
369	if len(content) <= MaxOutputLength {
370		return content
371	}
372
373	halfLength := MaxOutputLength / 2
374	start := content[:halfLength]
375	end := content[len(content)-halfLength:]
376
377	truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
378	return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
379}
380
381func countLines(s string) int {
382	if s == "" {
383		return 0
384	}
385	return len(strings.Split(s, "\n"))
386}