From a07dd6bc0d29329fb0cf2dffd46a9521f7b43849 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 5 Jul 2025 17:47:22 +0200 Subject: [PATCH] chore: make tools config independent --- cmd/root.go | 6 +++--- internal/llm/agent/agent.go | 19 ++++++++++--------- internal/llm/agent/mcp-tools.go | 20 +++++++++++--------- internal/llm/prompt/coder.go | 2 +- internal/llm/tools/bash.go | 9 +++++---- internal/llm/tools/edit.go | 20 ++++++++++---------- internal/llm/tools/fetch.go | 7 ++++--- internal/llm/tools/glob.go | 13 ++++++++----- internal/llm/tools/grep.go | 13 ++++++++----- internal/llm/tools/ls.go | 15 +++++++++------ internal/llm/tools/view.go | 9 +++++---- internal/llm/tools/write.go | 11 ++++++----- 12 files changed, 80 insertions(+), 64 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 3c1ccf6825c54c20c7ed41c8fba97508ff5c15e0..5747145284fbed25201b58a6b43eccf975db1304 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -100,7 +100,7 @@ to assist developers in writing, debugging, and understanding code directly from defer app.Shutdown() // Initialize MCP tools early for both modes - initMCPTools(ctx, app) + initMCPTools(ctx, app, cfg) prompt, err = maybePrependStdin(prompt) if err != nil { @@ -192,7 +192,7 @@ func attemptTUIRecovery(program *tea.Program) { program.Quit() } -func initMCPTools(ctx context.Context, app *app.App) { +func initMCPTools(ctx context.Context, app *app.App, cfg *config.Config) { go func() { defer log.RecoverPanic("MCP-goroutine", nil) @@ -201,7 +201,7 @@ func initMCPTools(ctx context.Context, app *app.App) { defer cancel() // Set this up once with proper error handling - agent.GetMcpTools(ctxWithTimeout, app.Permissions) + agent.GetMcpTools(ctxWithTimeout, app.Permissions, cfg) slog.Info("MCP message handling goroutine exiting") }() } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 28044d2bc058a5e07996546eb5a260c20936de34..6d9a825f600e79d3161c6b653669abe8773db116 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -94,21 +94,22 @@ func NewAgent( ) (Service, error) { ctx := context.Background() cfg := config.Get() - otherTools := GetMcpTools(ctx, permissions) + otherTools := GetMcpTools(ctx, permissions, cfg) if len(lspClients) > 0 { otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) } + cwd := cfg.WorkingDir() allTools := []tools.BaseTool{ - tools.NewBashTool(permissions), - tools.NewEditTool(lspClients, permissions, history), - tools.NewFetchTool(permissions), - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), + tools.NewBashTool(permissions, cwd), + tools.NewEditTool(lspClients, permissions, history, cwd), + tools.NewFetchTool(permissions, cwd), + tools.NewGlobTool(cwd), + tools.NewGrepTool(cwd), + tools.NewLsTool(cwd), tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - tools.NewWriteTool(lspClients, permissions, history), + tools.NewViewTool(lspClients, cwd), + tools.NewWriteTool(lspClients, permissions, history, cwd), } if agentCfg.ID == "coder" { diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index a0826bdf75895c378aac8d0c4a823b9a094a5e79..bb9b497722151c137dba1575390744230a3d6e99 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -22,6 +22,7 @@ type mcpTool struct { tool mcp.Tool mcpConfig config.MCPConfig permissions permission.Service + workingDir string } type MCPClient interface { @@ -98,7 +99,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes p := b.permissions.Request( permission.CreatePermissionRequest{ SessionID: sessionID, - Path: config.Get().WorkingDir(), + Path: b.workingDir, ToolName: b.Info().Name, Action: "execute", Description: permissionDescription, @@ -143,18 +144,19 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("invalid mcp type"), nil } -func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig) tools.BaseTool { +func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool { return &mcpTool{ mcpName: name, tool: tool, mcpConfig: mcpConfig, permissions: permissions, + workingDir: workingDir, } } var mcpTools []tools.BaseTool -func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool { +func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool { var stdioTools []tools.BaseTool initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION @@ -175,17 +177,17 @@ func getTools(ctx context.Context, name string, m config.MCPConfig, permissions return stdioTools } for _, t := range tools.Tools { - stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m)) + stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m, workingDir)) } defer c.Close() return stdioTools } -func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool { +func GetMcpTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool { if len(mcpTools) > 0 { return mcpTools } - for name, m := range config.Get().MCP { + for name, m := range cfg.MCP { switch m.Type { case config.MCPStdio: c, err := client.NewStdioMCPClient( @@ -198,7 +200,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba continue } - mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...) + mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) case config.MCPHttp: c, err := client.NewStreamableHttpClient( m.URL, @@ -208,7 +210,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba slog.Error("error creating mcp client", "error", err) continue } - mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...) + mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) case config.MCPSse: c, err := client.NewSSEMCPClient( m.URL, @@ -218,7 +220,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba slog.Error("error creating mcp client", "error", err) continue } - mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...) + mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) } } diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 4a1dceb1cf5949de4520a25a3e7703e8435eadd3..dfe2068cd45edf515291b2d759fac4e133912980 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -384,7 +384,7 @@ func getEnvironmentInfo() string { isGit := isGitRepo(cwd) platform := runtime.GOOS date := time.Now().Format("1/2/2006") - ls := tools.NewLsTool() + ls := tools.NewLsTool(cwd) r, _ := ls.Run(context.Background(), tools.ToolCall{ Input: `{"path":"."}`, }) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 03d3af32d95fd032d2f0f7092d66493c60867db8..0a10568a39315f6c4077385b8ca83f6b3e52691c 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -8,7 +8,6 @@ import ( "strings" "time" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/shell" ) @@ -29,6 +28,7 @@ type BashResponseMetadata struct { } type bashTool struct { permissions permission.Service + workingDir string } const ( @@ -244,9 +244,10 @@ Important: - Never update git config`, bannedCommandsStr, MaxOutputLength) } -func NewBashTool(permission permission.Service) BaseTool { +func NewBashTool(permission permission.Service, workingDir string) BaseTool { return &bashTool{ permissions: permission, + workingDir: workingDir, } } @@ -317,7 +318,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) p := b.permissions.Request( permission.CreatePermissionRequest{ SessionID: sessionID, - Path: config.Get().WorkingDir(), + Path: b.workingDir, ToolName: BashToolName, Action: "execute", Description: fmt.Sprintf("Execute command: %s", params.Command), @@ -337,7 +338,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) defer cancel() } stdout, stderr, err := shell. - GetPersistentShell(config.Get().WorkingDir()). + GetPersistentShell(b.workingDir). Exec(ctx, params.Command) interrupted := shell.IsInterrupt(err) exitCode := shell.ExitCode(err) diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 5b27223df8b6dfd4eb3ecf1bc5eb5762b17879a9..e09151781cf7f3c53fd0d23de46f1b9ca7dd3607 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -10,7 +10,6 @@ import ( "strings" "time" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/history" @@ -41,6 +40,7 @@ type editTool struct { lspClients map[string]*lsp.Client permissions permission.Service files history.Service + workingDir string } const ( @@ -99,11 +99,12 @@ WINDOWS NOTES: Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.` ) -func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { +func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service, workingDir string) BaseTool { return &editTool{ lspClients: lspClients, permissions: permissions, files: files, + workingDir: workingDir, } } @@ -144,8 +145,7 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } if !filepath.IsAbs(params.FilePath) { - wd := config.Get().WorkingDir() - params.FilePath = filepath.Join(wd, params.FilePath) + params.FilePath = filepath.Join(e.workingDir, params.FilePath) } var response ToolResponse @@ -206,9 +206,9 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) _, additions, removals := diff.GenerateDiff( "", content, - strings.TrimPrefix(filePath, config.Get().WorkingDir()), + strings.TrimPrefix(filePath, e.workingDir), ) - rootDir := config.Get().WorkingDir() + rootDir := e.workingDir permissionPath := filepath.Dir(filePath) if strings.HasPrefix(filePath, rootDir) { permissionPath = rootDir @@ -318,10 +318,10 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string _, additions, removals := diff.GenerateDiff( oldContent, newContent, - strings.TrimPrefix(filePath, config.Get().WorkingDir()), + strings.TrimPrefix(filePath, e.workingDir), ) - rootDir := config.Get().WorkingDir() + rootDir := e.workingDir permissionPath := filepath.Dir(filePath) if strings.HasPrefix(filePath, rootDir) { permissionPath = rootDir @@ -441,9 +441,9 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS _, additions, removals := diff.GenerateDiff( oldContent, newContent, - strings.TrimPrefix(filePath, config.Get().WorkingDir()), + strings.TrimPrefix(filePath, e.workingDir), ) - rootDir := config.Get().WorkingDir() + rootDir := e.workingDir permissionPath := filepath.Dir(filePath) if strings.HasPrefix(filePath, rootDir) { permissionPath = rootDir diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go index 6895556dbd925d8396b0258ab16c422cdc1a1810..28e15d19cee8219ccc4575ed036f29e8286db229 100644 --- a/internal/llm/tools/fetch.go +++ b/internal/llm/tools/fetch.go @@ -11,7 +11,6 @@ import ( md "github.com/JohannesKaufmann/html-to-markdown" "github.com/PuerkitoBio/goquery" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/permission" ) @@ -30,6 +29,7 @@ type FetchPermissionsParams struct { type fetchTool struct { client *http.Client permissions permission.Service + workingDir string } const ( @@ -65,7 +65,7 @@ TIPS: - Set appropriate timeouts for potentially slow websites` ) -func NewFetchTool(permissions permission.Service) BaseTool { +func NewFetchTool(permissions permission.Service, workingDir string) BaseTool { return &fetchTool{ client: &http.Client{ Timeout: 30 * time.Second, @@ -76,6 +76,7 @@ func NewFetchTool(permissions permission.Service) BaseTool { }, }, permissions: permissions, + workingDir: workingDir, } } @@ -133,7 +134,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error p := t.permissions.Request( permission.CreatePermissionRequest{ SessionID: sessionID, - Path: config.Get().WorkingDir(), + Path: t.workingDir, ToolName: FetchToolName, Action: "fetch", Description: fmt.Sprintf("Fetch content from URL: %s", params.URL), diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index 8dcd7c07586f98c87290ebe83e02a7dddd613a40..c70c76b7d2dbd798118a54859e5672dacc6e1304 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -11,7 +11,6 @@ import ( "sort" "strings" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/fsext" ) @@ -68,10 +67,14 @@ type GlobResponseMetadata struct { Truncated bool `json:"truncated"` } -type globTool struct{} +type globTool struct { + workingDir string +} -func NewGlobTool() BaseTool { - return &globTool{} +func NewGlobTool(workingDir string) BaseTool { + return &globTool{ + workingDir: workingDir, + } } func (g *globTool) Name() string { @@ -108,7 +111,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) searchPath := params.Path if searchPath == "" { - searchPath = config.Get().WorkingDir() + searchPath = g.workingDir } files, truncated, err := globFiles(params.Pattern, searchPath, 100) diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index ede19c1daa75c3fea16bb52f0ed4b9ff5093e247..0b39c63484bb9508e14865215bf5e57430efe468 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -16,7 +16,6 @@ import ( "sync" "time" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/fsext" ) @@ -89,7 +88,9 @@ type GrepResponseMetadata struct { Truncated bool `json:"truncated"` } -type grepTool struct{} +type grepTool struct { + workingDir string +} const ( GrepToolName = "grep" @@ -136,8 +137,10 @@ TIPS: - Use literal_text=true when searching for exact text containing special characters like dots, parentheses, etc.` ) -func NewGrepTool() BaseTool { - return &grepTool{} +func NewGrepTool(workingDir string) BaseTool { + return &grepTool{ + workingDir: workingDir, + } } func (g *grepTool) Name() string { @@ -200,7 +203,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) searchPath := params.Path if searchPath == "" { - searchPath = config.Get().WorkingDir() + searchPath = g.workingDir } matches, truncated, err := searchFiles(searchPattern, searchPath, params.Include, 100) diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index 6e858a6990b64c795ad6f4df957f9e0d5c7ad6d3..6526a57274c0ef6c169f17979d277a3031fe1f72 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -8,7 +8,6 @@ import ( "path/filepath" "strings" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/fsext" ) @@ -29,7 +28,9 @@ type LSResponseMetadata struct { Truncated bool `json:"truncated"` } -type lsTool struct{} +type lsTool struct { + workingDir string +} const ( LSToolName = "ls" @@ -70,8 +71,10 @@ TIPS: - Combine with other tools for more effective exploration` ) -func NewLsTool() BaseTool { - return &lsTool{} +func NewLsTool(workingDir string) BaseTool { + return &lsTool{ + workingDir: workingDir, + } } func (l *lsTool) Name() string { @@ -107,11 +110,11 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { searchPath := params.Path if searchPath == "" { - searchPath = config.Get().WorkingDir() + searchPath = l.workingDir } if !filepath.IsAbs(searchPath) { - searchPath = filepath.Join(config.Get().WorkingDir(), searchPath) + searchPath = filepath.Join(l.workingDir, searchPath) } if _, err := os.Stat(searchPath); os.IsNotExist(err) { diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index b156f89a26628982417aee1ab23354abf3415f61..27bbc237209e64637cfefb0f4ff1097f96641c2e 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -10,7 +10,6 @@ import ( "path/filepath" "strings" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/lsp" ) @@ -22,6 +21,7 @@ type ViewParams struct { type viewTool struct { lspClients map[string]*lsp.Client + workingDir string } type ViewResponseMetadata struct { @@ -71,9 +71,10 @@ TIPS: - When viewing large files, use the offset parameter to read specific sections` ) -func NewViewTool(lspClients map[string]*lsp.Client) BaseTool { +func NewViewTool(lspClients map[string]*lsp.Client, workingDir string) BaseTool { return &viewTool{ - lspClients, + lspClients: lspClients, + workingDir: workingDir, } } @@ -117,7 +118,7 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) // Handle relative paths filePath := params.FilePath if !filepath.IsAbs(filePath) { - filePath = filepath.Join(config.Get().WorkingDir(), filePath) + filePath = filepath.Join(v.workingDir, filePath) } // Check if file exists diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 9c64f5bff68e2a1d0fcea198ae82defb25a05dab..50f472bf2e65dba2b3c7e9efd9ecc88136764d2f 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -10,7 +10,6 @@ import ( "strings" "time" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/history" @@ -33,6 +32,7 @@ type writeTool struct { lspClients map[string]*lsp.Client permissions permission.Service files history.Service + workingDir string } type WriteResponseMetadata struct { @@ -77,11 +77,12 @@ TIPS: - Always include descriptive comments when making changes to existing code` ) -func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { +func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service, workingDir string) BaseTool { return &writeTool{ lspClients: lspClients, permissions: permissions, files: files, + workingDir: workingDir, } } @@ -123,7 +124,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error filePath := params.FilePath if !filepath.IsAbs(filePath) { - filePath = filepath.Join(config.Get().WorkingDir(), filePath) + filePath = filepath.Join(w.workingDir, filePath) } fileInfo, err := os.Stat(filePath) @@ -168,10 +169,10 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error diff, additions, removals := diff.GenerateDiff( oldContent, params.Content, - strings.TrimPrefix(filePath, config.Get().WorkingDir()), + strings.TrimPrefix(filePath, w.workingDir), ) - rootDir := config.Get().WorkingDir() + rootDir := w.workingDir permissionPath := filepath.Dir(filePath) if strings.HasPrefix(filePath, rootDir) { permissionPath = rootDir