From fa7fe6be3d2f5b9f00de5060faf9cc0b4aa14c33 Mon Sep 17 00:00:00 2001 From: kujtimiihoxha Date: Thu, 2 Oct 2025 12:32:49 +0200 Subject: [PATCH] wip: remove old implementation --- go.mod | 2 +- go.sum | 2 + internal/agent/agent-tool.go | 98 ++ internal/agent/agent.go | 27 +- internal/agent/coder.go | 18 - internal/agent/common_test.go | 8 +- internal/agent/coordinator.go | 484 ++++++- internal/agent/prompts.go | 36 + internal/agent/templates/agentTool.md | 16 + internal/agent/tools/mcp-tools.go | 219 ++-- internal/app/app.go | 65 +- internal/config/config.go | 22 +- internal/config/load_test.go | 12 +- internal/llm/agent/agent-tool.go | 106 -- internal/llm/agent/agent.go | 1133 ----------------- internal/llm/agent/errors.go | 15 - internal/llm/agent/event.go | 53 - internal/llm/agent/mcp-tools.go | 466 ------- internal/llm/prompt/anthropic.md | 108 -- internal/llm/prompt/coder.go | 100 -- internal/llm/prompt/gemini.md | 165 --- internal/llm/prompt/init.md | 9 - internal/llm/prompt/initialize.go | 10 - internal/llm/prompt/prompt.go | 143 --- internal/llm/prompt/prompt_test.go | 69 - internal/llm/prompt/summarize.md | 11 - internal/llm/prompt/summarizer.go | 10 - internal/llm/prompt/task.go | 15 - internal/llm/prompt/title.go | 10 - internal/llm/prompt/title.md | 8 - internal/llm/prompt/v2.md | 267 ---- internal/llm/provider/anthropic.go | 598 --------- internal/llm/provider/azure.go | 39 - internal/llm/provider/bedrock.go | 93 -- internal/llm/provider/gemini.go | 579 --------- internal/llm/provider/openai.go | 604 --------- internal/llm/provider/openai_test.go | 166 --- internal/llm/provider/provider.go | 208 --- internal/llm/provider/vertexai.go | 40 - internal/llm/tools/bash.go | 395 ------ internal/llm/tools/bash.md | 161 --- internal/llm/tools/diagnostics.go | 213 ---- internal/llm/tools/diagnostics.md | 21 - internal/llm/tools/download.go | 196 --- internal/llm/tools/download.md | 34 - internal/llm/tools/edit.go | 486 ------- internal/llm/tools/edit.md | 60 - internal/llm/tools/fetch.go | 236 ---- internal/llm/tools/fetch.md | 34 - internal/llm/tools/file.go | 53 - internal/llm/tools/glob.go | 150 --- internal/llm/tools/glob.md | 46 - internal/llm/tools/grep.go | 470 ------- internal/llm/tools/grep.md | 54 - internal/llm/tools/grep_test.go | 200 --- internal/llm/tools/ls.go | 283 ---- internal/llm/tools/ls.md | 40 - internal/llm/tools/multiedit.go | 424 ------ internal/llm/tools/multiedit.md | 48 - internal/llm/tools/rg.go | 53 - internal/llm/tools/safe.go | 70 - internal/llm/tools/sourcegraph.go | 302 ----- internal/llm/tools/sourcegraph.md | 102 -- internal/llm/tools/tools.go | 85 -- internal/llm/tools/view.go | 343 ----- internal/llm/tools/view.md | 42 - internal/llm/tools/write.go | 208 --- internal/llm/tools/write.md | 38 - internal/tui/components/chat/chat.go | 8 +- internal/tui/components/chat/editor/editor.go | 6 +- internal/tui/components/chat/header/header.go | 2 +- .../tui/components/chat/messages/renderer.go | 4 +- internal/tui/components/chat/messages/tool.go | 4 +- .../tui/components/chat/sidebar/sidebar.go | 2 +- internal/tui/components/chat/splash/splash.go | 6 +- .../components/dialogs/commands/commands.go | 8 +- .../tui/components/dialogs/compact/compact.go | 196 +-- .../dialogs/permissions/permissions.go | 2 +- .../components/dialogs/reasoning/reasoning.go | 2 +- internal/tui/components/mcp/mcp.go | 12 +- internal/tui/page/chat/chat.go | 48 +- internal/tui/tui.go | 70 +- 82 files changed, 977 insertions(+), 10274 deletions(-) create mode 100644 internal/agent/agent-tool.go delete mode 100644 internal/agent/coder.go create mode 100644 internal/agent/prompts.go create mode 100644 internal/agent/templates/agentTool.md delete mode 100644 internal/llm/agent/agent-tool.go delete mode 100644 internal/llm/agent/agent.go delete mode 100644 internal/llm/agent/errors.go delete mode 100644 internal/llm/agent/event.go delete mode 100644 internal/llm/agent/mcp-tools.go delete mode 100644 internal/llm/prompt/anthropic.md delete mode 100644 internal/llm/prompt/coder.go delete mode 100644 internal/llm/prompt/gemini.md delete mode 100644 internal/llm/prompt/init.md delete mode 100644 internal/llm/prompt/initialize.go delete mode 100644 internal/llm/prompt/prompt.go delete mode 100644 internal/llm/prompt/prompt_test.go delete mode 100644 internal/llm/prompt/summarize.md delete mode 100644 internal/llm/prompt/summarizer.go delete mode 100644 internal/llm/prompt/task.go delete mode 100644 internal/llm/prompt/title.go delete mode 100644 internal/llm/prompt/title.md delete mode 100644 internal/llm/prompt/v2.md delete mode 100644 internal/llm/provider/anthropic.go delete mode 100644 internal/llm/provider/azure.go delete mode 100644 internal/llm/provider/bedrock.go delete mode 100644 internal/llm/provider/gemini.go delete mode 100644 internal/llm/provider/openai.go delete mode 100644 internal/llm/provider/openai_test.go delete mode 100644 internal/llm/provider/provider.go delete mode 100644 internal/llm/provider/vertexai.go delete mode 100644 internal/llm/tools/bash.go delete mode 100644 internal/llm/tools/bash.md delete mode 100644 internal/llm/tools/diagnostics.go delete mode 100644 internal/llm/tools/diagnostics.md delete mode 100644 internal/llm/tools/download.go delete mode 100644 internal/llm/tools/download.md delete mode 100644 internal/llm/tools/edit.go delete mode 100644 internal/llm/tools/edit.md delete mode 100644 internal/llm/tools/fetch.go delete mode 100644 internal/llm/tools/fetch.md delete mode 100644 internal/llm/tools/file.go delete mode 100644 internal/llm/tools/glob.go delete mode 100644 internal/llm/tools/glob.md delete mode 100644 internal/llm/tools/grep.go delete mode 100644 internal/llm/tools/grep.md delete mode 100644 internal/llm/tools/grep_test.go delete mode 100644 internal/llm/tools/ls.go delete mode 100644 internal/llm/tools/ls.md delete mode 100644 internal/llm/tools/multiedit.go delete mode 100644 internal/llm/tools/multiedit.md delete mode 100644 internal/llm/tools/rg.go delete mode 100644 internal/llm/tools/safe.go delete mode 100644 internal/llm/tools/sourcegraph.go delete mode 100644 internal/llm/tools/sourcegraph.md delete mode 100644 internal/llm/tools/tools.go delete mode 100644 internal/llm/tools/view.go delete mode 100644 internal/llm/tools/view.md delete mode 100644 internal/llm/tools/write.go delete mode 100644 internal/llm/tools/write.md diff --git a/go.mod b/go.mod index c9c9a28541a5566f8beccdcd86ff2ac33c4d0c66..c438c64afdd4d272903b5ccf6b6c2e5fc43eb6f1 100644 --- a/go.mod +++ b/go.mod @@ -97,7 +97,7 @@ require ( require ( github.com/anthropics/anthropic-sdk-go v1.12.0 - github.com/charmbracelet/fantasy v0.0.0-20250930080945-d8adb55feb9a + github.com/charmbracelet/fantasy v0.0.0-20251002051643-c96822199d77 github.com/go-viper/mapstructure/v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index 54a6a80bd08064033fab29f765d52f72e2e67f24..d36d59227547f9e247400449e2785e0e0fee0331 100644 --- a/go.sum +++ b/go.sum @@ -91,6 +91,8 @@ github.com/charmbracelet/fang v0.4.2 h1:nWr7Tb82/TTNNGMGG35aTZ1X68loAOQmpb0qxkKX github.com/charmbracelet/fang v0.4.2/go.mod h1:wHJKQYO5ReYsxx+yZl+skDtrlKO/4LLEQ6EXsdHhRhg= github.com/charmbracelet/fantasy v0.0.0-20250930080945-d8adb55feb9a h1:pD/ATiDbIRm8i5cKf8k1NR/0HhmSdAejY2kbIpmBiis= github.com/charmbracelet/fantasy v0.0.0-20250930080945-d8adb55feb9a/go.mod h1:RZotHpq44tKZDe6Vf0kk1iDqnUgH7Scx+K/7uJ9Qwnw= +github.com/charmbracelet/fantasy v0.0.0-20251002051643-c96822199d77 h1:YHuUqaojkeu00YtQeXPqM/1RNJH/jqGNaQYFwa7JQTk= +github.com/charmbracelet/fantasy v0.0.0-20251002051643-c96822199d77/go.mod h1:RZotHpq44tKZDe6Vf0kk1iDqnUgH7Scx+K/7uJ9Qwnw= github.com/charmbracelet/glamour/v2 v2.0.0-20250811143442-a27abb32f018 h1:PU4Zvpagsk5sgaDxn5W4sxHuLp9QRMBZB3bFSk40A4w= github.com/charmbracelet/glamour/v2 v2.0.0-20250811143442-a27abb32f018/go.mod h1:Z/GLmp9fzaqX4ze3nXG7StgWez5uBM5XtlLHK8V/qSk= github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3.0.20250917201909-41ff0bf215ea h1:g1HfUgSMvye8mgecMD1mPscpt+pzJoDEiSA+p2QXzdQ= diff --git a/internal/agent/agent-tool.go b/internal/agent/agent-tool.go new file mode 100644 index 0000000000000000000000000000000000000000..82ba6ef41f5144117e0f8bff12104b97d287da5f --- /dev/null +++ b/internal/agent/agent-tool.go @@ -0,0 +1,98 @@ +package agent + +import ( + "context" + _ "embed" + "encoding/json" + "errors" + "fmt" + + "github.com/charmbracelet/fantasy/ai" + + "github.com/charmbracelet/crush/internal/agent/prompt" + "github.com/charmbracelet/crush/internal/agent/tools" + "github.com/charmbracelet/crush/internal/config" +) + +//go:embed templates/agentTool.md +var agentToolDescription []byte + +type AgentParams struct { + Prompt string `json:"prompt" description:"The task for the agent to perform"` +} + +const ( + AgentToolName = "agent" +) + +func (c *coordinator) agentTool() (ai.AgentTool, error) { + agentCfg, ok := c.cfg.Agents[config.AgentTask] + if !ok { + return nil, errors.New("task agent not configured") + } + prompt, err := taskPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir())) + if err != nil { + return nil, err + } + + agent, err := c.buildAgent(prompt, agentCfg) + if err != nil { + return nil, err + } + return ai.NewAgentTool( + AgentToolName, + string(agentToolDescription), + func(ctx context.Context, params AgentParams, call ai.ToolCall) (ai.ToolResponse, error) { + if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { + return ai.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil + } + if params.Prompt == "" { + return ai.NewTextErrorResponse("prompt is required"), nil + } + + sessionID := tools.GetSessionFromContext(ctx) + if sessionID == "" { + return ai.ToolResponse{}, errors.New("session id missing from context") + } + + session, err := c.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session") + if err != nil { + return ai.ToolResponse{}, fmt.Errorf("error creating session: %s", err) + } + model := agent.Model() + maxTokens := model.CatwalkCfg.DefaultMaxTokens + if model.ModelCfg.MaxTokens != 0 { + maxTokens = model.ModelCfg.MaxTokens + } + result, err := agent.Run(ctx, SessionAgentCall{ + SessionID: sessionID, + Prompt: params.Prompt, + MaxOutputTokens: maxTokens, + ProviderOptions: c.getProviderOptions(model), + Temperature: model.ModelCfg.Temperature, + TopP: model.ModelCfg.TopP, + TopK: model.ModelCfg.TopK, + FrequencyPenalty: model.ModelCfg.FrequencyPenalty, + PresencePenalty: model.ModelCfg.PresencePenalty, + }) + if err != nil { + return ai.ToolResponse{}, fmt.Errorf("error generating agent: %s", err) + } + updatedSession, err := c.sessions.Get(ctx, session.ID) + if err != nil { + return ai.ToolResponse{}, fmt.Errorf("error getting session: %s", err) + } + parentSession, err := c.sessions.Get(ctx, sessionID) + if err != nil { + return ai.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err) + } + + parentSession.Cost += updatedSession.Cost + + _, err = c.sessions.Save(ctx, parentSession) + if err != nil { + return ai.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err) + } + return ai.NewTextResponse(result.Response.Content.Text()), nil + }), nil +} diff --git a/internal/agent/agent.go b/internal/agent/agent.go index d9181c10eec49639aa510718f69aff7353707adb..6272bdf20e33be03db174c89a498fd5ae4377fbf 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -12,6 +12,7 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/agent/tools" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/permission" @@ -50,11 +51,13 @@ type SessionAgent interface { QueuedPrompts(sessionID string) int ClearQueue(sessionID string) Summarize(context.Context, string) error + Model() Model } type Model struct { - model ai.LanguageModel - config catwalk.Model + Model ai.LanguageModel + CatwalkCfg catwalk.Model + ModelCfg config.SelectedModel } type sessionAgent struct { @@ -116,7 +119,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen } agent := ai.NewAgent( - a.largeModel.model, + a.largeModel.Model, ai.WithSystemPrompt(a.systemPrompt), ai.WithTools(a.tools...), ) @@ -174,8 +177,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen assistantMsg, err = a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, - Model: a.largeModel.model.Model(), - Provider: a.largeModel.model.Provider(), + Model: a.largeModel.ModelCfg.Model, + Provider: a.largeModel.ModelCfg.Provider, }) if err != nil { return prepared, err @@ -384,13 +387,13 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error { defer a.activeRequests.Del(sessionID) defer cancel() - agent := ai.NewAgent(a.largeModel.model, + agent := ai.NewAgent(a.largeModel.Model, ai.WithSystemPrompt(string(summaryPrompt)), ) summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, - Model: a.largeModel.model.Model(), - Provider: a.largeModel.model.Provider(), + Model: a.largeModel.Model.Model(), + Provider: a.largeModel.Model.Provider(), }) if err != nil { return err @@ -516,7 +519,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session session.Sessio return } - agent := ai.NewAgent(a.smallModel.model, + agent := ai.NewAgent(a.smallModel.Model, ai.WithSystemPrompt(string(titlePrompt)), ai.WithMaxOutputTokens(40), ) @@ -554,7 +557,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session session.Sessio } func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) { - modelConfig := model.config + modelConfig := model.CatwalkCfg cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) + @@ -641,3 +644,7 @@ func (a *sessionAgent) SetModels(large Model, small Model) { func (a *sessionAgent) SetTools(tools []ai.AgentTool) { a.tools = tools } + +func (a *sessionAgent) Model() Model { + return a.largeModel +} diff --git a/internal/agent/coder.go b/internal/agent/coder.go deleted file mode 100644 index 65e4870aee616f9f5548c532e5e78e6999aa8b8c..0000000000000000000000000000000000000000 --- a/internal/agent/coder.go +++ /dev/null @@ -1,18 +0,0 @@ -package agent - -import ( - _ "embed" - - "github.com/charmbracelet/crush/internal/agent/prompt" -) - -//go:embed templates/coder.gotmpl -var coderPromptTmpl []byte - -func coderPrompt(opts ...prompt.Option) (*prompt.Prompt, error) { - systemPrompt, err := prompt.NewPrompt("coder", string(coderPromptTmpl), opts...) - if err != nil { - return nil, err - } - return systemPrompt, nil -} diff --git a/internal/agent/common_test.go b/internal/agent/common_test.go index 2a0781ce6eac12bca0f65beb34f86403023109db..875552c222115d65137295d4e5121bccb4001294 100644 --- a/internal/agent/common_test.go +++ b/internal/agent/common_test.go @@ -114,14 +114,14 @@ func testEnv(t *testing.T) env { func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) SessionAgent { largeModel := Model{ - model: large, - config: catwalk.Model{ + Model: large, + CatwalkCfg: catwalk.Model{ // todo: add values }, } smallModel := Model{ - model: small, - config: catwalk.Model{ + Model: small, + CatwalkCfg: catwalk.Model{ // todo: add values }, } diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 3cbccaa9bd3d849adf07ebb0e0985baa131994d2..520f6046ba6b42360e227e3e6f804eb85e4ce8ec 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -2,19 +2,499 @@ package agent import ( "context" + "errors" + "slices" + "strings" + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/agent/prompt" + "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/log" + "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/fantasy/ai" + "github.com/charmbracelet/fantasy/anthropic" + "github.com/charmbracelet/fantasy/google" + "github.com/charmbracelet/fantasy/openai" + "github.com/charmbracelet/fantasy/openaicompat" + "github.com/charmbracelet/fantasy/openrouter" ) type Coordinator interface { // INFO: (kujtim) this is not used yet we will use this when we have multiple agents - SetMainAgent(string) + // SetMainAgent(string) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) + Cancel(sessionID string) + CancelAll() + IsSessionBusy(sessionID string) bool + IsBusy() bool + QueuedPrompts(sessionID string) int + ClearQueue(sessionID string) + Summarize(context.Context, string) error + Model() Model + UpdateModels() error } type coordinator struct { - cfg *config.Config + cfg *config.Config + sessions session.Service + messages message.Service + permissions permission.Service + history history.Service + lspClients *csync.Map[string, *lsp.Client] + currentAgent SessionAgent + agents map[string]SessionAgent +} + +func NewCoordinator( + cfg *config.Config, + sessions session.Service, + messages message.Service, + permissions permission.Service, + history history.Service, + lspClients *csync.Map[string, *lsp.Client], +) (Coordinator, error) { + c := &coordinator{ + cfg: cfg, + sessions: sessions, + messages: messages, + permissions: permissions, + history: history, + lspClients: lspClients, + agents: make(map[string]SessionAgent), + } + + agentCfg, ok := cfg.Agents[config.AgentCoder] + if !ok { + return nil, errors.New("coder agent not configured") + } + + // TODO: make this dynamic when we support multiple agents + prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir())) + if err != nil { + return nil, err + } + + agent, err := c.buildAgent(prompt, agentCfg) + if err != nil { + return nil, err + } + c.currentAgent = agent + c.agents[config.AgentCoder] = agent + return c, nil +} + +// Run implements Coordinator. +func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) { + model := c.currentAgent.Model() + maxTokens := model.CatwalkCfg.DefaultMaxTokens + if model.ModelCfg.MaxTokens != 0 { + maxTokens = model.ModelCfg.MaxTokens + } + + if !model.CatwalkCfg.SupportsImages && attachments != nil { + attachments = nil + } + + return c.currentAgent.Run(ctx, SessionAgentCall{ + SessionID: sessionID, + Prompt: prompt, + Attachments: attachments, + MaxOutputTokens: maxTokens, + ProviderOptions: c.getProviderOptions(model), + Temperature: model.ModelCfg.Temperature, + TopP: model.ModelCfg.TopP, + TopK: model.ModelCfg.TopK, + FrequencyPenalty: model.ModelCfg.FrequencyPenalty, + PresencePenalty: model.ModelCfg.PresencePenalty, + }) +} + +func (c *coordinator) getProviderOptions(model Model) ai.ProviderOptions { + options := ai.ProviderOptions{} + + switch model.Model.Provider() { + case openai.Name: + parsed, err := openai.ParseOptions(model.ModelCfg.ProviderOptions) + if err == nil { + options[openai.Name] = parsed + } + case anthropic.Name: + parsed, err := anthropic.ParseOptions(model.ModelCfg.ProviderOptions) + if err == nil { + options[anthropic.Name] = parsed + } + case openrouter.Name: + parsed, err := openrouter.ParseOptions(model.ModelCfg.ProviderOptions) + if err == nil { + options[openrouter.Name] = parsed + } + case google.Name: + parsed, err := google.ParseOptions(model.ModelCfg.ProviderOptions) + if err == nil { + options[google.Name] = parsed + } + case openaicompat.Name: + parsed, err := openaicompat.ParseOptions(model.ModelCfg.ProviderOptions) + if err == nil { + options[openaicompat.Name] = parsed + } + } + + return options +} + +func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) { + large, small, err := c.buildAgentModels() + if err != nil { + return nil, err + } + + systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg) + if err != nil { + return nil, err + } + + tools, err := c.buildTools(agent) + if err != nil { + return nil, err + } + return NewSessionAgent(large, small, systemPrompt, c.sessions, c.messages, tools...), nil +} + +func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) { + var allTools []ai.AgentTool + if slices.Contains(agent.AllowedTools, AgentToolName) { + agentTool, err := c.agentTool() + if err != nil { + return nil, err + } + allTools = append(allTools, agentTool) + } + + allTools = append(allTools, + tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution), + tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil), + tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()), + tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()), + tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil), + tools.NewGlobTool(c.cfg.WorkingDir()), + tools.NewGrepTool(c.cfg.WorkingDir()), + tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls), + tools.NewSourcegraphTool(nil), + tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()), + tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()), + ) + + var filteredTools []ai.AgentTool + for _, tool := range allTools { + if slices.Contains(agent.AllowedTools, tool.Info().Name) { + filteredTools = append(filteredTools, tool) + } + } + + mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg) + + for _, mcpTool := range mcpTools { + if agent.AllowedMCP == nil { + // No MCP restrictions + filteredTools = append(filteredTools, mcpTool) + } else if len(agent.AllowedMCP) == 0 { + // no mcps allowed + break + } + + for mcp, tools := range agent.AllowedMCP { + if mcp == mcpTool.MCP() { + if len(tools) == 0 { + filteredTools = append(filteredTools, mcpTool) + } + for _, t := range tools { + if t == mcpTool.MCPToolName() { + filteredTools = append(filteredTools, mcpTool) + } + } + break + } + } + } + + return filteredTools, nil +} + +// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config +func (c *coordinator) buildAgentModels() (Model, Model, error) { + largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge] + if !ok { + return Model{}, Model{}, errors.New("large model not selected") + } + smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall] + if !ok { + return Model{}, Model{}, errors.New("small model not selected") + } + + largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider) + if !ok { + return Model{}, Model{}, errors.New("large model provider not configured") + } + + largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg) + if err != nil { + return Model{}, Model{}, err + } + + smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider) + if !ok { + return Model{}, Model{}, errors.New("large model provider not configured") + } + + smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg) + if err != nil { + return Model{}, Model{}, err + } + + var largeCatwalkModel *catwalk.Model + var smallCatwalkModel *catwalk.Model + + for _, m := range largeProviderCfg.Models { + if m.ID == largeModelCfg.Model { + largeCatwalkModel = &m + } + } + for _, m := range smallProviderCfg.Models { + if m.ID == smallModelCfg.Model { + smallCatwalkModel = &m + } + } + + if largeCatwalkModel == nil { + return Model{}, Model{}, errors.New("large model not found in provider config") + } + + if smallCatwalkModel == nil { + return Model{}, Model{}, errors.New("snall model not found in provider config") + } + + largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model) + if err != nil { + return Model{}, Model{}, err + } + smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model) + if err != nil { + return Model{}, Model{}, err + } + + return Model{ + Model: largeModel, + CatwalkCfg: *largeCatwalkModel, + ModelCfg: largeModelCfg, + }, Model{ + Model: smallModel, + CatwalkCfg: *smallCatwalkModel, + ModelCfg: smallModelCfg, + }, nil +} + +func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) ai.Provider { + hasBearerAuth := false + for key := range headers { + if strings.ToLower(key) == "authorization" { + hasBearerAuth = true + break + } + } + if hasBearerAuth { + apiKey = "" // clear apiKey to avoid using X-Api-Key header + } + + var opts []anthropic.Option + + if apiKey != "" { + // Use standard X-Api-Key header + opts = append(opts, anthropic.WithAPIKey(apiKey)) + } + + if len(headers) > 0 { + opts = append(opts, anthropic.WithHeaders(headers)) + } + + if baseURL != "" { + opts = append(opts, anthropic.WithBaseURL(baseURL)) + } + + if c.cfg.Options.Debug { + httpClient := log.NewHTTPClient() + opts = append(opts, anthropic.WithHTTPClient(httpClient)) + } + + return anthropic.New(opts...) +} + +func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) ai.Provider { + opts := []openai.Option{ + openai.WithAPIKey(apiKey), + } + if c.cfg.Options.Debug { + httpClient := log.NewHTTPClient() + opts = append(opts, openai.WithHTTPClient(httpClient)) + } + if len(headers) > 0 { + opts = append(opts, openai.WithHeaders(headers)) + } + if baseURL != "" { + opts = append(opts, openai.WithBaseURL(baseURL)) + } + return openai.New(opts...) +} + +func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) ai.Provider { + opts := []openrouter.Option{ + openrouter.WithAPIKey(apiKey), + openrouter.WithLanguageUniqueToolCallIds(), + } + if c.cfg.Options.Debug { + httpClient := log.NewHTTPClient() + opts = append(opts, openrouter.WithHTTPClient(httpClient)) + } + if len(headers) > 0 { + opts = append(opts, openrouter.WithHeaders(headers)) + } + return openrouter.New(opts...) +} + +func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) ai.Provider { + opts := []openaicompat.Option{ + openaicompat.WithAPIKey(apiKey), + } + if c.cfg.Options.Debug { + httpClient := log.NewHTTPClient() + opts = append(opts, openaicompat.WithHTTPClient(httpClient)) + } + if len(headers) > 0 { + opts = append(opts, openaicompat.WithHeaders(headers)) + } + + return openaicompat.New(baseURL, opts...) +} + +// TODO: add baseURL for google +func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) ai.Provider { + opts := []google.Option{ + google.WithAPIKey(apiKey), + } + if c.cfg.Options.Debug { + httpClient := log.NewHTTPClient() + opts = append(opts, google.WithHTTPClient(httpClient)) + } + if len(headers) > 0 { + opts = append(opts, google.WithHeaders(headers)) + } + return google.New(opts...) +} + +func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool { + if model.Think { + return true + } + + if model.ProviderOptions == nil { + return false + } + + opts, err := anthropic.ParseOptions(model.ProviderOptions) + if err != nil { + return false + } + if opts.Thinking != nil { + return true + } + return false +} + +func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (ai.Provider, error) { + headers := providerCfg.ExtraHeaders + + // handle special headers for anthropic + if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) { + headers["anthropic-beta"] = "interleaved-thinking-2025-05-14" + } + + // TODO: make sure we have + apiKey, _ := c.cfg.Resolve(providerCfg.APIKey) + baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL) + var provider ai.Provider + switch providerCfg.Type { + case openai.Name: + provider = c.buildOpenaiProvider(baseURL, apiKey, headers) + case anthropic.Name: + provider = c.buildAnthropicProvider(baseURL, apiKey, headers) + case openrouter.Name: + provider = c.buildOpenrouterProvider(baseURL, apiKey, headers) + case google.Name: + provider = c.buildGoogleProvider(baseURL, apiKey, headers) + case openaicompat.Name: + provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers) + default: + return nil, errors.New("provider type not supported") + } + return provider, nil +} + +func (c *coordinator) Cancel(sessionID string) { + c.currentAgent.Cancel(sessionID) +} + +func (c *coordinator) CancelAll() { + c.currentAgent.CancelAll() +} + +func (c *coordinator) ClearQueue(sessionID string) { + c.currentAgent.ClearQueue(sessionID) +} + +func (c *coordinator) IsBusy() bool { + return c.currentAgent.IsBusy() +} + +func (c *coordinator) IsSessionBusy(sessionID string) bool { + return c.currentAgent.IsSessionBusy(sessionID) +} + +func (c *coordinator) Model() Model { + return c.currentAgent.Model() +} + +func (c *coordinator) UpdateModels() error { + // build the models again so we make sure we get the latest config + large, small, err := c.buildAgentModels() + if err != nil { + return err + } + c.currentAgent.SetModels(large, small) + + agentCfg, ok := c.cfg.Agents[config.AgentCoder] + if !ok { + return errors.New("coder agent not configured") + } + + tools, err := c.buildTools(agentCfg) + if err != nil { + return err + } + c.currentAgent.SetTools(tools) + return nil +} + +func (c *coordinator) QueuedPrompts(sessionID string) int { + return c.currentAgent.QueuedPrompts(sessionID) +} + +func (c *coordinator) Summarize(ctx context.Context, sessionID string) error { + return c.currentAgent.Summarize(ctx, sessionID) } diff --git a/internal/agent/prompts.go b/internal/agent/prompts.go new file mode 100644 index 0000000000000000000000000000000000000000..f0f430e60e8bf35b0ab26bf4e62fc67f26d4b6f0 --- /dev/null +++ b/internal/agent/prompts.go @@ -0,0 +1,36 @@ +package agent + +import ( + _ "embed" + + "github.com/charmbracelet/crush/internal/agent/prompt" +) + +//go:embed templates/coder.gotmpl +var coderPromptTmpl []byte + +//go:embed templates/task.gotmpl +var taskPromptTmpl []byte + +//go:embed templates/initialize.md +var initializePrompt []byte + +func coderPrompt(opts ...prompt.Option) (*prompt.Prompt, error) { + systemPrompt, err := prompt.NewPrompt("coder", string(coderPromptTmpl), opts...) + if err != nil { + return nil, err + } + return systemPrompt, nil +} + +func taskPrompt(opts ...prompt.Option) (*prompt.Prompt, error) { + systemPrompt, err := prompt.NewPrompt("task", string(taskPromptTmpl), opts...) + if err != nil { + return nil, err + } + return systemPrompt, nil +} + +func InitializePrompt() string { + return string(initializePrompt) +} diff --git a/internal/agent/templates/agentTool.md b/internal/agent/templates/agentTool.md new file mode 100644 index 0000000000000000000000000000000000000000..50a98afec87653609c9fb7c523a60a46220c6c92 --- /dev/null +++ b/internal/agent/templates/agentTool.md @@ -0,0 +1,16 @@ +Launch a new agent that has access to the following tools: GlobTool, GrepTool, LS, View. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you. + + +- If you are searching for a keyword like "config" or "logger", or for questions like "which file does X?", the Agent tool is strongly recommended +- If you want to read a specific file path, use the View or GlobTool tool instead of the Agent tool, to find the match more quickly +- If you are searching for a specific class definition like "class Foo", use the GlobTool tool instead, to find the match more quickly + + + + +1. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses +2. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result. +3. Each agent invocation is stateless. You will not be able to send additional messages to the agent, nor will the agent be able to communicate with you outside of its final report. Therefore, your prompt should contain a highly detailed task description for the agent to perform autonomously and you should specify exactly what information the agent should return back to you in its final and only message to you. +4. The agent's outputs should generally be trusted +5. IMPORTANT: The agent can not use Bash, Replace, Edit, so can not modify files. If you want to use these tools, use them directly instead of going through the agent. + diff --git a/internal/agent/tools/mcp-tools.go b/internal/agent/tools/mcp-tools.go index e0e9d14fdc19682019ca9fc15bcf72890f980a30..53891dd045858434aa2eb3e4fde1e4a21870cd5f 100644 --- a/internal/agent/tools/mcp-tools.go +++ b/internal/agent/tools/mcp-tools.go @@ -54,7 +54,8 @@ func (s MCPState) String() string { type MCPEventType string const ( - MCPEventStateChanged MCPEventType = "state_changed" + MCPEventStateChanged MCPEventType = "state_changed" + MCPEventToolsListChanged MCPEventType = "tools_list_changed" ) // MCPEvent represents an event in the MCP system @@ -77,11 +78,12 @@ type MCPClientInfo struct { } var ( - mcpToolsOnce sync.Once - mcpTools []ai.AgentTool - mcpClients = csync.NewMap[string, *client.Client]() - mcpStates = csync.NewMap[string, MCPClientInfo]() - mcpBroker = pubsub.NewBroker[MCPEvent]() + mcpToolsOnce sync.Once + mcpTools = csync.NewMap[string, *McpTool]() + mcpClient2Tools = csync.NewMap[string, []*McpTool]() + mcpClients = csync.NewMap[string, *client.Client]() + mcpStates = csync.NewMap[string, MCPClientInfo]() + mcpBroker = pubsub.NewBroker[MCPEvent]() ) type McpTool struct { @@ -100,22 +102,30 @@ func (m *McpTool) ProviderOptions() ai.ProviderOptions { return m.providerOptions } -func (b *McpTool) Name() string { - return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name) +func (m *McpTool) Name() string { + return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name) } -func (b *McpTool) Info() ai.ToolInfo { - required := b.tool.InputSchema.Required +func (m *McpTool) MCP() string { + return m.mcpName +} + +func (m *McpTool) MCPToolName() string { + return m.tool.Name +} + +func (m *McpTool) Info() ai.ToolInfo { + required := m.tool.InputSchema.Required if required == nil { required = make([]string, 0) } - parameters := b.tool.InputSchema.Properties + parameters := m.tool.InputSchema.Properties if parameters == nil { parameters = make(map[string]any) } return ai.ToolInfo{ - Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name), - Description: b.tool.Description, + Name: fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name), + Description: m.tool.Description, Parameters: parameters, Required: required, } @@ -181,18 +191,18 @@ func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) return c, nil } -func (b *McpTool) Run(ctx context.Context, params ai.ToolCall) (ai.ToolResponse, error) { +func (m *McpTool) Run(ctx context.Context, params ai.ToolCall) (ai.ToolResponse, error) { sessionID := GetSessionFromContext(ctx) if sessionID == "" { return ai.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file") } - permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name) - p := b.permissions.Request( + permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name) + p := m.permissions.Request( permission.CreatePermissionRequest{ SessionID: sessionID, ToolCallID: params.ID, - Path: b.workingDir, - ToolName: b.Info().Name, + Path: m.workingDir, + ToolName: m.Info().Name, Action: "execute", Description: permissionDescription, Params: params.Input, @@ -202,19 +212,15 @@ func (b *McpTool) Run(ctx context.Context, params ai.ToolCall) (ai.ToolResponse, return ai.ToolResponse{}, permission.ErrorPermissionDenied } - return runTool(ctx, b.mcpName, b.tool.Name, params.Input) + return runTool(ctx, m.mcpName, m.tool.Name, params.Input) } -func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []ai.AgentTool { +func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) ([]*McpTool, error) { result, err := c.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { - slog.Error("error listing tools", "error", err) - updateMCPState(name, MCPStateError, err, nil, 0) - c.Close() - mcpClients.Del(name) - return nil + return nil, err } - mcpTools := make([]ai.AgentTool, 0, len(result.Tools)) + mcpTools := make([]*McpTool, 0, len(result.Tools)) for _, tool := range result.Tools { mcpTools = append(mcpTools, &McpTool{ mcpName: name, @@ -223,7 +229,7 @@ func getTools(ctx context.Context, name string, permissions permission.Service, workingDir: workingDir, }) } - return mcpTools + return mcpTools, nil } // SubscribeMCPEvents returns a channel for MCP events @@ -250,8 +256,12 @@ func updateMCPState(name string, state MCPState, err error, client *client.Clien Client: client, ToolCount: toolCount, } - if state == MCPStateConnected { + switch state { + case MCPStateConnected: info.ConnectedAt = time.Now() + case MCPStateError: + updateMcpTools(name, nil) + mcpClients.Del(name) } mcpStates.Set(name, info) @@ -265,6 +275,14 @@ func updateMCPState(name string, state MCPState, err error, client *client.Clien }) } +// publishMCPEventToolsListChanged publishes a tool list changed event +func publishMCPEventToolsListChanged(name string) { + mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{ + Type: MCPEventToolsListChanged, + Name: name, + }) +} + // CloseMCPClients closes all MCP clients. This should be called during application shutdown. func CloseMCPClients() error { var errs []error @@ -287,55 +305,79 @@ var mcpInitRequest = mcp.InitializeRequest{ }, } -func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []ai.AgentTool { - var wg sync.WaitGroup - result := csync.NewSlice[ai.AgentTool]() - - // Initialize states for all configured MCPs - for name, m := range cfg.MCP { - if m.Disabled { - updateMCPState(name, MCPStateDisabled, nil, nil, 0) - slog.Debug("skipping disabled mcp", "name", name) - continue - } +func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []*McpTool { + mcpToolsOnce.Do(func() { + var wg sync.WaitGroup + // Initialize states for all configured MCPs + for name, m := range cfg.MCP { + if m.Disabled { + updateMCPState(name, MCPStateDisabled, nil, nil, 0) + slog.Debug("skipping disabled mcp", "name", name) + continue + } - // Set initial starting state - updateMCPState(name, MCPStateStarting, nil, nil, 0) - - wg.Add(1) - go func(name string, m config.MCPConfig) { - defer func() { - wg.Done() - if r := recover(); r != nil { - var err error - switch v := r.(type) { - case error: - err = v - case string: - err = fmt.Errorf("panic: %s", v) - default: - err = fmt.Errorf("panic: %v", v) + // Set initial starting state + updateMCPState(name, MCPStateStarting, nil, nil, 0) + + wg.Add(1) + go func(name string, m config.MCPConfig) { + defer func() { + wg.Done() + if r := recover(); r != nil { + var err error + switch v := r.(type) { + case error: + err = v + case string: + err = fmt.Errorf("panic: %s", v) + default: + err = fmt.Errorf("panic: %v", v) + } + updateMCPState(name, MCPStateError, err, nil, 0) + slog.Error("panic in mcp client initialization", "error", err, "name", name) } + }() + + mcpCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m)) + defer cancel() + + c, err := createAndInitializeClient(mcpCtx, name, m, cfg.Resolver()) + if err != nil { + return + } + + mcpClients.Set(name, c) + + tools, err := getTools(mcpCtx, name, permissions, c, cfg.WorkingDir()) + if err != nil { + slog.Error("error listing tools", "error", err) updateMCPState(name, MCPStateError, err, nil, 0) - slog.Error("panic in mcp client initialization", "error", err, "name", name) + c.Close() + return } - }() - ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m)) - defer cancel() - c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver()) - if err != nil { - return - } - mcpClients.Set(name, c) + updateMcpTools(name, tools) + mcpClients.Set(name, c) + updateMCPState(name, MCPStateConnected, nil, c, len(tools)) + }(name, m) + } + wg.Wait() + }) + return slices.Collect(mcpTools.Seq()) +} - tools := getTools(ctx, name, permissions, c, cfg.WorkingDir()) - updateMCPState(name, MCPStateConnected, nil, c, len(tools)) - result.Append(tools...) - }(name, m) +// updateMcpTools updates the global mcpTools and mcpClient2Tools maps +func updateMcpTools(mcpName string, tools []*McpTool) { + if len(tools) == 0 { + mcpClient2Tools.Del(mcpName) + } else { + mcpClient2Tools.Set(mcpName, tools) + } + for _, tools := range mcpClient2Tools.Seq2() { + for _, t := range tools { + mcpTools.Set(t.Info().Name, t) + } } - wg.Wait() - return slices.Collect(result.Seq()) } func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) { @@ -346,32 +388,45 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon return nil, err } + c.OnNotification(func(n mcp.JSONRPCNotification) { + slog.Debug("Received MCP notification", "name", name, "notification", n) + switch n.Method { + case "notifications/tools/list_changed": + publishMCPEventToolsListChanged(name) + default: + slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method) + } + }) + + // XXX: ideally we should be able to use context.WithTimeout here, but, + // the SSE MCP client will start failing once that context is canceled. timeout := mcpTimeout(m) - initCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() + mcpCtx, cancel := context.WithCancel(ctx) + cancelTimer := time.AfterFunc(timeout, cancel) - // Only call Start() for non-stdio clients, as stdio clients auto-start - if m.Type != config.MCPStdio { - if err := c.Start(initCtx); err != nil { - updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) - slog.Error("error starting mcp client", "error", err, "name", name) - _ = c.Close() - return nil, err - } + if err := c.Start(mcpCtx); err != nil { + updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) + slog.Error("error starting mcp client", "error", err, "name", name) + _ = c.Close() + cancel() + return nil, err } - if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil { + + if _, err := c.Initialize(mcpCtx, mcpInitRequest); err != nil { updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) slog.Error("error initializing mcp client", "error", err, "name", name) _ = c.Close() + cancel() return nil, err } + cancelTimer.Stop() slog.Info("Initialized mcp client", "name", name) return c, nil } func maybeTimeoutErr(err error, timeout time.Duration) error { - if errors.Is(err, context.DeadlineExceeded) { + if errors.Is(err, context.Canceled) { return fmt.Errorf("timed out after %s", timeout) } return err diff --git a/internal/app/app.go b/internal/app/app.go index 8afec8a0c2aba1f271e083aaedaaa0ca06713363..1c86d1508169accf7f080d3a6a886fe46cc2cd43 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -10,14 +10,16 @@ import ( "time" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/history" - "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/pubsub" + "github.com/charmbracelet/fantasy/ai" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" @@ -32,7 +34,7 @@ type App struct { Permissions permission.Service // TODO: (kujtim) remove this when fully implemented - CoderAgent agent.Service + AgentCoordinator agent.Coordinator LSPClients *csync.Map[string, *lsp.Client] @@ -144,10 +146,23 @@ func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool // Automatically approve all permission requests for this non-interactive session app.Permissions.AutoApproveSession(sess.ID) - done, err := app.CoderAgent.Run(ctx, sess.ID, prompt) - if err != nil { - return fmt.Errorf("failed to start agent processing stream: %w", err) + type response struct { + result *ai.AgentResult + err error } + done := make(chan response, 1) + + go func(ctx context.Context, sessionID, prompt string) { + result, err := app.AgentCoordinator.Run(ctx, sess.ID, prompt) + if err != nil { + done <- response{ + err: fmt.Errorf("failed to start agent processing stream: %w", err), + } + } + done <- response{ + result: result, + } + }(ctx, sess.ID, prompt) messageEvents := app.Messages.Subscribe(ctx) messageReadBytes := make(map[string]int) @@ -156,26 +171,13 @@ func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool select { case result := <-done: stopSpinner() - - if result.Error != nil { - if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) { + if result.err != nil { + if errors.Is(result.err, context.Canceled) || errors.Is(result.err, agent.ErrRequestCancelled) { slog.Info("Non-interactive: agent processing cancelled", "session_id", sess.ID) return nil } - return fmt.Errorf("agent processing failed: %w", result.Error) + return fmt.Errorf("agent processing failed: %w", result.err) } - - msgContent := result.Message.Content().String() - readBts := messageReadBytes[result.Message.ID] - - if len(msgContent) < readBts { - slog.Error("Non-interactive: message content is shorter than read bytes", "message_length", len(msgContent), "read_bytes", readBts) - return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(msgContent), readBts) - } - fmt.Println(msgContent[readBts:]) - messageReadBytes[result.Message.ID] = len(msgContent) - - slog.Info("Non-interactive: run completed", "session_id", sess.ID) return nil case event := <-messageEvents: @@ -204,7 +206,7 @@ func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool } func (app *App) UpdateAgentModel() error { - return app.CoderAgent.UpdateModel() + return app.AgentCoordinator.UpdateModels() } func (app *App) setupEvents() { @@ -215,7 +217,7 @@ func (app *App) setupEvents() { setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events) setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events) setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events) - setupSubscriber(ctx, app.serviceEventsWG, "mcp", agent.SubscribeMCPEvents, app.events) + setupSubscriber(ctx, app.serviceEventsWG, "mcp", tools.SubscribeMCPEvents, app.events) setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events) cleanupFunc := func() error { cancel() @@ -259,17 +261,16 @@ func setupSubscriber[T any]( } func (app *App) InitCoderAgent() error { - coderAgentCfg := app.config.Agents["coder"] + coderAgentCfg := app.config.Agents[config.AgentCoder] if coderAgentCfg.ID == "" { return fmt.Errorf("coder agent configuration is missing") } var err error - app.CoderAgent, err = agent.NewAgent( - app.globalCtx, - coderAgentCfg, - app.Permissions, + app.AgentCoordinator, err = agent.NewCoordinator( + app.config, app.Sessions, app.Messages, + app.Permissions, app.History, app.LSPClients, ) @@ -279,9 +280,7 @@ func (app *App) InitCoderAgent() error { } // Add MCP client cleanup to shutdown process - app.cleanupFuncs = append(app.cleanupFuncs, agent.CloseMCPClients) - - setupSubscriber(app.eventsCtx, app.serviceEventsWG, "coderAgent", app.CoderAgent.Subscribe, app.events) + app.cleanupFuncs = append(app.cleanupFuncs, tools.CloseMCPClients) return nil } @@ -319,8 +318,8 @@ func (app *App) Subscribe(program *tea.Program) { // Shutdown performs a graceful shutdown of the application. func (app *App) Shutdown() { - if app.CoderAgent != nil { - app.CoderAgent.CancelAll() + if app.AgentCoordinator != nil { + app.AgentCoordinator.CancelAll() } // Shutdown all LSP clients. diff --git a/internal/config/config.go b/internal/config/config.go index 27e4e8d39964032bbbc488e218d461151c6e4ac2..9ccf837c9738773469b457cbeeef3474b1319e63 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -48,6 +48,11 @@ const ( SelectedModelTypeSmall SelectedModelType = "small" ) +const ( + AgentCoder string = "coder" + AgentTask string = "task" +) + type SelectedModel struct { // The model id as used by the provider API. // Required. @@ -103,9 +108,6 @@ type ProviderConfig struct { // The provider models Models []catwalk.Model `json:"models,omitempty" jsonschema:"description=List of models available from this provider"` - - // Override provider specific options. - ProviderOptions map[string]any `json:"provider_options,omitempty" jsonschema:"description=Additional provider-specific options for the model"` } type MCPType string @@ -264,10 +266,6 @@ type Agent struct { // if the string array is nil, all tools from the AllowedMCP are available AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"` - // The list of LSPs that this agent can use - // if this is nil, all LSPs are available - AllowedLSP []string `json:"allowed_lsp,omitempty"` - // Overrides the context paths for this agent ContextPaths []string `json:"context_paths,omitempty"` } @@ -514,16 +512,17 @@ func (c *Config) SetupAgents() { allowedTools := resolveAllowedTools(allToolNames(), c.Options.DisabledTools) agents := map[string]Agent{ - "coder": { - ID: "coder", + AgentCoder: { + ID: AgentCoder, Name: "Coder", Description: "An agent that helps with executing coding tasks.", Model: SelectedModelTypeLarge, ContextPaths: c.Options.ContextPaths, AllowedTools: allowedTools, }, - "task": { - ID: "task", + + AgentTask: { + ID: AgentCoder, Name: "Task", Description: "An agent that helps with searching for context and finding implementation details.", Model: SelectedModelTypeLarge, @@ -531,7 +530,6 @@ func (c *Config) SetupAgents() { AllowedTools: resolveReadOnlyTools(allowedTools), // NO MCPs or LSPs by default AllowedMCP: map[string][]string{}, - AllowedLSP: []string{}, }, } c.Agents = agents diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 406fe07d523c8b0d5d7f038f8d94cc74a0b58f89..7d324bb005e2633b189c25d8bc93881ca2dea6e9 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -462,11 +462,11 @@ func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) { } cfg.SetupAgents() - coderAgent, ok := cfg.Agents["coder"] + coderAgent, ok := cfg.Agents[AgentCoder] require.True(t, ok) assert.Equal(t, allToolNames(), coderAgent.AllowedTools) - taskAgent, ok := cfg.Agents["task"] + taskAgent, ok := cfg.Agents[AgentTask] require.True(t, ok) assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) } @@ -483,11 +483,11 @@ func TestConfig_setupAgentsWithDisabledTools(t *testing.T) { } cfg.SetupAgents() - coderAgent, ok := cfg.Agents["coder"] + coderAgent, ok := cfg.Agents[AgentCoder] require.True(t, ok) assert.Equal(t, []string{"agent", "bash", "multiedit", "fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools) - taskAgent, ok := cfg.Agents["task"] + taskAgent, ok := cfg.Agents[AgentTask] require.True(t, ok) assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) } @@ -506,11 +506,11 @@ func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) { } cfg.SetupAgents() - coderAgent, ok := cfg.Agents["coder"] + coderAgent, ok := cfg.Agents[AgentCoder] require.True(t, ok) assert.Equal(t, []string{"agent", "bash", "download", "edit", "multiedit", "fetch", "write"}, coderAgent.AllowedTools) - taskAgent, ok := cfg.Agents["task"] + taskAgent, ok := cfg.Agents[AgentTask] require.True(t, ok) assert.Equal(t, []string{}, taskAgent.AllowedTools) } diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go deleted file mode 100644 index ae15c5867e7321f2ac29e1809f5eb7effb830fdc..0000000000000000000000000000000000000000 --- a/internal/llm/agent/agent-tool.go +++ /dev/null @@ -1,106 +0,0 @@ -package agent - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/charmbracelet/crush/internal/llm/tools" - "github.com/charmbracelet/crush/internal/message" - "github.com/charmbracelet/crush/internal/session" -) - -type agentTool struct { - agent Service - sessions session.Service - messages message.Service -} - -const ( - AgentToolName = "agent" -) - -type AgentParams struct { - Prompt string `json:"prompt"` -} - -func (b *agentTool) Name() string { - return AgentToolName -} - -func (b *agentTool) Info() tools.ToolInfo { - return tools.ToolInfo{ - Name: AgentToolName, - Description: "Launch a new agent that has access to the following tools: GlobTool, GrepTool, LS, View. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you. For example:\n\n- If you are searching for a keyword like \"config\" or \"logger\", or for questions like \"which file does X?\", the Agent tool is strongly recommended\n- If you want to read a specific file path, use the View or GlobTool tool instead of the Agent tool, to find the match more quickly\n- If you are searching for a specific class definition like \"class Foo\", use the GlobTool tool instead, to find the match more quickly\n\nUsage notes:\n1. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses\n2. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.\n3. Each agent invocation is stateless. You will not be able to send additional messages to the agent, nor will the agent be able to communicate with you outside of its final report. Therefore, your prompt should contain a highly detailed task description for the agent to perform autonomously and you should specify exactly what information the agent should return back to you in its final and only message to you.\n4. The agent's outputs should generally be trusted\n5. IMPORTANT: The agent can not use Bash, Replace, Edit, so can not modify files. If you want to use these tools, use them directly instead of going through the agent.", - Parameters: map[string]any{ - "prompt": map[string]any{ - "type": "string", - "description": "The task for the agent to perform", - }, - }, - Required: []string{"prompt"}, - } -} - -func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) { - var params AgentParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil - } - if params.Prompt == "" { - return tools.NewTextErrorResponse("prompt is required"), nil - } - - sessionID, messageID := tools.GetContextValues(ctx) - if sessionID == "" || messageID == "" { - return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required") - } - - session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session") - if err != nil { - return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err) - } - - done, err := b.agent.Run(ctx, session.ID, params.Prompt) - if err != nil { - return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err) - } - result := <-done - if result.Error != nil { - return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", result.Error) - } - - response := result.Message - if response.Role != message.Assistant { - return tools.NewTextErrorResponse("no response"), nil - } - - updatedSession, err := b.sessions.Get(ctx, session.ID) - if err != nil { - return tools.ToolResponse{}, fmt.Errorf("error getting session: %s", err) - } - parentSession, err := b.sessions.Get(ctx, sessionID) - if err != nil { - return tools.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err) - } - - parentSession.Cost += updatedSession.Cost - - _, err = b.sessions.Save(ctx, parentSession) - if err != nil { - return tools.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err) - } - return tools.NewTextResponse(response.Content().String()), nil -} - -func NewAgentTool( - agent Service, - sessions session.Service, - messages message.Service, -) tools.BaseTool { - return &agentTool{ - sessions: sessions, - messages: messages, - agent: agent, - } -} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go deleted file mode 100644 index 1efc3fc268392c06481d61ae6e11c9d67cdc13e8..0000000000000000000000000000000000000000 --- a/internal/llm/agent/agent.go +++ /dev/null @@ -1,1133 +0,0 @@ -package agent - -import ( - "context" - "errors" - "fmt" - "log/slog" - "maps" - "slices" - "strings" - "time" - - "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/event" - "github.com/charmbracelet/crush/internal/history" - "github.com/charmbracelet/crush/internal/llm/prompt" - "github.com/charmbracelet/crush/internal/llm/provider" - "github.com/charmbracelet/crush/internal/llm/tools" - "github.com/charmbracelet/crush/internal/log" - "github.com/charmbracelet/crush/internal/lsp" - "github.com/charmbracelet/crush/internal/message" - "github.com/charmbracelet/crush/internal/permission" - "github.com/charmbracelet/crush/internal/pubsub" - "github.com/charmbracelet/crush/internal/session" - "github.com/charmbracelet/crush/internal/shell" -) - -type AgentEventType string - -const ( - AgentEventTypeError AgentEventType = "error" - AgentEventTypeResponse AgentEventType = "response" - AgentEventTypeSummarize AgentEventType = "summarize" -) - -type AgentEvent struct { - Type AgentEventType - Message message.Message - Error error - - // When summarizing - SessionID string - Progress string - Done bool -} - -type Service interface { - pubsub.Suscriber[AgentEvent] - Model() catwalk.Model - Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) - Cancel(sessionID string) - CancelAll() - IsSessionBusy(sessionID string) bool - IsBusy() bool - Summarize(ctx context.Context, sessionID string) error - UpdateModel() error - QueuedPrompts(sessionID string) int - ClearQueue(sessionID string) -} - -type agent struct { - *pubsub.Broker[AgentEvent] - agentCfg config.Agent - sessions session.Service - messages message.Service - permissions permission.Service - baseTools *csync.Map[string, tools.BaseTool] - mcpTools *csync.Map[string, tools.BaseTool] - lspClients *csync.Map[string, *lsp.Client] - - // We need this to be able to update it when model changes - agentToolFn func() (tools.BaseTool, error) - cleanupFuncs []func() - - provider provider.Provider - providerID string - - titleProvider provider.Provider - summarizeProvider provider.Provider - summarizeProviderID string - - activeRequests *csync.Map[string, context.CancelFunc] - promptQueue *csync.Map[string, []string] -} - -var agentPromptMap = map[string]prompt.PromptID{ - "coder": prompt.PromptCoder, - "task": prompt.PromptTask, -} - -func NewAgent( - ctx context.Context, - agentCfg config.Agent, - // These services are needed in the tools - permissions permission.Service, - sessions session.Service, - messages message.Service, - history history.Service, - lspClients *csync.Map[string, *lsp.Client], -) (Service, error) { - cfg := config.Get() - - var agentToolFn func() (tools.BaseTool, error) - if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) { - agentToolFn = func() (tools.BaseTool, error) { - taskAgentCfg := config.Get().Agents["task"] - if taskAgentCfg.ID == "" { - return nil, fmt.Errorf("task agent not found in config") - } - taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients) - if err != nil { - return nil, fmt.Errorf("failed to create task agent: %w", err) - } - return NewAgentTool(taskAgent, sessions, messages), nil - } - } - - providerCfg := config.Get().GetProviderForModel(agentCfg.Model) - if providerCfg == nil { - return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name) - } - model := config.Get().GetModelByType(agentCfg.Model) - - if model == nil { - return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name) - } - - promptID := agentPromptMap[agentCfg.ID] - if promptID == "" { - promptID = prompt.PromptDefault - } - opts := []provider.ProviderClientOption{ - provider.WithModel(agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)), - } - agentProvider, err := provider.NewProvider(*providerCfg, opts...) - if err != nil { - return nil, err - } - - smallModelCfg := cfg.Models[config.SelectedModelTypeSmall] - var smallModelProviderCfg *config.ProviderConfig - if smallModelCfg.Provider == providerCfg.ID { - smallModelProviderCfg = providerCfg - } else { - smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall) - - if smallModelProviderCfg.ID == "" { - return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider) - } - } - smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall) - if smallModel.ID == "" { - return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID) - } - - titleOpts := []provider.ProviderClientOption{ - provider.WithModel(config.SelectedModelTypeSmall), - provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), - } - titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...) - if err != nil { - return nil, err - } - - summarizeOpts := []provider.ProviderClientOption{ - provider.WithModel(config.SelectedModelTypeLarge), - provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)), - } - summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...) - if err != nil { - return nil, err - } - - baseToolsFn := func() map[string]tools.BaseTool { - slog.Info("Initializing agent base tools", "agent", agentCfg.ID) - defer func() { - slog.Info("Initialized agent base tools", "agent", agentCfg.ID) - }() - - // Base tools available to all agents - cwd := cfg.WorkingDir() - result := make(map[string]tools.BaseTool) - for _, tool := range []tools.BaseTool{ - tools.NewBashTool(permissions, cwd, cfg.Options.Attribution), - tools.NewDownloadTool(permissions, cwd), - tools.NewEditTool(lspClients, permissions, history, cwd), - tools.NewMultiEditTool(lspClients, permissions, history, cwd), - tools.NewFetchTool(permissions, cwd), - tools.NewGlobTool(cwd), - tools.NewGrepTool(cwd), - tools.NewLsTool(permissions, cwd), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients, permissions, cwd), - tools.NewWriteTool(lspClients, permissions, history, cwd), - } { - result[tool.Name()] = tool - } - return result - } - mcpToolsFn := func() map[string]tools.BaseTool { - slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID) - defer func() { - slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID) - }() - - mcpToolsOnce.Do(func() { - doGetMCPTools(ctx, permissions, cfg) - }) - - return maps.Collect(mcpTools.Seq2()) - } - - a := &agent{ - Broker: pubsub.NewBroker[AgentEvent](), - agentCfg: agentCfg, - provider: agentProvider, - providerID: string(providerCfg.ID), - messages: messages, - sessions: sessions, - titleProvider: titleProvider, - summarizeProvider: summarizeProvider, - summarizeProviderID: string(providerCfg.ID), - agentToolFn: agentToolFn, - activeRequests: csync.NewMap[string, context.CancelFunc](), - mcpTools: csync.NewLazyMap(mcpToolsFn), - baseTools: csync.NewLazyMap(baseToolsFn), - promptQueue: csync.NewMap[string, []string](), - permissions: permissions, - lspClients: lspClients, - } - a.setupEvents(ctx) - return a, nil -} - -func (a *agent) Model() catwalk.Model { - return *config.Get().GetModelByType(a.agentCfg.Model) -} - -func (a *agent) Cancel(sessionID string) { - // Cancel regular requests - if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil { - slog.Info("Request cancellation initiated", "session_id", sessionID) - cancel() - } - - // Also check for summarize requests - if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil { - slog.Info("Summarize cancellation initiated", "session_id", sessionID) - cancel() - } - - if a.QueuedPrompts(sessionID) > 0 { - slog.Info("Clearing queued prompts", "session_id", sessionID) - a.promptQueue.Del(sessionID) - } -} - -func (a *agent) IsBusy() bool { - var busy bool - for cancelFunc := range a.activeRequests.Seq() { - if cancelFunc != nil { - busy = true - break - } - } - return busy -} - -func (a *agent) IsSessionBusy(sessionID string) bool { - _, busy := a.activeRequests.Get(sessionID) - return busy -} - -func (a *agent) QueuedPrompts(sessionID string) int { - l, ok := a.promptQueue.Get(sessionID) - if !ok { - return 0 - } - return len(l) -} - -func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error { - if content == "" { - return nil - } - if a.titleProvider == nil { - return nil - } - session, err := a.sessions.Get(ctx, sessionID) - if err != nil { - return err - } - parts := []message.ContentPart{message.TextContent{ - Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content), - }} - - // Use streaming approach like summarization - response := a.titleProvider.StreamResponse( - ctx, - []message.Message{ - { - Role: message.User, - Parts: parts, - }, - }, - nil, - ) - - var finalResponse *provider.ProviderResponse - for r := range response { - if r.Error != nil { - return r.Error - } - finalResponse = r.Response - } - - if finalResponse == nil { - return fmt.Errorf("no response received from title provider") - } - - title := strings.ReplaceAll(finalResponse.Content, "\n", " ") - - if idx := strings.Index(title, ""); idx > 0 { - title = title[idx+len(""):] - } - - title = strings.TrimSpace(title) - if title == "" { - return nil - } - - session.Title = title - _, err = a.sessions.Save(ctx, session) - return err -} - -func (a *agent) err(err error) AgentEvent { - return AgentEvent{ - Type: AgentEventTypeError, - Error: err, - } -} - -func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) { - if !a.Model().SupportsImages && attachments != nil { - attachments = nil - } - events := make(chan AgentEvent, 1) - if a.IsSessionBusy(sessionID) { - existing, ok := a.promptQueue.Get(sessionID) - if !ok { - existing = []string{} - } - existing = append(existing, content) - a.promptQueue.Set(sessionID, existing) - return nil, nil - } - - genCtx, cancel := context.WithCancel(ctx) - a.activeRequests.Set(sessionID, cancel) - startTime := time.Now() - - go func() { - slog.Debug("Request started", "sessionID", sessionID) - defer log.RecoverPanic("agent.Run", func() { - events <- a.err(fmt.Errorf("panic while running the agent")) - }) - var attachmentParts []message.ContentPart - for _, attachment := range attachments { - attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content}) - } - result := a.processGeneration(genCtx, sessionID, content, attachmentParts) - if result.Error != nil { - if isCancelledErr(result.Error) { - slog.Error("Request canceled", "sessionID", sessionID) - } else { - slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error()) - event.Error(result.Error) - } - } else { - slog.Debug("Request completed", "sessionID", sessionID) - } - a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second)) - a.activeRequests.Del(sessionID) - cancel() - a.Publish(pubsub.CreatedEvent, result) - events <- result - close(events) - }() - a.eventPromptSent(sessionID) - return events, nil -} - -func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent { - cfg := config.Get() - // List existing messages; if none, start title generation asynchronously. - msgs, err := a.messages.List(ctx, sessionID) - if err != nil { - return a.err(fmt.Errorf("failed to list messages: %w", err)) - } - if len(msgs) == 0 { - go func() { - defer log.RecoverPanic("agent.Run", func() { - slog.Error("panic while generating title") - }) - titleErr := a.generateTitle(ctx, sessionID, content) - if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) { - slog.Error("failed to generate title", "error", titleErr) - } - }() - } - session, err := a.sessions.Get(ctx, sessionID) - if err != nil { - return a.err(fmt.Errorf("failed to get session: %w", err)) - } - if session.SummaryMessageID != "" { - summaryMsgInex := -1 - for i, msg := range msgs { - if msg.ID == session.SummaryMessageID { - summaryMsgInex = i - break - } - } - if summaryMsgInex != -1 { - msgs = msgs[summaryMsgInex:] - msgs[0].Role = message.User - } - } - - userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts) - if err != nil { - return a.err(fmt.Errorf("failed to create user message: %w", err)) - } - // Append the new user message to the conversation history. - msgHistory := append(msgs, userMsg) - - for { - // Check for cancellation before each iteration - select { - case <-ctx.Done(): - return a.err(ctx.Err()) - default: - // Continue processing - } - agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory) - if err != nil { - if errors.Is(err, context.Canceled) { - agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "") - a.messages.Update(context.Background(), agentMessage) - return a.err(ErrRequestCancelled) - } - return a.err(fmt.Errorf("failed to process events: %w", err)) - } - if cfg.Options.Debug { - slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults) - } - if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil { - // We are not done, we need to respond with the tool response - msgHistory = append(msgHistory, agentMessage, *toolResults) - // If there are queued prompts, process the next one - nextPrompt, ok := a.promptQueue.Take(sessionID) - if ok { - for _, prompt := range nextPrompt { - // Create a new user message for the queued prompt - userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil) - if err != nil { - return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err)) - } - // Append the new user message to the conversation history - msgHistory = append(msgHistory, userMsg) - } - } - - continue - } else if agentMessage.FinishReason() == message.FinishReasonEndTurn { - queuePrompts, ok := a.promptQueue.Take(sessionID) - if ok { - for _, prompt := range queuePrompts { - if prompt == "" { - continue - } - userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil) - if err != nil { - return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err)) - } - msgHistory = append(msgHistory, userMsg) - } - continue - } - } - if agentMessage.FinishReason() == "" { - // Kujtim: could not track down where this is happening but this means its cancelled - agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "") - _ = a.messages.Update(context.Background(), agentMessage) - return a.err(ErrRequestCancelled) - } - return AgentEvent{ - Type: AgentEventTypeResponse, - Message: agentMessage, - Done: true, - } - } -} - -func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) { - parts := []message.ContentPart{message.TextContent{Text: content}} - parts = append(parts, attachmentParts...) - return a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.User, - Parts: parts, - }) -} - -func (a *agent) getAllTools() ([]tools.BaseTool, error) { - var allTools []tools.BaseTool - for tool := range a.baseTools.Seq() { - if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) { - allTools = append(allTools, tool) - } - } - if a.agentCfg.ID == "coder" { - allTools = slices.AppendSeq(allTools, a.mcpTools.Seq()) - if a.lspClients.Len() > 0 { - allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients)) - } - } - if a.agentToolFn != nil { - agentTool, agentToolErr := a.agentToolFn() - if agentToolErr != nil { - return nil, agentToolErr - } - allTools = append(allTools, agentTool) - } - return allTools, nil -} - -func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { - ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - - // Create the assistant message first so the spinner shows immediately - assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - Model: a.Model().ID, - Provider: a.providerID, - }) - if err != nil { - return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err) - } - - allTools, toolsErr := a.getAllTools() - if toolsErr != nil { - return assistantMsg, nil, toolsErr - } - // Now collect tools (which may block on MCP initialization) - eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools) - - // Add the session and message ID into the context if needed by tools. - ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) - -loop: - for { - select { - case event, ok := <-eventChan: - if !ok { - break loop - } - if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil { - if errors.Is(processErr, context.Canceled) { - a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "") - } else { - a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error()) - } - return assistantMsg, nil, processErr - } - case <-ctx.Done(): - a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "") - return assistantMsg, nil, ctx.Err() - } - } - - toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls())) - toolCalls := assistantMsg.ToolCalls() - for i, toolCall := range toolCalls { - select { - case <-ctx.Done(): - a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "") - // Make all future tool calls cancelled - for j := i; j < len(toolCalls); j++ { - toolResults[j] = message.ToolResult{ - ToolCallID: toolCalls[j].ID, - Content: "Tool execution canceled by user", - IsError: true, - } - } - goto out - default: - // Continue processing - var tool tools.BaseTool - allTools, _ = a.getAllTools() - for _, availableTool := range allTools { - if availableTool.Info().Name == toolCall.Name { - tool = availableTool - break - } - } - - // Tool not found - if tool == nil { - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: fmt.Sprintf("Tool not found: %s", toolCall.Name), - IsError: true, - } - continue - } - - // Run tool in goroutine to allow cancellation - type toolExecResult struct { - response tools.ToolResponse - err error - } - resultChan := make(chan toolExecResult, 1) - - go func() { - response, err := tool.Run(ctx, tools.ToolCall{ - ID: toolCall.ID, - Name: toolCall.Name, - Input: toolCall.Input, - }) - resultChan <- toolExecResult{response: response, err: err} - }() - - var toolResponse tools.ToolResponse - var toolErr error - - select { - case <-ctx.Done(): - a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "") - // Mark remaining tool calls as cancelled - for j := i; j < len(toolCalls); j++ { - toolResults[j] = message.ToolResult{ - ToolCallID: toolCalls[j].ID, - Content: "Tool execution canceled by user", - IsError: true, - } - } - goto out - case result := <-resultChan: - toolResponse = result.response - toolErr = result.err - } - - if toolErr != nil { - slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr) - if errors.Is(toolErr, permission.ErrorPermissionDenied) { - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: "Permission denied", - IsError: true, - } - for j := i + 1; j < len(toolCalls); j++ { - toolResults[j] = message.ToolResult{ - ToolCallID: toolCalls[j].ID, - Content: "Tool execution canceled by user", - IsError: true, - } - } - a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "") - break - } - } - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: toolResponse.Content, - Metadata: toolResponse.Metadata, - IsError: toolResponse.IsError, - } - } - } -out: - if len(toolResults) == 0 { - return assistantMsg, nil, nil - } - parts := make([]message.ContentPart, 0) - for _, tr := range toolResults { - parts = append(parts, tr) - } - msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: parts, - Provider: a.providerID, - }) - if err != nil { - return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err) - } - - return assistantMsg, &msg, err -} - -func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) { - msg.AddFinish(finishReason, message, details) - _ = a.messages.Update(ctx, *msg) -} - -func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error { - select { - case <-ctx.Done(): - return ctx.Err() - default: - // Continue processing. - } - - switch event.Type { - case provider.EventThinkingDelta: - assistantMsg.AppendReasoningContent(event.Thinking) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventSignatureDelta: - assistantMsg.AppendReasoningSignature(event.Signature) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventContentDelta: - assistantMsg.FinishThinking() - assistantMsg.AppendContent(event.Content) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventToolUseStart: - assistantMsg.FinishThinking() - slog.Info("Tool call started", "toolCall", event.ToolCall) - assistantMsg.AddToolCall(*event.ToolCall) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventToolUseDelta: - assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventToolUseStop: - slog.Info("Finished tool call", "toolCall", event.ToolCall) - assistantMsg.FinishToolCall(event.ToolCall.ID) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventError: - return event.Error - case provider.EventComplete: - assistantMsg.FinishThinking() - assistantMsg.SetToolCalls(event.Response.ToolCalls) - assistantMsg.AddFinish(event.Response.FinishReason, "", "") - if err := a.messages.Update(ctx, *assistantMsg); err != nil { - return fmt.Errorf("failed to update message: %w", err) - } - return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage) - } - - return nil -} - -func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error { - sess, err := a.sessions.Get(ctx, sessionID) - if err != nil { - return fmt.Errorf("failed to get session: %w", err) - } - - cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + - model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + - model.CostPer1MIn/1e6*float64(usage.InputTokens) + - model.CostPer1MOut/1e6*float64(usage.OutputTokens) - - a.eventTokensUsed(sessionID, usage, cost) - - sess.Cost += cost - sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens - sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens - - _, err = a.sessions.Save(ctx, sess) - if err != nil { - return fmt.Errorf("failed to save session: %w", err) - } - return nil -} - -func (a *agent) Summarize(ctx context.Context, sessionID string) error { - if a.summarizeProvider == nil { - return fmt.Errorf("summarize provider not available") - } - - // Check if session is busy - if a.IsSessionBusy(sessionID) { - return ErrSessionBusy - } - - // Create a new context with cancellation - summarizeCtx, cancel := context.WithCancel(ctx) - - // Store the cancel function in activeRequests to allow cancellation - a.activeRequests.Set(sessionID+"-summarize", cancel) - - go func() { - defer a.activeRequests.Del(sessionID + "-summarize") - defer cancel() - event := AgentEvent{ - Type: AgentEventTypeSummarize, - Progress: "Starting summarization...", - } - - a.Publish(pubsub.CreatedEvent, event) - // Get all messages from the session - msgs, err := a.messages.List(summarizeCtx, sessionID) - if err != nil { - event = AgentEvent{ - Type: AgentEventTypeError, - Error: fmt.Errorf("failed to list messages: %w", err), - Done: true, - } - a.Publish(pubsub.CreatedEvent, event) - return - } - summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID) - - if len(msgs) == 0 { - event = AgentEvent{ - Type: AgentEventTypeError, - Error: fmt.Errorf("no messages to summarize"), - Done: true, - } - a.Publish(pubsub.CreatedEvent, event) - return - } - - event = AgentEvent{ - Type: AgentEventTypeSummarize, - Progress: "Analyzing conversation...", - } - a.Publish(pubsub.CreatedEvent, event) - - // Add a system message to guide the summarization - summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next." - - // Create a new message with the summarize prompt - promptMsg := message.Message{ - Role: message.User, - Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}}, - } - - // Append the prompt to the messages - msgsWithPrompt := append(msgs, promptMsg) - - event = AgentEvent{ - Type: AgentEventTypeSummarize, - Progress: "Generating summary...", - } - - a.Publish(pubsub.CreatedEvent, event) - - // Send the messages to the summarize provider - response := a.summarizeProvider.StreamResponse( - summarizeCtx, - msgsWithPrompt, - nil, - ) - var finalResponse *provider.ProviderResponse - for r := range response { - if r.Error != nil { - event = AgentEvent{ - Type: AgentEventTypeError, - Error: fmt.Errorf("failed to summarize: %w", r.Error), - Done: true, - } - a.Publish(pubsub.CreatedEvent, event) - return - } - finalResponse = r.Response - } - - summary := strings.TrimSpace(finalResponse.Content) - if summary == "" { - event = AgentEvent{ - Type: AgentEventTypeError, - Error: fmt.Errorf("empty summary returned"), - Done: true, - } - a.Publish(pubsub.CreatedEvent, event) - return - } - shell := shell.GetPersistentShell(config.Get().WorkingDir()) - summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir() - event = AgentEvent{ - Type: AgentEventTypeSummarize, - Progress: "Creating new session...", - } - - a.Publish(pubsub.CreatedEvent, event) - oldSession, err := a.sessions.Get(summarizeCtx, sessionID) - if err != nil { - event = AgentEvent{ - Type: AgentEventTypeError, - Error: fmt.Errorf("failed to get session: %w", err), - Done: true, - } - - a.Publish(pubsub.CreatedEvent, event) - return - } - // Create a message in the new session with the summary - msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{ - message.TextContent{Text: summary}, - message.Finish{ - Reason: message.FinishReasonEndTurn, - Time: time.Now().Unix(), - }, - }, - Model: a.summarizeProvider.Model().ID, - Provider: a.summarizeProviderID, - }) - if err != nil { - event = AgentEvent{ - Type: AgentEventTypeError, - Error: fmt.Errorf("failed to create summary message: %w", err), - Done: true, - } - - a.Publish(pubsub.CreatedEvent, event) - return - } - oldSession.SummaryMessageID = msg.ID - oldSession.CompletionTokens = finalResponse.Usage.OutputTokens - oldSession.PromptTokens = 0 - model := a.summarizeProvider.Model() - usage := finalResponse.Usage - cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + - model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + - model.CostPer1MIn/1e6*float64(usage.InputTokens) + - model.CostPer1MOut/1e6*float64(usage.OutputTokens) - oldSession.Cost += cost - _, err = a.sessions.Save(summarizeCtx, oldSession) - if err != nil { - event = AgentEvent{ - Type: AgentEventTypeError, - Error: fmt.Errorf("failed to save session: %w", err), - Done: true, - } - a.Publish(pubsub.CreatedEvent, event) - } - - event = AgentEvent{ - Type: AgentEventTypeSummarize, - SessionID: oldSession.ID, - Progress: "Summary complete", - Done: true, - } - a.Publish(pubsub.CreatedEvent, event) - // Send final success event with the new session ID - }() - - return nil -} - -func (a *agent) ClearQueue(sessionID string) { - if a.QueuedPrompts(sessionID) > 0 { - slog.Info("Clearing queued prompts", "session_id", sessionID) - a.promptQueue.Del(sessionID) - } -} - -func (a *agent) CancelAll() { - if !a.IsBusy() { - return - } - for key := range a.activeRequests.Seq2() { - a.Cancel(key) // key is sessionID - } - - for _, cleanup := range a.cleanupFuncs { - if cleanup != nil { - cleanup() - } - } - - timeout := time.After(5 * time.Second) - for a.IsBusy() { - select { - case <-timeout: - return - default: - time.Sleep(200 * time.Millisecond) - } - } -} - -func (a *agent) UpdateModel() error { - cfg := config.Get() - - // Get current provider configuration - currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model) - if currentProviderCfg == nil || currentProviderCfg.ID == "" { - return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name) - } - - // Check if provider has changed - if string(currentProviderCfg.ID) != a.providerID { - // Provider changed, need to recreate the main provider - model := cfg.GetModelByType(a.agentCfg.Model) - if model.ID == "" { - return fmt.Errorf("model not found for agent %s", a.agentCfg.Name) - } - - promptID := agentPromptMap[a.agentCfg.ID] - if promptID == "" { - promptID = prompt.PromptDefault - } - - opts := []provider.ProviderClientOption{ - provider.WithModel(a.agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)), - } - - newProvider, err := provider.NewProvider(*currentProviderCfg, opts...) - if err != nil { - return fmt.Errorf("failed to create new provider: %w", err) - } - - // Update the provider and provider ID - a.provider = newProvider - a.providerID = string(currentProviderCfg.ID) - } - - // Check if providers have changed for title (small) and summarize (large) - smallModelCfg := cfg.Models[config.SelectedModelTypeSmall] - var smallModelProviderCfg config.ProviderConfig - for p := range cfg.Providers.Seq() { - if p.ID == smallModelCfg.Provider { - smallModelProviderCfg = p - break - } - } - if smallModelProviderCfg.ID == "" { - return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider) - } - - largeModelCfg := cfg.Models[config.SelectedModelTypeLarge] - var largeModelProviderCfg config.ProviderConfig - for p := range cfg.Providers.Seq() { - if p.ID == largeModelCfg.Provider { - largeModelProviderCfg = p - break - } - } - if largeModelProviderCfg.ID == "" { - return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider) - } - - var maxTitleTokens int64 = 40 - - // if the max output is too low for the gemini provider it won't return anything - if smallModelCfg.Provider == "gemini" { - maxTitleTokens = 1000 - } - // Recreate title provider - titleOpts := []provider.ProviderClientOption{ - provider.WithModel(config.SelectedModelTypeSmall), - provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), - provider.WithMaxTokens(maxTitleTokens), - } - newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...) - if err != nil { - return fmt.Errorf("failed to create new title provider: %w", err) - } - a.titleProvider = newTitleProvider - - // Recreate summarize provider if provider changed (now large model) - if string(largeModelProviderCfg.ID) != a.summarizeProviderID { - largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge) - if largeModel == nil { - return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID) - } - summarizeOpts := []provider.ProviderClientOption{ - provider.WithModel(config.SelectedModelTypeLarge), - provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)), - } - newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...) - if err != nil { - return fmt.Errorf("failed to create new summarize provider: %w", err) - } - a.summarizeProvider = newSummarizeProvider - a.summarizeProviderID = string(largeModelProviderCfg.ID) - } - - return nil -} - -func (a *agent) setupEvents(ctx context.Context) { - ctx, cancel := context.WithCancel(ctx) - - go func() { - subCh := SubscribeMCPEvents(ctx) - - for { - select { - case event, ok := <-subCh: - if !ok { - slog.Debug("MCPEvents subscription channel closed") - return - } - switch event.Payload.Type { - case MCPEventToolsListChanged: - name := event.Payload.Name - c, ok := mcpClients.Get(name) - if !ok { - slog.Warn("MCP client not found for tools update", "name", name) - continue - } - cfg := config.Get() - tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir()) - if err != nil { - slog.Error("error listing tools", "error", err) - updateMCPState(name, MCPStateError, err, nil, 0) - _ = c.Close() - continue - } - updateMcpTools(name, tools) - a.mcpTools.Reset(maps.Collect(mcpTools.Seq2())) - updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len()) - default: - continue - } - case <-ctx.Done(): - slog.Debug("MCPEvents subscription cancelled") - return - } - } - }() - - a.cleanupFuncs = append(a.cleanupFuncs, cancel) -} diff --git a/internal/llm/agent/errors.go b/internal/llm/agent/errors.go deleted file mode 100644 index 0e2f983d64b42b93ad3a51f32ce0335b0374a613..0000000000000000000000000000000000000000 --- a/internal/llm/agent/errors.go +++ /dev/null @@ -1,15 +0,0 @@ -package agent - -import ( - "context" - "errors" -) - -var ( - ErrRequestCancelled = errors.New("request canceled by user") - ErrSessionBusy = errors.New("session is currently processing another request") -) - -func isCancelledErr(err error) bool { - return errors.Is(err, context.Canceled) || errors.Is(err, ErrRequestCancelled) -} diff --git a/internal/llm/agent/event.go b/internal/llm/agent/event.go deleted file mode 100644 index 8642d9990dc31689292abe9f2b39e685462f158e..0000000000000000000000000000000000000000 --- a/internal/llm/agent/event.go +++ /dev/null @@ -1,53 +0,0 @@ -package agent - -import ( - "time" - - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/event" - "github.com/charmbracelet/crush/internal/llm/provider" -) - -func (a *agent) eventPromptSent(sessionID string) { - event.PromptSent( - a.eventCommon(sessionID)..., - ) -} - -func (a *agent) eventPromptResponded(sessionID string, duration time.Duration) { - event.PromptResponded( - append( - a.eventCommon(sessionID), - "prompt duration pretty", duration.String(), - "prompt duration in seconds", int64(duration.Seconds()), - )..., - ) -} - -func (a *agent) eventTokensUsed(sessionID string, usage provider.TokenUsage, cost float64) { - event.TokensUsed( - append( - a.eventCommon(sessionID), - "input tokens", usage.InputTokens, - "output tokens", usage.OutputTokens, - "cache read tokens", usage.CacheReadTokens, - "cache creation tokens", usage.CacheCreationTokens, - "total tokens", usage.InputTokens+usage.OutputTokens+usage.CacheReadTokens+usage.CacheCreationTokens, - "cost", cost, - )..., - ) -} - -func (a *agent) eventCommon(sessionID string) []any { - cfg := config.Get() - currentModel := cfg.Models[cfg.Agents["coder"].Model] - - return []any{ - "session id", sessionID, - "provider", currentModel.Provider, - "model", currentModel.Model, - "reasoning effort", currentModel.ReasoningEffort, - "thinking mode", currentModel.Think, - "yolo mode", a.permissions.SkipRequests(), - } -} diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go deleted file mode 100644 index 181f32b7280faf3eb36040d2ebecf3f892350f53..0000000000000000000000000000000000000000 --- a/internal/llm/agent/mcp-tools.go +++ /dev/null @@ -1,466 +0,0 @@ -package agent - -import ( - "cmp" - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "maps" - "strings" - "sync" - "time" - - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/home" - "github.com/charmbracelet/crush/internal/llm/tools" - "github.com/charmbracelet/crush/internal/permission" - "github.com/charmbracelet/crush/internal/pubsub" - "github.com/charmbracelet/crush/internal/version" - "github.com/mark3labs/mcp-go/client" - "github.com/mark3labs/mcp-go/client/transport" - "github.com/mark3labs/mcp-go/mcp" -) - -// MCPState represents the current state of an MCP client -type MCPState int - -const ( - MCPStateDisabled MCPState = iota - MCPStateStarting - MCPStateConnected - MCPStateError -) - -func (s MCPState) String() string { - switch s { - case MCPStateDisabled: - return "disabled" - case MCPStateStarting: - return "starting" - case MCPStateConnected: - return "connected" - case MCPStateError: - return "error" - default: - return "unknown" - } -} - -// MCPEventType represents the type of MCP event -type MCPEventType string - -const ( - MCPEventStateChanged MCPEventType = "state_changed" - MCPEventToolsListChanged MCPEventType = "tools_list_changed" -) - -// MCPEvent represents an event in the MCP system -type MCPEvent struct { - Type MCPEventType - Name string - State MCPState - Error error - ToolCount int -} - -// MCPClientInfo holds information about an MCP client's state -type MCPClientInfo struct { - Name string - State MCPState - Error error - Client *client.Client - ToolCount int - ConnectedAt time.Time -} - -var ( - mcpToolsOnce sync.Once - mcpTools = csync.NewMap[string, tools.BaseTool]() - mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]() - mcpClients = csync.NewMap[string, *client.Client]() - mcpStates = csync.NewMap[string, MCPClientInfo]() - mcpBroker = pubsub.NewBroker[MCPEvent]() -) - -type McpTool struct { - mcpName string - tool mcp.Tool - permissions permission.Service - workingDir string -} - -func (b *McpTool) Name() string { - return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name) -} - -func (b *McpTool) Info() tools.ToolInfo { - required := b.tool.InputSchema.Required - if required == nil { - required = make([]string, 0) - } - parameters := b.tool.InputSchema.Properties - if parameters == nil { - parameters = make(map[string]any) - } - return tools.ToolInfo{ - Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name), - Description: b.tool.Description, - Parameters: parameters, - Required: required, - } -} - -func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) { - var args map[string]any - if err := json.Unmarshal([]byte(input), &args); err != nil { - return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil - } - - c, err := getOrRenewClient(ctx, name) - if err != nil { - return tools.NewTextErrorResponse(err.Error()), nil - } - result, err := c.CallTool(ctx, mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: toolName, - Arguments: args, - }, - }) - if err != nil { - return tools.NewTextErrorResponse(err.Error()), nil - } - - output := make([]string, 0, len(result.Content)) - for _, v := range result.Content { - if v, ok := v.(mcp.TextContent); ok { - output = append(output, v.Text) - } else { - output = append(output, fmt.Sprintf("%v", v)) - } - } - return tools.NewTextResponse(strings.Join(output, "\n")), nil -} - -func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) { - c, ok := mcpClients.Get(name) - if !ok { - return nil, fmt.Errorf("mcp '%s' not available", name) - } - - cfg := config.Get() - m := cfg.MCP[name] - state, _ := mcpStates.Get(name) - - timeout := mcpTimeout(m) - pingCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - err := c.Ping(pingCtx) - if err == nil { - return c, nil - } - updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount) - - c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver()) - if err != nil { - return nil, err - } - - updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount) - mcpClients.Set(name, c) - return c, nil -} - -func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) { - sessionID, messageID := tools.GetContextValues(ctx) - if sessionID == "" || messageID == "" { - return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") - } - permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name) - p := b.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - ToolCallID: params.ID, - Path: b.workingDir, - ToolName: b.Info().Name, - Action: "execute", - Description: permissionDescription, - Params: params.Input, - }, - ) - if !p { - return tools.ToolResponse{}, permission.ErrorPermissionDenied - } - - return runTool(ctx, b.mcpName, b.tool.Name, params.Input) -} - -func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) ([]tools.BaseTool, error) { - result, err := c.ListTools(ctx, mcp.ListToolsRequest{}) - if err != nil { - return nil, err - } - mcpTools := make([]tools.BaseTool, 0, len(result.Tools)) - for _, tool := range result.Tools { - mcpTools = append(mcpTools, &McpTool{ - mcpName: name, - tool: tool, - permissions: permissions, - workingDir: workingDir, - }) - } - return mcpTools, nil -} - -// SubscribeMCPEvents returns a channel for MCP events -func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] { - return mcpBroker.Subscribe(ctx) -} - -// GetMCPStates returns the current state of all MCP clients -func GetMCPStates() map[string]MCPClientInfo { - return maps.Collect(mcpStates.Seq2()) -} - -// GetMCPState returns the state of a specific MCP client -func GetMCPState(name string) (MCPClientInfo, bool) { - return mcpStates.Get(name) -} - -// updateMCPState updates the state of an MCP client and publishes an event -func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) { - info := MCPClientInfo{ - Name: name, - State: state, - Error: err, - Client: client, - ToolCount: toolCount, - } - switch state { - case MCPStateConnected: - info.ConnectedAt = time.Now() - case MCPStateError: - updateMcpTools(name, nil) - mcpClients.Del(name) - } - mcpStates.Set(name, info) - - // Publish state change event - mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{ - Type: MCPEventStateChanged, - Name: name, - State: state, - Error: err, - ToolCount: toolCount, - }) -} - -// publishMCPEventToolsListChanged publishes a tool list changed event -func publishMCPEventToolsListChanged(name string) { - mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{ - Type: MCPEventToolsListChanged, - Name: name, - }) -} - -// CloseMCPClients closes all MCP clients. This should be called during application shutdown. -func CloseMCPClients() error { - var errs []error - for name, c := range mcpClients.Seq2() { - if err := c.Close(); err != nil { - errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err)) - } - } - mcpBroker.Shutdown() - return errors.Join(errs...) -} - -var mcpInitRequest = mcp.InitializeRequest{ - Params: mcp.InitializeParams{ - ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, - ClientInfo: mcp.Implementation{ - Name: "Crush", - Version: version.Version, - }, - }, -} - -func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) { - var wg sync.WaitGroup - // Initialize states for all configured MCPs - for name, m := range cfg.MCP { - if m.Disabled { - updateMCPState(name, MCPStateDisabled, nil, nil, 0) - slog.Debug("skipping disabled mcp", "name", name) - continue - } - - // Set initial starting state - updateMCPState(name, MCPStateStarting, nil, nil, 0) - - wg.Add(1) - go func(name string, m config.MCPConfig) { - defer func() { - wg.Done() - if r := recover(); r != nil { - var err error - switch v := r.(type) { - case error: - err = v - case string: - err = fmt.Errorf("panic: %s", v) - default: - err = fmt.Errorf("panic: %v", v) - } - updateMCPState(name, MCPStateError, err, nil, 0) - slog.Error("panic in mcp client initialization", "error", err, "name", name) - } - }() - - ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m)) - defer cancel() - - c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver()) - if err != nil { - return - } - - mcpClients.Set(name, c) - - tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir()) - if err != nil { - slog.Error("error listing tools", "error", err) - updateMCPState(name, MCPStateError, err, nil, 0) - c.Close() - return - } - - updateMcpTools(name, tools) - mcpClients.Set(name, c) - updateMCPState(name, MCPStateConnected, nil, c, len(tools)) - }(name, m) - } - wg.Wait() -} - -// updateMcpTools updates the global mcpTools and mcpClient2Tools maps -func updateMcpTools(mcpName string, tools []tools.BaseTool) { - if len(tools) == 0 { - mcpClient2Tools.Del(mcpName) - } else { - mcpClient2Tools.Set(mcpName, tools) - } - for _, tools := range mcpClient2Tools.Seq2() { - for _, t := range tools { - mcpTools.Set(t.Name(), t) - } - } -} - -func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) { - c, err := createMcpClient(name, m, resolver) - if err != nil { - updateMCPState(name, MCPStateError, err, nil, 0) - slog.Error("error creating mcp client", "error", err, "name", name) - return nil, err - } - - c.OnNotification(func(n mcp.JSONRPCNotification) { - slog.Debug("Received MCP notification", "name", name, "notification", n) - switch n.Method { - case "notifications/tools/list_changed": - publishMCPEventToolsListChanged(name) - default: - slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method) - } - }) - - // XXX: ideally we should be able to use context.WithTimeout here, but, - // the SSE MCP client will start failing once that context is canceled. - timeout := mcpTimeout(m) - mcpCtx, cancel := context.WithCancel(ctx) - cancelTimer := time.AfterFunc(timeout, cancel) - - if err := c.Start(mcpCtx); err != nil { - updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) - slog.Error("error starting mcp client", "error", err, "name", name) - _ = c.Close() - cancel() - return nil, err - } - - if _, err := c.Initialize(mcpCtx, mcpInitRequest); err != nil { - updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) - slog.Error("error initializing mcp client", "error", err, "name", name) - _ = c.Close() - cancel() - return nil, err - } - - cancelTimer.Stop() - slog.Info("Initialized mcp client", "name", name) - return c, nil -} - -func maybeTimeoutErr(err error, timeout time.Duration) error { - if errors.Is(err, context.Canceled) { - return fmt.Errorf("timed out after %s", timeout) - } - return err -} - -func createMcpClient(name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) { - switch m.Type { - case config.MCPStdio: - command, err := resolver.ResolveValue(m.Command) - if err != nil { - return nil, fmt.Errorf("invalid mcp command: %w", err) - } - if strings.TrimSpace(command) == "" { - return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field") - } - return client.NewStdioMCPClientWithOptions( - home.Long(command), - m.ResolvedEnv(), - m.Args, - transport.WithCommandLogger(mcpLogger{name: name}), - ) - case config.MCPHttp: - if strings.TrimSpace(m.URL) == "" { - return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field") - } - return client.NewStreamableHttpClient( - m.URL, - transport.WithHTTPHeaders(m.ResolvedHeaders()), - transport.WithHTTPLogger(mcpLogger{name: name}), - ) - case config.MCPSse: - if strings.TrimSpace(m.URL) == "" { - return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field") - } - return client.NewSSEMCPClient( - m.URL, - client.WithHeaders(m.ResolvedHeaders()), - transport.WithSSELogger(mcpLogger{name: name}), - ) - default: - return nil, fmt.Errorf("unsupported mcp type: %s", m.Type) - } -} - -// for MCP's clients. -type mcpLogger struct{ name string } - -func (l mcpLogger) Errorf(format string, v ...any) { - slog.Error(fmt.Sprintf(format, v...), "name", l.name) -} - -func (l mcpLogger) Infof(format string, v ...any) { - slog.Info(fmt.Sprintf(format, v...), "name", l.name) -} - -func mcpTimeout(m config.MCPConfig) time.Duration { - return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second -} diff --git a/internal/llm/prompt/anthropic.md b/internal/llm/prompt/anthropic.md deleted file mode 100644 index c025e1cde544f0df596a8e855e69087e9b547746..0000000000000000000000000000000000000000 --- a/internal/llm/prompt/anthropic.md +++ /dev/null @@ -1,108 +0,0 @@ -You are Crush, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. - -IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. - -# Memory - -If the current working directory contains a file called CRUSH.md, it will be automatically added to your context. This file serves multiple purposes: - -1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time -2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.) -3. Maintaining useful information about the codebase structure and organization - -When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to CRUSH.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to CRUSH.md so you can remember it for next time. - -# Tone and style - -You should be concise, direct, and to the point. When you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system). -Remember that your output will be displayed on a command line interface. Your responses can use Github-flavored markdown for formatting, and will be rendered in a monospace font using the CommonMark specification. -Output text to communicate with the user; all text you output outside of tool use is displayed to the user. Only use tools to complete tasks. Never use tools like Bash or code comments as means to communicate with the user during the session. -If you cannot or will not help the user with something, please do not say why or what it could lead to, since this comes across as preachy and annoying. Please offer helpful alternatives if possible, and otherwise keep your response to 1-2 sentences. -IMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness, quality, and accuracy. Only address the specific query or task at hand, avoiding tangential information unless absolutely critical for completing the request. If you can answer in 1-3 sentences or a short paragraph, please do. -IMPORTANT: You should NOT answer with unnecessary preamble or postamble (such as explaining your code or summarizing your action), unless the user asks you to. -IMPORTANT: Keep your responses short, since they will be displayed on a command line interface. You MUST answer concisely with fewer than 4 lines (not including tool use or code generation), unless user asks for detail. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". Here are some examples to demonstrate appropriate verbosity: - -user: 2 + 2 -assistant: 4 - - - -user: what is 2+2? -assistant: 4 - - - -user: is 11 a prime number? -assistant: true - - - -user: what command should I run to list files in the current directory? -assistant: ls - - - -user: what command should I run to watch files in the current directory? -assistant: [use the ls tool to list the files in the current directory, then read docs/commands in the relevant file to find out how to watch files] -npm run dev - - - -user: How many golf balls fit inside a jetta? -assistant: 150000 - - - -user: what files are in the directory src/? -assistant: [runs ls and sees foo.c, bar.c, baz.c] -user: which file contains the implementation of foo? -assistant: src/foo.c - - - -user: write tests for new feature -assistant: [uses grep and glob search tools to find where similar tests are defined, uses concurrent read file tool use blocks in one tool call to read relevant files at the same time, uses edit file tool to write new tests] - - -# Proactiveness - -You are allowed to be proactive, but only when the user asks you to do something. You should strive to strike a balance between: - -1. Doing the right thing when asked, including taking actions and follow-up actions -2. Not surprising the user with actions you take without asking - For example, if the user asks you how to approach something, you should do your best to answer their question first, and not immediately jump into taking actions. -3. Do not add additional code explanation summary unless requested by the user. After working on a file, just stop, rather than providing an explanation of what you did. - -# Following conventions - -When making changes to files, first understand the file's code conventions. Mimic code style, use existing libraries and utilities, and follow existing patterns. - -- NEVER assume that a given library is available, even if it is well known. Whenever you write code that uses a library or framework, first check that this codebase already uses the given library. For example, you might look at neighboring files, or check the package.json (or cargo.toml, and so on depending on the language). -- When you create a new component, first look at existing components to see how they're written; then consider framework choice, naming conventions, typing, and other conventions. -- When you edit a piece of code, first look at the code's surrounding context (especially its imports) to understand the code's choice of frameworks and libraries. Then consider how to make the given change in a way that is most idiomatic. -- Always follow security best practices. Never introduce code that exposes or logs secrets and keys. Never commit secrets or keys to the repository. - -# Code style - -- IMPORTANT: DO NOT ADD **_ANY_** COMMENTS unless asked - -# Doing tasks - -The user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended: - -1. Use the available search tools to understand the codebase and the user's query. -2. Implement the solution using all tools available to you -3. Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach. -4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to CRUSH.md so that you will know to run it next time. - -NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive. - -# Tool usage policy - -- When doing file search, prefer to use the Agent tool in order to reduce context usage. -- IMPORTANT: All tools are executed in parallel when multiple tool calls are sent in a single message. Only send multiple tool calls when they are safe to run in parallel (no dependencies between them). -- IMPORTANT: The user does not see the full output of the tool responses, so if you need the output of the tool for the response make sure to summarize it for the user. - -VERY IMPORTANT NEVER use emojis in your responses. - -You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail. diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go deleted file mode 100644 index 57ed088b22de03fe875ad0822f159b35eb36a834..0000000000000000000000000000000000000000 --- a/internal/llm/prompt/coder.go +++ /dev/null @@ -1,100 +0,0 @@ -package prompt - -import ( - _ "embed" - "fmt" - "os" - "path/filepath" - "runtime" - "strconv" - "time" - - "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/tools" -) - -func CoderPrompt(p string, contextFiles ...string) string { - var basePrompt string - - basePrompt = string(anthropicCoderPrompt) - switch p { - case string(catwalk.InferenceProviderOpenAI): - // seems to behave better - basePrompt = string(coderV2Prompt) - case string(catwalk.InferenceProviderGemini): - basePrompt = string(geminiCoderPrompt) - } - if ok, _ := strconv.ParseBool(os.Getenv("CRUSH_CODER_V2")); ok { - basePrompt = string(coderV2Prompt) - } - envInfo := getEnvironmentInfo() - - basePrompt = fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) - - contextContent := getContextFromPaths(config.Get().WorkingDir(), contextFiles) - if contextContent != "" { - return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent) - } - return basePrompt -} - -//go:embed anthropic.md -var anthropicCoderPrompt []byte - -//go:embed gemini.md -var geminiCoderPrompt []byte - -//go:embed v2.md -var coderV2Prompt []byte - -func getEnvironmentInfo() string { - cwd := config.Get().WorkingDir() - isGit := isGitRepo(cwd) - platform := runtime.GOOS - date := time.Now().Format("1/2/2006") - output, _, _ := tools.ListDirectoryTree(cwd, tools.LSParams{}) - return fmt.Sprintf(`Here is useful information about the environment you are running in: - -Working directory: %s -Is directory a git repo: %s -Platform: %s -Today's date: %s - - -%s - - `, cwd, boolToYesNo(isGit), platform, date, output) -} - -func isGitRepo(dir string) bool { - _, err := os.Stat(filepath.Join(dir, ".git")) - return err == nil -} - -func lspInformation() string { - cfg := config.Get() - hasLSP := false - for _, v := range cfg.LSP { - if !v.Disabled { - hasLSP = true - break - } - } - if !hasLSP { - return "" - } - return `# LSP Information -Tools that support it will also include useful diagnostics such as linting and typechecking. -- These diagnostics will be automatically enabled when you run the tool, and will be displayed in the output at the bottom within the and tags. -- Take necessary actions to fix the issues. -- You should ignore diagnostics of files that you did not change or are not related or caused by your changes unless the user explicitly asks you to fix them. -` -} - -func boolToYesNo(b bool) string { - if b { - return "Yes" - } - return "No" -} diff --git a/internal/llm/prompt/gemini.md b/internal/llm/prompt/gemini.md deleted file mode 100644 index 4b91258d663ba0c9f845d000f4fe9d20420d5885..0000000000000000000000000000000000000000 --- a/internal/llm/prompt/gemini.md +++ /dev/null @@ -1,165 +0,0 @@ -You are an interactive CLI agent specializing in software engineering tasks. Your primary goal is to help users safely and efficiently, adhering strictly to the following instructions and utilizing your available tools. - -IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. - -# Memory - -If the current working directory contains a file called CRUSH.md, it will be automatically added to your context. This file serves multiple purposes: - -1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time -2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.) -3. Maintaining useful information about the codebase structure and organization - -When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to CRUSH.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to CRUSH.md so you can remember it for next time. - -# Core Mandates - -- **Conventions:** Rigorously adhere to existing project conventions when reading or modifying code. Analyze surrounding code, tests, and configuration first. -- **Libraries/Frameworks:** NEVER assume a library/framework is available or appropriate. Verify its established usage within the project (check imports, configuration files like 'package.json', 'Cargo.toml', 'requirements.txt', 'build.gradle', etc., or observe neighboring files) before employing it. -- **Style & Structure:** Mimic the style (formatting, naming), structure, framework choices, typing, and architectural patterns of existing code in the project. -- **Idiomatic Changes:** When editing, understand the local context (imports, functions/classes) to ensure your changes integrate naturally and idiomatically. -- **Comments:** Add code comments sparingly. Focus on _why_ something is done, especially for complex logic, rather than _what_ is done. Only add high-value comments if necessary for clarity or if requested by the user. Do not edit comments that are separate from the code you are changing. _NEVER_ talk to the user or describe your changes through comments. -- **Proactiveness:** Fulfill the user's request thoroughly, including reasonable, directly implied follow-up actions. -- **Confirm Ambiguity/Expansion:** Do not take significant actions beyond the clear scope of the request without confirming with the user. If asked _how_ to do something, explain first, don't just do it. -- **Explaining Changes:** After completing a code modification or file operation _do not_ provide summaries unless asked. -- **Do Not revert changes:** Do not revert changes to the codebase unless asked to do so by the user. Only revert changes made by you if they have resulted in an error or if the user has explicitly asked you to revert the changes. - -# Code style - -- IMPORTANT: DO NOT ADD **_ANY_** COMMENTS unless asked - -# Primary Workflows - -## Software Engineering Tasks - -When requested to perform tasks like fixing bugs, adding features, refactoring, or explaining code, follow this sequence: - -1. **Understand:** Think about the user's request and the relevant codebase context. Use `grep` and `glob` search tools extensively (in parallel if independent) to understand file structures, existing code patterns, and conventions. Use `view` to understand context and validate any assumptions you may have. -2. **Plan:** Build a coherent and grounded (based on the understanding in step 1) plan for how you intend to resolve the user's task. Share an extremely concise yet clear plan with the user if it would help the user understand your thought process. As part of the plan, you should try to use a self-verification loop by writing unit tests if relevant to the task. Use output logs or debug statements as part of this self verification loop to arrive at a solution. -3. **Implement:** Use the available tools (e.g., `edit`, `write` `bash` ...) to act on the plan, strictly adhering to the project's established conventions (detailed under 'Core Mandates'). -4. **Verify (Tests):** If applicable and feasible, verify the changes using the project's testing procedures. Identify the correct test commands and frameworks by examining 'README' files, build/package configuration (e.g., 'package.json'), or existing test execution patterns. NEVER assume standard test commands. -5. **Verify (Standards):** VERY IMPORTANT: After making code changes, execute the project-specific build, linting and type-checking commands (e.g., 'tsc', 'npm run lint', 'ruff check .') that you have identified for this project (or obtained from the user). This ensures code quality and adherence to standards. If unsure about these commands, you can ask the user if they'd like you to run them and if so how to. - -NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive. - -# Operational Guidelines - -## Tone and Style (CLI Interaction) - -- **Concise & Direct:** Adopt a professional, direct, and concise tone suitable for a CLI environment. -- **Minimal Output:** Aim for fewer than 3 lines of text output (excluding tool use/code generation) per response whenever practical. Focus strictly on the user's query. -- **Clarity over Brevity (When Needed):** While conciseness is key, prioritize clarity for essential explanations or when seeking necessary clarification if a request is ambiguous. -- **No Chitchat:** Avoid conversational filler, preambles ("Okay, I will now..."), or postambles ("I have finished the changes..."). Get straight to the action or answer. -- **Formatting:** Use GitHub-flavored Markdown. Responses will be rendered in monospace. -- **Tools vs. Text:** Use tools for actions, text output _only_ for communication. Do not add explanatory comments within tool calls or code blocks unless specifically part of the required code/command itself. -- **Handling Inability:** If unable/unwilling to fulfill a request, state so briefly (1-2 sentences) without excessive justification. Offer alternatives if appropriate. - -## Security and Safety Rules - -- **Explain Critical Commands:** Before executing commands with `bash` that modify the file system, codebase, or system state, you _must_ provide a brief explanation of the command's purpose and potential impact. Prioritize user understanding and safety. -- **Security First:** Always apply security best practices. Never introduce code that exposes, logs, or commits secrets, API keys, or other sensitive information. - -## Tool Usage - -- **File Paths:** Always use absolute paths when referring to files with tools like `view` or `write`. Relative paths are not supported. You must provide an absolute path. -- **Parallelism:** IMPORTANT: All tools are executed in parallel when multiple tool calls are sent in a single message. Only send multiple tool calls when they are safe to run in parallel (no dependencies between them). -- **Command Execution:** Use the `bash` tool for running shell commands, remembering the safety rule to explain modifying commands first. -- **Background Processes:** Use background processes (via `&`) for commands that are unlikely to stop on their own, e.g. `node server.js &`. If unsure, ask the user. -- **Interactive Commands:** Try to avoid shell commands that are likely to require user interaction (e.g. `git rebase -i`). Use non-interactive versions of commands (e.g. `npm init -y` instead of `npm init`) when available, and otherwise remind the user that interactive shell commands are not supported and may cause hangs until canceled by the user. - -# Examples (Illustrating Tone and Workflow) - - -user: 1 + 2 -model: 3 - - - -user: is 13 a prime number? -model: true - - - -user: list files here. -model: [tool_call: ls for path '.'] - - - -user: start the server implemented in server.js -model: [tool_call: bash for 'node server.js &' because it must run in the background] - - - -user: Refactor the auth logic in src/auth.py to use the requests library instead of urllib. -model: Okay, I can refactor 'src/auth.py'. -First, I'll analyze the code and check for a test safety net before planning any changes. -[tool_call: glob for pattern 'tests/test_auth.py'] -[tool_call: view for absolute_path '/path/to/tests/test_auth.py'] -(After analysis) -Great, 'tests/test_auth.py' exists and covers the core authentication logic. With this safety net in place, I can safely plan the refactoring. -I'll also confirm 'requests' is a dependency. -[tool_call: view for absolute_path '/path/to/requirements.txt'] -(After analysis) -Looks good, 'requests' is available. - -Here's the plan: - -1. Replace the 'urllib' calls with 'requests'. -2. Add proper 'try...except' error handling for the new network calls. -3. Remove the old 'urllib' import. -4. Run the project's linter and tests to verify the changes. - -Should I proceed? -user: Yes -model: -[tool_call: edit to apply the refactoring to 'src/auth.py'] -Refactoring complete. Running verification... -[tool_call: bash for 'ruff check src/auth.py && pytest'] -(After verification passes) -All checks passed. This is a stable checkpoint. - - - -user: Delete the temp directory. -model: I can run `rm -rf ./temp`. This will permanently delete the directory and all its contents. - - - -user: Write tests for someFile.ts -model: -Okay, I can write those tests. First, I'll read `someFile.ts` to understand its functionality. -[tool_call: view for absolute_path '/path/to/someFile.ts' or use glob to find `someFile.ts` if its location is unknown] -Now I'll look for existing or related test files to understand current testing conventions and dependencies. -[tool_call: glob for pattern '**/*.test.ts' and '**/*.spec.ts'] -(After reviewing existing tests and the file content) -[tool_call: write to create /path/to/someFile.test.ts with the test code] -I've written the tests. Now I'll run the project's test command to verify them. -[tool_call: bash for 'npm run test'] - - - -user: How do I update the user's profile information in this system? -model: -I'm not immediately sure how user profile information is updated. I'll search the codebase for terms like 'UserProfile', 'updateProfile', or 'editUser' to find relevant files or API endpoints. -[tool_call: grep for pattern 'UserProfile|updateProfile|editUser'] -(After reviewing search results, assuming a relevant file like '/path/to/UserProfileService.java' was found) -Okay, `/path/to/UserProfileService.java` seems like the most relevant file. I'll read its content to understand how updates are handled. -[tool_call: view for absolute_path '/path/to/UserProfileService.java'] -(After reading the file) -It appears the `updateUserProfile` method in `UserProfileService.java` is responsible for this. It expects a user ID and a `UserProfileDTO` object... - - - -user: Where are all the 'app.config' files in this project? I need to check their settings. -model: -[tool_call: glob for pattern '**/app.config'] -(Assuming glob returns a list of paths like ['/path/to/moduleA/app.config', '/path/to/moduleB/app.config']) -I found the following 'app.config' files: -- /path/to/moduleA/app.config -- /path/to/moduleB/app.config -To help you check their settings, I can read their contents. Which one would you like to start with, or should I read all of them? - - -# Final Reminder - -Your core function is efficient and safe assistance. Balance extreme conciseness with the crucial need for clarity, especially regarding safety and potential system modifications. Always prioritize user control and project conventions. Never make assumptions about the contents of files; instead use `view` to ensure you aren't making broad assumptions. Finally, you are an agent - please keep going until the user's query is completely resolved. diff --git a/internal/llm/prompt/init.md b/internal/llm/prompt/init.md deleted file mode 100644 index edada7f77b9c497873d71b6f5b1fd12fbeef387f..0000000000000000000000000000000000000000 --- a/internal/llm/prompt/init.md +++ /dev/null @@ -1,9 +0,0 @@ -`Please analyze this codebase and create a **CRUSH.md** file containing: - -- Build/lint/test commands - especially for running a single test -- Code style guidelines including imports, formatting, types, naming conventions, error handling, etc. - -The file you create will be given to agentic coding agents (such as yourself) that operate in this repository. Make it about 20-30 lines long. -If there's already a **CRUSH.md**, improve it. - -If there are Cursor rules (in `.cursor/rules/` or `.cursorrules`) or Copilot rules (in `.github/copilot-instructions.md`), make sure to include them. diff --git a/internal/llm/prompt/initialize.go b/internal/llm/prompt/initialize.go deleted file mode 100644 index 8dfe0d14006c48a90674e37e817b1235ebe381ea..0000000000000000000000000000000000000000 --- a/internal/llm/prompt/initialize.go +++ /dev/null @@ -1,10 +0,0 @@ -package prompt - -import _ "embed" - -//go:embed init.md -var initPrompt []byte - -func Initialize() string { - return string(initPrompt) -} diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go deleted file mode 100644 index 919686a7d248d6ac2f02ae21ff4a323b26fc536f..0000000000000000000000000000000000000000 --- a/internal/llm/prompt/prompt.go +++ /dev/null @@ -1,143 +0,0 @@ -package prompt - -import ( - "os" - "path/filepath" - "strings" - "sync" - - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/env" - "github.com/charmbracelet/crush/internal/home" -) - -type PromptID string - -const ( - PromptCoder PromptID = "coder" - PromptTitle PromptID = "title" - PromptTask PromptID = "task" - PromptSummarizer PromptID = "summarizer" - PromptDefault PromptID = "default" -) - -func GetPrompt(promptID PromptID, provider string, contextPaths ...string) string { - basePrompt := "" - switch promptID { - case PromptCoder: - basePrompt = CoderPrompt(provider, contextPaths...) - case PromptTitle: - basePrompt = TitlePrompt() - case PromptTask: - basePrompt = TaskPrompt() - case PromptSummarizer: - basePrompt = SummarizerPrompt() - default: - basePrompt = "You are a helpful assistant" - } - return basePrompt -} - -func getContextFromPaths(workingDir string, contextPaths []string) string { - return processContextPaths(workingDir, contextPaths) -} - -// expandPath expands ~ and environment variables in file paths -func expandPath(path string) string { - path = home.Long(path) - - // Handle environment variable expansion using the same pattern as config - if strings.HasPrefix(path, "$") { - resolver := config.NewEnvironmentVariableResolver(env.New()) - if expanded, err := resolver.ResolveValue(path); err == nil { - path = expanded - } - } - - return path -} - -func processContextPaths(workDir string, paths []string) string { - var ( - wg sync.WaitGroup - resultCh = make(chan string) - ) - - // Track processed files to avoid duplicates - processedFiles := csync.NewMap[string, bool]() - - for _, path := range paths { - wg.Add(1) - go func(p string) { - defer wg.Done() - - // Expand ~ and environment variables before processing - p = expandPath(p) - - // Use absolute path if provided, otherwise join with workDir - fullPath := p - if !filepath.IsAbs(p) { - fullPath = filepath.Join(workDir, p) - } - - // Check if the path is a directory using os.Stat - info, err := os.Stat(fullPath) - if err != nil { - return // Skip if path doesn't exist or can't be accessed - } - - if info.IsDir() { - filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error { - if err != nil { - return err - } - if !d.IsDir() { - // Check if we've already processed this file (case-insensitive) - lowerPath := strings.ToLower(path) - - if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed { - processedFiles.Set(lowerPath, true) - if result := processFile(path); result != "" { - resultCh <- result - } - } - } - return nil - }) - } else { - // It's a file, process it directly - // Check if we've already processed this file (case-insensitive) - lowerPath := strings.ToLower(fullPath) - - if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed { - processedFiles.Set(lowerPath, true) - result := processFile(fullPath) - if result != "" { - resultCh <- result - } - } - } - }(path) - } - - go func() { - wg.Wait() - close(resultCh) - }() - - results := make([]string, 0) - for result := range resultCh { - results = append(results, result) - } - - return strings.Join(results, "\n") -} - -func processFile(filePath string) string { - content, err := os.ReadFile(filePath) - if err != nil { - return "" - } - return "# From:" + filePath + "\n" + string(content) -} diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go deleted file mode 100644 index 66f9d438d9a5ab62d0f0871c718b166ad44795b0..0000000000000000000000000000000000000000 --- a/internal/llm/prompt/prompt_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package prompt - -import ( - "os" - "strings" - "testing" - - "github.com/charmbracelet/crush/internal/home" -) - -func TestExpandPath(t *testing.T) { - tests := []struct { - name string - input string - expected func() string - }{ - { - name: "regular path unchanged", - input: "/absolute/path", - expected: func() string { - return "/absolute/path" - }, - }, - { - name: "tilde expansion", - input: "~/documents", - expected: func() string { - return home.Dir() + "/documents" - }, - }, - { - name: "tilde only", - input: "~", - expected: func() string { - return home.Dir() - }, - }, - { - name: "environment variable expansion", - input: "$HOME", - expected: func() string { - return os.Getenv("HOME") - }, - }, - { - name: "relative path unchanged", - input: "relative/path", - expected: func() string { - return "relative/path" - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := expandPath(tt.input) - expected := tt.expected() - - // Skip test if environment variable is not set - if strings.HasPrefix(tt.input, "$") && expected == "" { - t.Skip("Environment variable not set") - } - - if result != expected { - t.Errorf("expandPath(%q) = %q, want %q", tt.input, result, expected) - } - }) - } -} diff --git a/internal/llm/prompt/summarize.md b/internal/llm/prompt/summarize.md deleted file mode 100644 index 5a40e6b3b7f7e55ec4d615aa2751d25f78832555..0000000000000000000000000000000000000000 --- a/internal/llm/prompt/summarize.md +++ /dev/null @@ -1,11 +0,0 @@ -You are a helpful AI assistant tasked with summarizing conversations. - -When asked to summarize, provide a detailed but concise summary of the conversation. -Focus on information that would be helpful for continuing the conversation, including: - -- What was done -- What is currently being worked on -- Which files are being modified -- What needs to be done next - -Your summary should be comprehensive enough to provide context but concise enough to be quickly understood. diff --git a/internal/llm/prompt/summarizer.go b/internal/llm/prompt/summarizer.go deleted file mode 100644 index e715128b3bd7e1236a3cb7dcd4e58ac6d4abf965..0000000000000000000000000000000000000000 --- a/internal/llm/prompt/summarizer.go +++ /dev/null @@ -1,10 +0,0 @@ -package prompt - -import _ "embed" - -//go:embed summarize.md -var summarizePrompt []byte - -func SummarizerPrompt() string { - return string(summarizePrompt) -} diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go deleted file mode 100644 index e4f021d4ab7ef9f49873bc6893a231d72f2f3994..0000000000000000000000000000000000000000 --- a/internal/llm/prompt/task.go +++ /dev/null @@ -1,15 +0,0 @@ -package prompt - -import ( - "fmt" -) - -func TaskPrompt() string { - agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question. -Notes: -1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". -2. When relevant, share file names and code snippets relevant to the query -3. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.` - - return fmt.Sprintf("%s\n%s\n", agentPrompt, getEnvironmentInfo()) -} diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go deleted file mode 100644 index 18a6e835122174f6798e5ccf69e61f9ca99a5251..0000000000000000000000000000000000000000 --- a/internal/llm/prompt/title.go +++ /dev/null @@ -1,10 +0,0 @@ -package prompt - -import _ "embed" - -//go:embed title.md -var titlePrompt []byte - -func TitlePrompt() string { - return string(titlePrompt) -} diff --git a/internal/llm/prompt/title.md b/internal/llm/prompt/title.md deleted file mode 100644 index 6da44069787ce6e5d69a6bb9f24b3dc5caa3782f..0000000000000000000000000000000000000000 --- a/internal/llm/prompt/title.md +++ /dev/null @@ -1,8 +0,0 @@ -you will generate a short title based on the first message a user begins a conversation with - -- ensure it is not more than 50 characters long -- the title should be a summary of the user's message -- it should be one line long -- do not use quotes or colons -- the entire text you return will be used as the title -- never return anything that is more than one sentence (one line) long diff --git a/internal/llm/prompt/v2.md b/internal/llm/prompt/v2.md deleted file mode 100644 index 1477ad915a17247b13ff00588674d2acbc25a125..0000000000000000000000000000000000000000 --- a/internal/llm/prompt/v2.md +++ /dev/null @@ -1,267 +0,0 @@ -You are Crush, an autonomous software engineering agent that helps users with coding tasks. Use the instructions below and the tools available to you to assist the user. - -# Core Principles - -You are an agent - please keep going until the user's query is completely resolved, before ending your turn and yielding back to the user. - -Your thinking should be thorough and so it's fine if it's very long. However, avoid unnecessary repetition and verbosity. You should be concise, but thorough. - -You MUST iterate and keep going until the problem is solved. - -You have everything you need to resolve this problem. I want you to fully solve this autonomously before coming back to me. - -Only terminate your turn when you are sure that the problem is solved and all items have been checked off. Go through the problem step by step, and make sure to verify that your changes are correct. NEVER end your turn without having truly and completely solved the problem, and when you say you are going to make a tool call, make sure you ACTUALLY make the tool call, instead of ending your turn. - -**IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames, directory structure, and existing codebase patterns.** - -When the user provides URLs or when you need to research external information, use the fetch tool to gather that information. If you find relevant links in the fetched content, follow them to gather comprehensive information. - -When working with third-party packages, libraries, or frameworks that you're unfamiliar with or need to verify usage patterns for, you can use the Sourcegraph tool to search for code examples across public repositories. This can help you understand best practices and common implementation patterns. - -Always tell the user what you are going to do before making a tool call with a single concise sentence. This will help them understand what you are doing and why. - -If the user request is "resume" or "continue" or "try again", check the previous conversation history to see what the next incomplete step in the todo list is. Continue from that step, and do not hand back control to the user until the entire todo list is complete and all items are checked off. Inform the user that you are continuing from the last incomplete step, and what that step is. - -Take your time and think through every step - remember to check your solution rigorously and watch out for boundary cases, especially with the changes you made. Use the sequential thinking approach if needed. Your solution must be perfect. If not, continue working on it. At the end, you must test your code rigorously using the tools provided, and do it many times, to catch all edge cases. If it is not robust, iterate more and make it perfect. Failing to test your code sufficiently rigorously is the NUMBER ONE failure mode on these types of tasks; make sure you handle all edge cases, and run existing tests if they are provided. - -You MUST plan extensively before each function call, and reflect extensively on the outcomes of the previous function calls. DO NOT do this entire process by making function calls only, as this can impair your ability to solve the problem and think insightfully. - -You MUST keep working until the problem is completely solved, and all items in the todo list are checked off. Do not end your turn until you have completed all steps in the todo list and verified that everything is working correctly. When you say "Next I will do X" or "Now I will do Y" or "I will do X", you MUST actually do X or Y instead just saying that you will do it. - -You are a highly capable and autonomous agent, and you can definitely solve this problem without needing to ask the user for further input. - -# Proactiveness and Balance - -You should strive to strike a balance between: - -1. Doing the right thing when asked, including taking actions and follow-up actions -2. Not surprising the user with actions you take without asking -3. Being thorough and autonomous while staying focused on the user's actual request - -For example, if the user asks you how to approach something, you should do your best to answer their question first, and not immediately jump into taking actions. However, when they ask you to solve a problem or implement something, be proactive in completing the entire task. - -# Workflow - -1. **Understand the Context**: Think about what the code you're editing is supposed to do based on filenames, directory structure, and existing patterns. -2. **Fetch URLs**: Fetch any URLs provided by the user using the `fetch` tool. -3. **Deep Problem Understanding**: Carefully read the issue and think critically about what is required. -4. **Codebase Investigation**: Explore relevant files, search for key functions, and gather context. -5. **Research**: If needed, research the problem using available tools. -6. **Plan Development**: Develop a clear, step-by-step plan with a todo list. -7. **Incremental Implementation**: Make small, testable code changes. -8. **Debug and Test**: Debug as needed and test frequently. -9. **Iterate**: Continue until the root cause is fixed and all tests pass. -10. **Comprehensive Validation**: Reflect and validate thoroughly after tests pass. - -Refer to the detailed sections below for more information on each step. - -## 1. Understanding Context and Fetching URLs - -- **Context First**: Before diving into code, understand what the existing code is supposed to do based on file names, directory structure, imports, and existing patterns. -- **URL Fetching**: If the user provides a URL, use the `fetch` tool to retrieve the content. -- **Recursive Information Gathering**: If you find additional relevant URLs or links, fetch those as well until you have all necessary information. - -## 2. Deep Problem Understanding - -Carefully read the issue and think hard about a plan to solve it before coding. Consider: - -- What is the expected behavior? -- What are the edge cases? -- What are the potential pitfalls? -- How does this fit into the larger context of the codebase? -- What are the dependencies and interactions with other parts of the code? - -## 3. Codebase Investigation - -- Explore relevant files and directories using `ls`, `view`, `glob`, and `grep` tools. -- Search for key functions, classes, or variables related to the issue. -- Read and understand relevant code snippets. -- Identify the root cause of the problem. -- Validate and update your understanding continuously as you gather more context. - -## 4. Research When Needed - -- Use the `sourcegraph` tool when you need to find code examples or verify usage patterns for libraries/frameworks. -- Use the `fetch` tool to retrieve documentation or other web resources. -- Look for patterns, best practices, and implementation examples. -- Focus your research on what's necessary to solve the specific problem at hand. - -## 5. Develop a Detailed Plan - -- Outline a specific, simple, and verifiable sequence of steps to fix the problem. -- Create a todo list in markdown format to track your progress. -- Each time you complete a step, check it off using `[x]` syntax. -- Each time you check off a step, display the updated todo list to the user. -- Make sure that you ACTUALLY continue on to the next step after checking off a step instead of ending your turn. - -## 6. Making Code Changes - -- Before editing, always read the relevant file contents or section to ensure complete context using the `view` tool. -- Always read at least 2000 lines of code at a time to ensure you have enough context. -- If a patch is not applied correctly, attempt to reapply it. -- Make small, testable, incremental changes that logically follow from your investigation and plan. -- Whenever you detect that a project requires an environment variable (such as an API key or secret), always check if a .env file exists in the project root. If it does not exist, automatically create a .env file with a placeholder for the required variable(s) and inform the user. Do this proactively, without waiting for the user to request it. -- Prefer using the `multiedit` tool when making multiple edits to the same file. - -## 7. Debugging and Testing - -- Use the `bash` tool to run commands and check for errors. -- Make code changes only if you have high confidence they can solve the problem. -- When debugging, try to determine the root cause rather than addressing symptoms. -- Debug for as long as needed to identify the root cause and identify a fix. -- Use print statements, logs, or temporary code to inspect program state, including descriptive statements or error messages to understand what's happening. -- To test hypotheses, you can also add test statements or functions. -- Revisit your assumptions if unexpected behavior occurs. -- **Test rigorously and frequently** - this is critical for success. - -# Memory - -If the current working directory contains a file called CRUSH.md, it will be automatically added to your context. This file serves multiple purposes: - -1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time -2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.) -3. Maintaining useful information about the codebase structure and organization - -When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to CRUSH.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to CRUSH.md so you can remember it for next time. - -# How to Create a Todo List - -Use the following format to create a todo list: - -```markdown -- [ ] Step 1: Description of the first step -- [ ] Step 2: Description of the second step -- [ ] Step 3: Description of the third step -``` - -Do not ever use HTML tags or any other formatting for the todo list, as it will not be rendered correctly. Always use the markdown format shown above. Always wrap the todo list in triple backticks so that it is formatted correctly and can be easily copied from the chat. - -Always show the completed todo list to the user as the last item in your message, so that they can see that you have addressed all of the steps. - -# Communication Guidelines - -Always communicate clearly and concisely in a casual, friendly yet professional tone. - - -"Let me fetch the URL you provided to gather more information." -"Ok, I've got all of the information I need on the API and I know how to use it." -"Now, I will search the codebase for the function that handles the API requests." -"I need to update several files here - stand by" -"OK! Now let's run the tests to make sure everything is working correctly." -"Whelp - I see we have some problems. Let's fix those up." - - -- Respond with clear, direct answers. Use bullet points and code blocks for structure. -- Avoid unnecessary explanations, repetition, and filler. -- Always write code directly to the correct files. -- Do not display code to the user unless they specifically ask for it. -- Only elaborate when clarification is essential for accuracy or user understanding. - -# Tone and Style - -You should be concise, direct, and to the point. When you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system). - -Remember that your output will be displayed on a command line interface. Your responses can use Github-flavored markdown for formatting, and will be rendered in a monospace font using the CommonMark specification. - -Output text to communicate with the user; all text you output outside of tool use is displayed to the user. Only use tools to complete tasks. Never use tools like Bash or code comments as means to communicate with the user during the session. - -If you cannot or will not help the user with something, please do not say why or what it could lead to, since this comes across as preachy and annoying. Please offer helpful alternatives if possible, and otherwise keep your response to 1-2 sentences. - -IMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness, quality, and accuracy. Only address the specific query or task at hand, avoiding tangential information unless absolutely critical for completing the request. - -IMPORTANT: You should NOT answer with unnecessary preamble or postamble (such as explaining your code or summarizing your action), unless the user asks you to. - -VERY IMPORTANT: NEVER use emojis in your responses. - -# Following Conventions - -When making changes to files, first understand the file's code conventions. Mimic code style, use existing libraries and utilities, and follow existing patterns. - -- NEVER assume that a given library is available, even if it is well known. Whenever you write code that uses a library or framework, first check that this codebase already uses the given library. For example, you might look at neighboring files, or check the package.json (or cargo.toml, and so on depending on the language). -- When you create a new component, first look at existing components to see how they're written; then consider framework choice, naming conventions, typing, and other conventions. -- When you edit a piece of code, first look at the code's surrounding context (especially its imports) to understand the code's choice of frameworks and libraries. Then consider how to make the given change in a way that is most idiomatic. -- Always follow security best practices. Never introduce code that exposes or logs secrets and keys. Never commit secrets or keys to the repository. - -# Code Style - -- IMPORTANT: DO NOT ADD **_ANY_** COMMENTS unless asked - -# Task Execution - -The user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended: - -1. Use the available search tools to understand the codebase and the user's query. -2. Implement the solution using all tools available to you -3. Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach. -4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to CRUSH.md so that you will know to run it next time. - -NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive. - -# Tool Usage Policy - -- When doing file search, prefer to use the Agent tool in order to reduce context usage. -- **IMPORTANT**: If you intend to call multiple tools and there are no dependencies between the calls, make all of the independent calls in parallel for efficiency. -- **IMPORTANT**: The user does not see the full output of the tool responses, so if you need the output of the tool for your response, make sure to summarize it for the user. -- All tools are executed in parallel when multiple tool calls are sent in a single message. Only send multiple tool calls when they are safe to run in parallel (no dependencies between them). - -# Reading Files and Folders - -**Always check if you have already read a file, folder, or workspace structure before reading it again.** - -- If you have already read the content and it has not changed, do NOT re-read it. -- Only re-read files or folders if: - - You suspect the content has changed since your last read. - - You have made edits to the file or folder. - - You encounter an error that suggests the context may be stale or incomplete. -- Use your internal memory and previous context to avoid redundant reads. -- This will save time, reduce unnecessary operations, and make your workflow more efficient. - -# Directory Context and Navigation - -**Always maintain awareness of your current working directory by tracking it mentally from the command history.** - -- **Remember directory changes**: When you use `cd` to change directories, mentally note and remember the new location for all subsequent operations. -- **Track your location from context**: Use the command history and previous `cd` commands to know where you currently are without constantly checking. -- **Check location only when commands fail**: If a command fails unexpectedly with file/path errors, then use `pwd` to verify your current directory as the failure might be due to being in the wrong location. -- **Use relative paths confidently**: Once you know your location, use relative paths appropriately based on your mental model of the current directory. -- **Maintain directory awareness across operations**: Keep track of where you are throughout a multi-step task, especially when working with files in different directories. - -**When to verify with `pwd`:** - -- After a command fails with "file not found" or similar path-related or `exit status 1` errors -- When resuming work or continuing from a previous step if uncertain -- When you realize you may have lost track of your current location - -**Mental tracking example:** - -```bash -# You start in /project/root -cd src/components # Now mentally note: I'm in /project/root/src/components -# Work with files here using relative paths -ls ./Button.tsx # This should work because I know I'm in components/ -# If this fails, THEN run pwd to double-check location -``` - -# Git and Version Control - -If the user tells you to stage and commit, you may do so. - -You are NEVER allowed to stage and commit files automatically. Only do this when explicitly requested. - -# Error Handling and Recovery - -- When you encounter errors, don't give up - analyze the error carefully and try alternative approaches. -- If a tool fails, try a different tool or approach to accomplish the same goal. -- When debugging, be systematic: isolate the problem, test hypotheses, and iterate until resolved. -- Always validate your solutions work correctly before considering the task complete. - -# Final Validation - -Before completing any task: - -1. Ensure all todo items are checked off -2. Run all relevant tests -3. Run linting and type checking if available -4. Verify the original problem is solved -5. Test edge cases and boundary conditions -6. Confirm no regressions were introduced diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go deleted file mode 100644 index 981ff4590fd7db92288ff11b3d8f607e594cb0fd..0000000000000000000000000000000000000000 --- a/internal/llm/provider/anthropic.go +++ /dev/null @@ -1,598 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log/slog" - "net/http" - "regexp" - "strconv" - "strings" - "time" - - "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/bedrock" - "github.com/anthropics/anthropic-sdk-go/option" - "github.com/anthropics/anthropic-sdk-go/vertex" - "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/tools" - "github.com/charmbracelet/crush/internal/log" - "github.com/charmbracelet/crush/internal/message" -) - -// Pre-compiled regex for parsing context limit errors. -var contextLimitRegex = regexp.MustCompile(`input length and ` + "`max_tokens`" + ` exceed context limit: (\d+) \+ (\d+) > (\d+)`) - -type anthropicClient struct { - providerOptions providerClientOptions - tp AnthropicClientType - client anthropic.Client - adjustedMaxTokens int // Used when context limit is hit -} - -type AnthropicClient ProviderClient - -type AnthropicClientType string - -const ( - AnthropicClientTypeNormal AnthropicClientType = "normal" - AnthropicClientTypeBedrock AnthropicClientType = "bedrock" - AnthropicClientTypeVertex AnthropicClientType = "vertex" -) - -func newAnthropicClient(opts providerClientOptions, tp AnthropicClientType) AnthropicClient { - return &anthropicClient{ - providerOptions: opts, - tp: tp, - client: createAnthropicClient(opts, tp), - } -} - -func createAnthropicClient(opts providerClientOptions, tp AnthropicClientType) anthropic.Client { - anthropicClientOptions := []option.RequestOption{} - - // Check if Authorization header is provided in extra headers - hasBearerAuth := false - if opts.extraHeaders != nil { - for key := range opts.extraHeaders { - if strings.ToLower(key) == "authorization" { - hasBearerAuth = true - break - } - } - } - - isBearerToken := strings.HasPrefix(opts.apiKey, "Bearer ") - - if opts.apiKey != "" && !hasBearerAuth { - if isBearerToken { - slog.Debug("API key starts with 'Bearer ', using as Authorization header") - anthropicClientOptions = append(anthropicClientOptions, option.WithHeader("Authorization", opts.apiKey)) - } else { - // Use standard X-Api-Key header - anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey)) - } - } else if hasBearerAuth { - slog.Debug("Skipping X-Api-Key header because Authorization header is provided") - } - - if opts.baseURL != "" { - resolvedBaseURL, err := config.Get().Resolve(opts.baseURL) - if err == nil && resolvedBaseURL != "" { - anthropicClientOptions = append(anthropicClientOptions, option.WithBaseURL(resolvedBaseURL)) - } - } - - if config.Get().Options.Debug { - httpClient := log.NewHTTPClient() - anthropicClientOptions = append(anthropicClientOptions, option.WithHTTPClient(httpClient)) - } - - switch tp { - case AnthropicClientTypeBedrock: - anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) - case AnthropicClientTypeVertex: - project := opts.extraParams["project"] - location := opts.extraParams["location"] - anthropicClientOptions = append(anthropicClientOptions, vertex.WithGoogleAuth(context.Background(), location, project)) - } - for key, header := range opts.extraHeaders { - anthropicClientOptions = append(anthropicClientOptions, option.WithHeaderAdd(key, header)) - } - for key, value := range opts.extraBody { - anthropicClientOptions = append(anthropicClientOptions, option.WithJSONSet(key, value)) - } - return anthropic.NewClient(anthropicClientOptions...) -} - -func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) { - for i, msg := range messages { - cache := false - if i > len(messages)-3 { - cache = true - } - switch msg.Role { - case message.User: - content := anthropic.NewTextBlock(msg.Content().String()) - if cache && !a.providerOptions.disableCache { - content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - } - var contentBlocks []anthropic.ContentBlockParamUnion - contentBlocks = append(contentBlocks, content) - for _, binaryContent := range msg.BinaryContent() { - base64Image := binaryContent.String(catwalk.InferenceProviderAnthropic) - imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image) - contentBlocks = append(contentBlocks, imageBlock) - } - anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...)) - - case message.Assistant: - blocks := []anthropic.ContentBlockParamUnion{} - - // Add thinking blocks first if present (required when thinking is enabled with tool use) - if reasoningContent := msg.ReasoningContent(); reasoningContent.Thinking != "" { - thinkingBlock := anthropic.NewThinkingBlock(reasoningContent.Signature, reasoningContent.Thinking) - blocks = append(blocks, thinkingBlock) - } - - if msg.Content().String() != "" { - content := anthropic.NewTextBlock(msg.Content().String()) - if cache && !a.providerOptions.disableCache { - content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - } - blocks = append(blocks, content) - } - - for _, toolCall := range msg.ToolCalls() { - if !toolCall.Finished { - continue - } - var inputMap map[string]any - err := json.Unmarshal([]byte(toolCall.Input), &inputMap) - if err != nil { - continue - } - blocks = append(blocks, anthropic.NewToolUseBlock(toolCall.ID, inputMap, toolCall.Name)) - } - - if len(blocks) == 0 { - continue - } - anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - - case message.Tool: - results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults())) - for i, toolResult := range msg.ToolResults() { - results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError) - } - anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...)) - } - } - return anthropicMessages -} - -func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam { - if len(tools) == 0 { - return nil - } - anthropicTools := make([]anthropic.ToolUnionParam, len(tools)) - - for i, tool := range tools { - info := tool.Info() - toolParam := anthropic.ToolParam{ - Name: info.Name, - Description: anthropic.String(info.Description), - InputSchema: anthropic.ToolInputSchemaParam{ - Properties: info.Parameters, - Required: info.Required, - }, - } - - if i == len(tools)-1 && !a.providerOptions.disableCache { - toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - } - - anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam} - } - - return anthropicTools -} - -func (a *anthropicClient) finishReason(reason string) message.FinishReason { - switch reason { - case "end_turn": - return message.FinishReasonEndTurn - case "max_tokens": - return message.FinishReasonMaxTokens - case "tool_use": - return message.FinishReasonToolUse - case "stop_sequence": - return message.FinishReasonEndTurn - default: - return message.FinishReasonUnknown - } -} - -func (a *anthropicClient) isThinkingEnabled() bool { - cfg := config.Get() - modelConfig := cfg.Models[config.SelectedModelTypeLarge] - if a.providerOptions.modelType == config.SelectedModelTypeSmall { - modelConfig = cfg.Models[config.SelectedModelTypeSmall] - } - return a.Model().CanReason && modelConfig.Think -} - -func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams { - model := a.providerOptions.model(a.providerOptions.modelType) - var thinkingParam anthropic.ThinkingConfigParamUnion - cfg := config.Get() - modelConfig := cfg.Models[config.SelectedModelTypeLarge] - if a.providerOptions.modelType == config.SelectedModelTypeSmall { - modelConfig = cfg.Models[config.SelectedModelTypeSmall] - } - temperature := anthropic.Float(0) - - maxTokens := model.DefaultMaxTokens - if modelConfig.MaxTokens > 0 { - maxTokens = modelConfig.MaxTokens - } - if a.isThinkingEnabled() { - thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(maxTokens) * 0.8)) - temperature = anthropic.Float(1) - } - // Override max tokens if set in provider options - if a.providerOptions.maxTokens > 0 { - maxTokens = a.providerOptions.maxTokens - } - - // Use adjusted max tokens if context limit was hit - if a.adjustedMaxTokens > 0 { - maxTokens = int64(a.adjustedMaxTokens) - } - - systemBlocks := []anthropic.TextBlockParam{} - - // Add custom system prompt prefix if configured - if a.providerOptions.systemPromptPrefix != "" { - systemBlocks = append(systemBlocks, anthropic.TextBlockParam{ - Text: a.providerOptions.systemPromptPrefix, - }) - } - - systemBlocks = append(systemBlocks, anthropic.TextBlockParam{ - Text: a.providerOptions.systemMessage, - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - }, - }) - - return anthropic.MessageNewParams{ - Model: anthropic.Model(model.ID), - MaxTokens: maxTokens, - Temperature: temperature, - Messages: messages, - Tools: tools, - Thinking: thinkingParam, - System: systemBlocks, - } -} - -func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { - attempts := 0 - for { - attempts++ - // Prepare messages on each attempt in case max_tokens was adjusted - preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) - - var opts []option.RequestOption - if a.isThinkingEnabled() { - opts = append(opts, option.WithHeaderAdd("anthropic-beta", "interleaved-thinking-2025-05-14")) - } - anthropicResponse, err := a.client.Messages.New( - ctx, - preparedMessages, - opts..., - ) - // If there is an error we are going to see if we can retry the call - if err != nil { - retry, after, retryErr := a.shouldRetry(attempts, err) - if retryErr != nil { - return nil, retryErr - } - if retry { - slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err) - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(time.Duration(after) * time.Millisecond): - continue - } - } - return nil, retryErr - } - - content := "" - for _, block := range anthropicResponse.Content { - if text, ok := block.AsAny().(anthropic.TextBlock); ok { - content += text.Text - } - } - - return &ProviderResponse{ - Content: content, - ToolCalls: a.toolCalls(*anthropicResponse), - Usage: a.usage(*anthropicResponse), - }, nil - } -} - -func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { - attempts := 0 - eventChan := make(chan ProviderEvent) - go func() { - for { - attempts++ - // Prepare messages on each attempt in case max_tokens was adjusted - preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) - - var opts []option.RequestOption - if a.isThinkingEnabled() { - opts = append(opts, option.WithHeaderAdd("anthropic-beta", "interleaved-thinking-2025-05-14")) - } - - anthropicStream := a.client.Messages.NewStreaming( - ctx, - preparedMessages, - opts..., - ) - accumulatedMessage := anthropic.Message{} - - currentToolCallID := "" - for anthropicStream.Next() { - event := anthropicStream.Current() - err := accumulatedMessage.Accumulate(event) - if err != nil { - slog.Warn("Error accumulating message", "error", err) - continue - } - - switch event := event.AsAny().(type) { - case anthropic.ContentBlockStartEvent: - switch event.ContentBlock.Type { - case "text": - eventChan <- ProviderEvent{Type: EventContentStart} - case "tool_use": - currentToolCallID = event.ContentBlock.ID - eventChan <- ProviderEvent{ - Type: EventToolUseStart, - ToolCall: &message.ToolCall{ - ID: event.ContentBlock.ID, - Name: event.ContentBlock.Name, - Finished: false, - }, - } - } - - case anthropic.ContentBlockDeltaEvent: - if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" { - eventChan <- ProviderEvent{ - Type: EventThinkingDelta, - Thinking: event.Delta.Thinking, - } - } else if event.Delta.Type == "signature_delta" && event.Delta.Signature != "" { - eventChan <- ProviderEvent{ - Type: EventSignatureDelta, - Signature: event.Delta.Signature, - } - } else if event.Delta.Type == "text_delta" && event.Delta.Text != "" { - eventChan <- ProviderEvent{ - Type: EventContentDelta, - Content: event.Delta.Text, - } - } else if event.Delta.Type == "input_json_delta" { - if currentToolCallID != "" { - eventChan <- ProviderEvent{ - Type: EventToolUseDelta, - ToolCall: &message.ToolCall{ - ID: currentToolCallID, - Finished: false, - Input: event.Delta.PartialJSON, - }, - } - } - } - case anthropic.ContentBlockStopEvent: - if currentToolCallID != "" { - eventChan <- ProviderEvent{ - Type: EventToolUseStop, - ToolCall: &message.ToolCall{ - ID: currentToolCallID, - }, - } - currentToolCallID = "" - } else { - eventChan <- ProviderEvent{Type: EventContentStop} - } - - case anthropic.MessageStopEvent: - content := "" - for _, block := range accumulatedMessage.Content { - if text, ok := block.AsAny().(anthropic.TextBlock); ok { - content += text.Text - } - } - - eventChan <- ProviderEvent{ - Type: EventComplete, - Response: &ProviderResponse{ - Content: content, - ToolCalls: a.toolCalls(accumulatedMessage), - Usage: a.usage(accumulatedMessage), - FinishReason: a.finishReason(string(accumulatedMessage.StopReason)), - }, - Content: content, - } - } - } - - err := anthropicStream.Err() - if err == nil || errors.Is(err, io.EOF) { - close(eventChan) - return - } - - // If there is an error we are going to see if we can retry the call - retry, after, retryErr := a.shouldRetry(attempts, err) - if retryErr != nil { - eventChan <- ProviderEvent{Type: EventError, Error: retryErr} - close(eventChan) - return - } - if retry { - slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err) - select { - case <-ctx.Done(): - // context cancelled - if ctx.Err() != nil { - eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} - } - close(eventChan) - return - case <-time.After(time.Duration(after) * time.Millisecond): - continue - } - } - if ctx.Err() != nil { - eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} - } - - close(eventChan) - return - } - }() - return eventChan -} - -func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) { - var apiErr *anthropic.Error - if !errors.As(err, &apiErr) { - return false, 0, err - } - - if attempts > maxRetries { - return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) - } - - if apiErr.StatusCode == http.StatusUnauthorized { - prev := a.providerOptions.apiKey - // in case the key comes from a script, we try to re-evaluate it. - a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey) - if err != nil { - return false, 0, fmt.Errorf("failed to resolve API key: %w", err) - } - // if it didn't change, do not retry. - if prev == a.providerOptions.apiKey { - return false, 0, err - } - a.client = createAnthropicClient(a.providerOptions, a.tp) - return true, 0, nil - } - - // Handle context limit exceeded error (400 Bad Request) - if apiErr.StatusCode == http.StatusBadRequest { - if adjusted, ok := a.handleContextLimitError(apiErr); ok { - a.adjustedMaxTokens = adjusted - slog.Debug("Adjusted max_tokens due to context limit", "new_max_tokens", adjusted) - return true, 0, nil - } - } - - isOverloaded := strings.Contains(apiErr.Error(), "overloaded") || strings.Contains(apiErr.Error(), "rate limit exceeded") - // 529 (unofficial): The service is overloaded - if apiErr.StatusCode != http.StatusTooManyRequests && apiErr.StatusCode != 529 && !isOverloaded { - return false, 0, err - } - - retryMs := 0 - retryAfterValues := apiErr.Response.Header.Values("Retry-After") - - backoffMs := 2000 * (1 << (attempts - 1)) - jitterMs := int(float64(backoffMs) * 0.2) - retryMs = backoffMs + jitterMs - if len(retryAfterValues) > 0 { - if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil { - retryMs = retryMs * 1000 - } - } - return true, int64(retryMs), nil -} - -// handleContextLimitError parses context limit error and returns adjusted max_tokens -func (a *anthropicClient) handleContextLimitError(apiErr *anthropic.Error) (int, bool) { - // Parse error message like: "input length and max_tokens exceed context limit: 154978 + 50000 > 200000" - errorMsg := apiErr.Error() - - matches := contextLimitRegex.FindStringSubmatch(errorMsg) - - if len(matches) != 4 { - return 0, false - } - - inputTokens, err1 := strconv.Atoi(matches[1]) - contextLimit, err2 := strconv.Atoi(matches[3]) - - if err1 != nil || err2 != nil { - return 0, false - } - - // Calculate safe max_tokens with a buffer of 1000 tokens - safeMaxTokens := contextLimit - inputTokens - 1000 - - // Ensure we don't go below a minimum threshold - safeMaxTokens = max(safeMaxTokens, 1000) - - return safeMaxTokens, true -} - -func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall { - var toolCalls []message.ToolCall - - for _, block := range msg.Content { - switch variant := block.AsAny().(type) { - case anthropic.ToolUseBlock: - toolCall := message.ToolCall{ - ID: variant.ID, - Name: variant.Name, - Input: string(variant.Input), - Type: string(variant.Type), - Finished: true, - } - toolCalls = append(toolCalls, toolCall) - } - } - - return toolCalls -} - -func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { - return TokenUsage{ - InputTokens: msg.Usage.InputTokens, - OutputTokens: msg.Usage.OutputTokens, - CacheCreationTokens: msg.Usage.CacheCreationInputTokens, - CacheReadTokens: msg.Usage.CacheReadInputTokens, - } -} - -func (a *anthropicClient) Model() catwalk.Model { - return a.providerOptions.model(a.providerOptions.modelType) -} diff --git a/internal/llm/provider/azure.go b/internal/llm/provider/azure.go deleted file mode 100644 index 9042d66876c6f22bd9c06a5f52f6b4502e32c0f2..0000000000000000000000000000000000000000 --- a/internal/llm/provider/azure.go +++ /dev/null @@ -1,39 +0,0 @@ -package provider - -import ( - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/log" - "github.com/openai/openai-go" - "github.com/openai/openai-go/azure" - "github.com/openai/openai-go/option" -) - -type azureClient struct { - *openaiClient -} - -type AzureClient ProviderClient - -func newAzureClient(opts providerClientOptions) AzureClient { - apiVersion := opts.extraParams["apiVersion"] - if apiVersion == "" { - apiVersion = "2025-01-01-preview" - } - - reqOpts := []option.RequestOption{ - azure.WithEndpoint(opts.baseURL, apiVersion), - } - - if config.Get().Options.Debug { - httpClient := log.NewHTTPClient() - reqOpts = append(reqOpts, option.WithHTTPClient(httpClient)) - } - - reqOpts = append(reqOpts, azure.WithAPIKey(opts.apiKey)) - base := &openaiClient{ - providerOptions: opts, - client: openai.NewClient(reqOpts...), - } - - return &azureClient{openaiClient: base} -} diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go deleted file mode 100644 index 526d11b5597859853be9314ed618748e3ae40f38..0000000000000000000000000000000000000000 --- a/internal/llm/provider/bedrock.go +++ /dev/null @@ -1,93 +0,0 @@ -package provider - -import ( - "context" - "errors" - "fmt" - "strings" - - "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/tools" - "github.com/charmbracelet/crush/internal/message" -) - -type bedrockClient struct { - providerOptions providerClientOptions - childProvider ProviderClient -} - -type BedrockClient ProviderClient - -func newBedrockClient(opts providerClientOptions) BedrockClient { - // Get AWS region from environment - region := opts.extraParams["region"] - if region == "" { - region = "us-east-1" // default region - } - if len(region) < 2 { - return &bedrockClient{ - providerOptions: opts, - childProvider: nil, // Will cause an error when used - } - } - - opts.model = func(modelType config.SelectedModelType) catwalk.Model { - model := config.Get().GetModelByType(modelType) - - // Prefix the model name with region - regionPrefix := region[:2] - modelName := model.ID - model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName) - return *model - } - - model := opts.model(opts.modelType) - - // Determine which provider to use based on the model - if strings.Contains(string(model.ID), "anthropic") { - // Create Anthropic client with Bedrock configuration - anthropicOpts := opts - // TODO: later find a way to check if the AWS account has caching enabled - opts.disableCache = true // Disable cache for Bedrock - return &bedrockClient{ - providerOptions: opts, - childProvider: newAnthropicClient(anthropicOpts, AnthropicClientTypeBedrock), - } - } - - // Return client with nil childProvider if model is not supported - // This will cause an error when used - return &bedrockClient{ - providerOptions: opts, - childProvider: nil, - } -} - -func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - if b.childProvider == nil { - return nil, errors.New("unsupported model for bedrock provider") - } - return b.childProvider.send(ctx, messages, tools) -} - -func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { - eventChan := make(chan ProviderEvent) - - if b.childProvider == nil { - go func() { - eventChan <- ProviderEvent{ - Type: EventError, - Error: errors.New("unsupported model for bedrock provider"), - } - close(eventChan) - }() - return eventChan - } - - return b.childProvider.stream(ctx, messages, tools) -} - -func (b *bedrockClient) Model() catwalk.Model { - return b.providerOptions.model(b.providerOptions.modelType) -} diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go deleted file mode 100644 index a846d8d582524bb6bf9c8ed31e3796ec8d94b419..0000000000000000000000000000000000000000 --- a/internal/llm/provider/gemini.go +++ /dev/null @@ -1,579 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log/slog" - "strings" - "time" - - "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/tools" - "github.com/charmbracelet/crush/internal/log" - "github.com/charmbracelet/crush/internal/message" - "github.com/google/uuid" - "google.golang.org/genai" -) - -type geminiClient struct { - providerOptions providerClientOptions - client *genai.Client -} - -type GeminiClient ProviderClient - -func newGeminiClient(opts providerClientOptions) GeminiClient { - client, err := createGeminiClient(opts) - if err != nil { - slog.Error("Failed to create Gemini client", "error", err) - return nil - } - - return &geminiClient{ - providerOptions: opts, - client: client, - } -} - -func createGeminiClient(opts providerClientOptions) (*genai.Client, error) { - cc := &genai.ClientConfig{ - APIKey: opts.apiKey, - Backend: genai.BackendGeminiAPI, - } - if opts.baseURL != "" { - resolvedBaseURL, err := config.Get().Resolve(opts.baseURL) - if err == nil && resolvedBaseURL != "" { - cc.HTTPOptions = genai.HTTPOptions{ - BaseURL: resolvedBaseURL, - } - } - } - if config.Get().Options.Debug { - cc.HTTPClient = log.NewHTTPClient() - } - client, err := genai.NewClient(context.Background(), cc) - if err != nil { - return nil, err - } - return client, nil -} - -func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content { - var history []*genai.Content - for _, msg := range messages { - switch msg.Role { - case message.User: - var parts []*genai.Part - parts = append(parts, &genai.Part{Text: msg.Content().String()}) - for _, binaryContent := range msg.BinaryContent() { - parts = append(parts, &genai.Part{InlineData: &genai.Blob{ - MIMEType: binaryContent.MIMEType, - Data: binaryContent.Data, - }}) - } - history = append(history, &genai.Content{ - Parts: parts, - Role: genai.RoleUser, - }) - case message.Assistant: - var assistantParts []*genai.Part - - if msg.Content().String() != "" { - assistantParts = append(assistantParts, &genai.Part{Text: msg.Content().String()}) - } - - if len(msg.ToolCalls()) > 0 { - for _, call := range msg.ToolCalls() { - if !call.Finished { - continue - } - args, _ := parseJSONToMap(call.Input) - assistantParts = append(assistantParts, &genai.Part{ - FunctionCall: &genai.FunctionCall{ - Name: call.Name, - Args: args, - }, - }) - } - } - - if len(assistantParts) > 0 { - history = append(history, &genai.Content{ - Role: genai.RoleModel, - Parts: assistantParts, - }) - } - - case message.Tool: - var toolParts []*genai.Part - for _, result := range msg.ToolResults() { - response := map[string]any{"result": result.Content} - parsed, err := parseJSONToMap(result.Content) - if err == nil { - response = parsed - } - - var toolCall message.ToolCall - for _, m := range messages { - if m.Role == message.Assistant { - for _, call := range m.ToolCalls() { - if call.ID == result.ToolCallID { - toolCall = call - break - } - } - } - } - - toolParts = append(toolParts, &genai.Part{ - FunctionResponse: &genai.FunctionResponse{ - Name: toolCall.Name, - Response: response, - }, - }) - } - if len(toolParts) > 0 { - history = append(history, &genai.Content{ - Parts: toolParts, - Role: genai.RoleUser, - }) - } - } - } - - return history -} - -func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool { - geminiTool := &genai.Tool{} - geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools)) - - for _, tool := range tools { - info := tool.Info() - declaration := &genai.FunctionDeclaration{ - Name: info.Name, - Description: info.Description, - Parameters: &genai.Schema{ - Type: genai.TypeObject, - Properties: convertSchemaProperties(info.Parameters), - Required: info.Required, - }, - } - - geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration) - } - - return []*genai.Tool{geminiTool} -} - -func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason { - switch reason { - case genai.FinishReasonStop: - return message.FinishReasonEndTurn - case genai.FinishReasonMaxTokens: - return message.FinishReasonMaxTokens - default: - return message.FinishReasonUnknown - } -} - -func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - // Convert messages - geminiMessages := g.convertMessages(messages) - model := g.providerOptions.model(g.providerOptions.modelType) - cfg := config.Get() - - modelConfig := cfg.Models[config.SelectedModelTypeLarge] - if g.providerOptions.modelType == config.SelectedModelTypeSmall { - modelConfig = cfg.Models[config.SelectedModelTypeSmall] - } - - maxTokens := model.DefaultMaxTokens - if modelConfig.MaxTokens > 0 { - maxTokens = modelConfig.MaxTokens - } - systemMessage := g.providerOptions.systemMessage - if g.providerOptions.systemPromptPrefix != "" { - systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage - } - history := geminiMessages[:len(geminiMessages)-1] // All but last message - lastMsg := geminiMessages[len(geminiMessages)-1] - config := &genai.GenerateContentConfig{ - MaxOutputTokens: int32(maxTokens), - SystemInstruction: &genai.Content{ - Parts: []*genai.Part{{Text: systemMessage}}, - }, - } - config.Tools = g.convertTools(tools) - chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) - - attempts := 0 - for { - attempts++ - var toolCalls []message.ToolCall - - var lastMsgParts []genai.Part - for _, part := range lastMsg.Parts { - lastMsgParts = append(lastMsgParts, *part) - } - resp, err := chat.SendMessage(ctx, lastMsgParts...) - // If there is an error we are going to see if we can retry the call - if err != nil { - retry, after, retryErr := g.shouldRetry(attempts, err) - if retryErr != nil { - return nil, retryErr - } - if retry { - slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err) - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(time.Duration(after) * time.Millisecond): - continue - } - } - return nil, retryErr - } - - content := "" - - if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { - for _, part := range resp.Candidates[0].Content.Parts { - switch { - case part.Text != "": - content = string(part.Text) - case part.FunctionCall != nil: - id := "call_" + uuid.New().String() - args, _ := json.Marshal(part.FunctionCall.Args) - toolCalls = append(toolCalls, message.ToolCall{ - ID: id, - Name: part.FunctionCall.Name, - Input: string(args), - Type: "function", - Finished: true, - }) - } - } - } - finishReason := message.FinishReasonEndTurn - if len(resp.Candidates) > 0 { - finishReason = g.finishReason(resp.Candidates[0].FinishReason) - } - if len(toolCalls) > 0 { - finishReason = message.FinishReasonToolUse - } - - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: g.usage(resp), - FinishReason: finishReason, - }, nil - } -} - -func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { - // Convert messages - geminiMessages := g.convertMessages(messages) - - model := g.providerOptions.model(g.providerOptions.modelType) - cfg := config.Get() - - modelConfig := cfg.Models[config.SelectedModelTypeLarge] - if g.providerOptions.modelType == config.SelectedModelTypeSmall { - modelConfig = cfg.Models[config.SelectedModelTypeSmall] - } - maxTokens := model.DefaultMaxTokens - if modelConfig.MaxTokens > 0 { - maxTokens = modelConfig.MaxTokens - } - - // Override max tokens if set in provider options - if g.providerOptions.maxTokens > 0 { - maxTokens = g.providerOptions.maxTokens - } - systemMessage := g.providerOptions.systemMessage - if g.providerOptions.systemPromptPrefix != "" { - systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage - } - history := geminiMessages[:len(geminiMessages)-1] // All but last message - lastMsg := geminiMessages[len(geminiMessages)-1] - config := &genai.GenerateContentConfig{ - MaxOutputTokens: int32(maxTokens), - SystemInstruction: &genai.Content{ - Parts: []*genai.Part{{Text: systemMessage}}, - }, - } - config.Tools = g.convertTools(tools) - chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) - - attempts := 0 - eventChan := make(chan ProviderEvent) - - go func() { - defer close(eventChan) - - for { - attempts++ - - currentContent := "" - toolCalls := []message.ToolCall{} - var finalResp *genai.GenerateContentResponse - - eventChan <- ProviderEvent{Type: EventContentStart} - - var lastMsgParts []genai.Part - - for _, part := range lastMsg.Parts { - lastMsgParts = append(lastMsgParts, *part) - } - - for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) { - if err != nil { - retry, after, retryErr := g.shouldRetry(attempts, err) - if retryErr != nil { - eventChan <- ProviderEvent{Type: EventError, Error: retryErr} - return - } - if retry { - slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err) - select { - case <-ctx.Done(): - if ctx.Err() != nil { - eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} - } - - return - case <-time.After(time.Duration(after) * time.Millisecond): - continue - } - } else { - eventChan <- ProviderEvent{Type: EventError, Error: err} - return - } - } - - finalResp = resp - - if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { - for _, part := range resp.Candidates[0].Content.Parts { - switch { - case part.Text != "": - delta := string(part.Text) - if delta != "" { - eventChan <- ProviderEvent{ - Type: EventContentDelta, - Content: delta, - } - currentContent += delta - } - case part.FunctionCall != nil: - id := "call_" + uuid.New().String() - args, _ := json.Marshal(part.FunctionCall.Args) - newCall := message.ToolCall{ - ID: id, - Name: part.FunctionCall.Name, - Input: string(args), - Type: "function", - Finished: true, - } - - toolCalls = append(toolCalls, newCall) - } - } - } else { - // no content received - break - } - } - - eventChan <- ProviderEvent{Type: EventContentStop} - - if finalResp != nil { - finishReason := message.FinishReasonEndTurn - if len(finalResp.Candidates) > 0 { - finishReason = g.finishReason(finalResp.Candidates[0].FinishReason) - } - if len(toolCalls) > 0 { - finishReason = message.FinishReasonToolUse - } - eventChan <- ProviderEvent{ - Type: EventComplete, - Response: &ProviderResponse{ - Content: currentContent, - ToolCalls: toolCalls, - Usage: g.usage(finalResp), - FinishReason: finishReason, - }, - } - return - } else { - eventChan <- ProviderEvent{ - Type: EventError, - Error: errors.New("no content received"), - } - } - } - }() - - return eventChan -} - -func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) { - // Check if error is a rate limit error - if attempts > maxRetries { - return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) - } - - // Gemini doesn't have a standard error type we can check against - // So we'll check the error message for rate limit indicators - if errors.Is(err, io.EOF) { - return false, 0, err - } - - errMsg := err.Error() - isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests") - - // Check for token expiration (401 Unauthorized) - if contains(errMsg, "unauthorized", "invalid api key", "api key expired") { - prev := g.providerOptions.apiKey - // in case the key comes from a script, we try to re-evaluate it. - g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey) - if err != nil { - return false, 0, fmt.Errorf("failed to resolve API key: %w", err) - } - // if it didn't change, do not retry. - if prev == g.providerOptions.apiKey { - return false, 0, err - } - g.client, err = createGeminiClient(g.providerOptions) - if err != nil { - return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err) - } - return true, 0, nil - } - - // Check for common rate limit error messages - - if !isRateLimit { - return false, 0, err - } - - // Calculate backoff with jitter - backoffMs := 2000 * (1 << (attempts - 1)) - jitterMs := int(float64(backoffMs) * 0.2) - retryMs := backoffMs + jitterMs - - return true, int64(retryMs), nil -} - -func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { - if resp == nil || resp.UsageMetadata == nil { - return TokenUsage{} - } - - return TokenUsage{ - InputTokens: int64(resp.UsageMetadata.PromptTokenCount), - OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount), - CacheCreationTokens: 0, // Not directly provided by Gemini - CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount), - } -} - -func (g *geminiClient) Model() catwalk.Model { - return g.providerOptions.model(g.providerOptions.modelType) -} - -// Helper functions -func parseJSONToMap(jsonStr string) (map[string]any, error) { - var result map[string]any - err := json.Unmarshal([]byte(jsonStr), &result) - return result, err -} - -func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema { - properties := make(map[string]*genai.Schema) - - for name, param := range parameters { - properties[name] = convertToSchema(param) - } - - return properties -} - -func convertToSchema(param any) *genai.Schema { - schema := &genai.Schema{Type: genai.TypeString} - - paramMap, ok := param.(map[string]any) - if !ok { - return schema - } - - if desc, ok := paramMap["description"].(string); ok { - schema.Description = desc - } - - typeVal, hasType := paramMap["type"] - if !hasType { - return schema - } - - typeStr, ok := typeVal.(string) - if !ok { - return schema - } - - schema.Type = mapJSONTypeToGenAI(typeStr) - - switch typeStr { - case "array": - schema.Items = processArrayItems(paramMap) - case "object": - if props, ok := paramMap["properties"].(map[string]any); ok { - schema.Properties = convertSchemaProperties(props) - } - } - - return schema -} - -func processArrayItems(paramMap map[string]any) *genai.Schema { - items, ok := paramMap["items"].(map[string]any) - if !ok { - return nil - } - - return convertToSchema(items) -} - -func mapJSONTypeToGenAI(jsonType string) genai.Type { - switch jsonType { - case "string": - return genai.TypeString - case "number": - return genai.TypeNumber - case "integer": - return genai.TypeInteger - case "boolean": - return genai.TypeBoolean - case "array": - return genai.TypeArray - case "object": - return genai.TypeObject - default: - return genai.TypeString // Default to string for unknown types - } -} - -func contains(s string, substrs ...string) bool { - for _, substr := range substrs { - if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) { - return true - } - } - return false -} diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go deleted file mode 100644 index 3e92e077b3156ddccc186e0b104b7db174290c18..0000000000000000000000000000000000000000 --- a/internal/llm/provider/openai.go +++ /dev/null @@ -1,604 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log/slog" - "net/http" - "strings" - "time" - - "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/tools" - "github.com/charmbracelet/crush/internal/log" - "github.com/charmbracelet/crush/internal/message" - "github.com/google/uuid" - "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/packages/param" - "github.com/openai/openai-go/shared" -) - -type openaiClient struct { - providerOptions providerClientOptions - client openai.Client -} - -type OpenAIClient ProviderClient - -func newOpenAIClient(opts providerClientOptions) OpenAIClient { - return &openaiClient{ - providerOptions: opts, - client: createOpenAIClient(opts), - } -} - -func createOpenAIClient(opts providerClientOptions) openai.Client { - openaiClientOptions := []option.RequestOption{} - if opts.apiKey != "" { - openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) - } - if opts.baseURL != "" { - resolvedBaseURL, err := config.Get().Resolve(opts.baseURL) - if err == nil && resolvedBaseURL != "" { - openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL)) - } - } - - if config.Get().Options.Debug { - httpClient := log.NewHTTPClient() - openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(httpClient)) - } - - for key, value := range opts.extraHeaders { - openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value)) - } - - for extraKey, extraValue := range opts.extraBody { - openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue)) - } - - return openai.NewClient(openaiClientOptions...) -} - -func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { - isAnthropicModel := o.providerOptions.config.ID == string(catwalk.InferenceProviderOpenRouter) && strings.HasPrefix(o.Model().ID, "anthropic/") - // Add system message first - systemMessage := o.providerOptions.systemMessage - if o.providerOptions.systemPromptPrefix != "" { - systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage - } - - system := openai.SystemMessage(systemMessage) - if isAnthropicModel && !o.providerOptions.disableCache { - systemTextBlock := openai.ChatCompletionContentPartTextParam{Text: systemMessage} - systemTextBlock.SetExtraFields( - map[string]any{ - "cache_control": map[string]string{ - "type": "ephemeral", - }, - }, - ) - var content []openai.ChatCompletionContentPartTextParam - content = append(content, systemTextBlock) - system = openai.SystemMessage(content) - } - openaiMessages = append(openaiMessages, system) - - for i, msg := range messages { - cache := false - if i > len(messages)-3 { - cache = true - } - switch msg.Role { - case message.User: - var content []openai.ChatCompletionContentPartUnionParam - - textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()} - content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock}) - hasBinaryContent := false - for _, binaryContent := range msg.BinaryContent() { - hasBinaryContent = true - imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)} - imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL} - - content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock}) - } - if cache && !o.providerOptions.disableCache && isAnthropicModel { - textBlock.SetExtraFields(map[string]any{ - "cache_control": map[string]string{ - "type": "ephemeral", - }, - }) - } - if hasBinaryContent || (isAnthropicModel && !o.providerOptions.disableCache) { - openaiMessages = append(openaiMessages, openai.UserMessage(content)) - } else { - openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String())) - } - - case message.Assistant: - assistantMsg := openai.ChatCompletionAssistantMessageParam{ - Role: "assistant", - } - - // Only include finished tool calls; interrupted tool calls must not be resent. - if len(msg.ToolCalls()) > 0 { - finished := make([]message.ToolCall, 0, len(msg.ToolCalls())) - for _, call := range msg.ToolCalls() { - if call.Finished { - finished = append(finished, call) - } - } - if len(finished) > 0 { - assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(finished)) - for i, call := range finished { - assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{ - ID: call.ID, - Type: "function", - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Name: call.Name, - Arguments: call.Input, - }, - } - } - } - } - if msg.Content().String() != "" { - assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{ - OfString: param.NewOpt(msg.Content().Text), - } - } - - if cache && !o.providerOptions.disableCache && isAnthropicModel { - assistantMsg.SetExtraFields(map[string]any{ - "cache_control": map[string]string{ - "type": "ephemeral", - }, - }) - } - // Skip empty assistant messages (no content and no finished tool calls) - if msg.Content().String() == "" && len(assistantMsg.ToolCalls) == 0 { - continue - } - - openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{ - OfAssistant: &assistantMsg, - }) - - case message.Tool: - for _, result := range msg.ToolResults() { - openaiMessages = append(openaiMessages, - openai.ToolMessage(result.Content, result.ToolCallID), - ) - } - } - } - - return openaiMessages -} - -func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam { - openaiTools := make([]openai.ChatCompletionToolParam, len(tools)) - - for i, tool := range tools { - info := tool.Info() - openaiTools[i] = openai.ChatCompletionToolParam{ - Function: openai.FunctionDefinitionParam{ - Name: info.Name, - Description: openai.String(info.Description), - Parameters: openai.FunctionParameters{ - "type": "object", - "properties": info.Parameters, - "required": info.Required, - }, - }, - } - } - - return openaiTools -} - -func (o *openaiClient) finishReason(reason string) message.FinishReason { - switch reason { - case "stop": - return message.FinishReasonEndTurn - case "length": - return message.FinishReasonMaxTokens - case "tool_calls": - return message.FinishReasonToolUse - default: - return message.FinishReasonUnknown - } -} - -func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { - model := o.providerOptions.model(o.providerOptions.modelType) - cfg := config.Get() - - modelConfig := cfg.Models[config.SelectedModelTypeLarge] - if o.providerOptions.modelType == config.SelectedModelTypeSmall { - modelConfig = cfg.Models[config.SelectedModelTypeSmall] - } - - reasoningEffort := modelConfig.ReasoningEffort - - params := openai.ChatCompletionNewParams{ - Model: openai.ChatModel(model.ID), - Messages: messages, - Tools: tools, - } - - maxTokens := model.DefaultMaxTokens - if modelConfig.MaxTokens > 0 { - maxTokens = modelConfig.MaxTokens - } - - // Override max tokens if set in provider options - if o.providerOptions.maxTokens > 0 { - maxTokens = o.providerOptions.maxTokens - } - if model.CanReason { - params.MaxCompletionTokens = openai.Int(maxTokens) - switch reasoningEffort { - case "low": - params.ReasoningEffort = shared.ReasoningEffortLow - case "medium": - params.ReasoningEffort = shared.ReasoningEffortMedium - case "high": - params.ReasoningEffort = shared.ReasoningEffortHigh - case "minimal": - params.ReasoningEffort = shared.ReasoningEffort("minimal") - default: - params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort) - } - } else { - params.MaxTokens = openai.Int(maxTokens) - } - - return params -} - -func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { - params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) - attempts := 0 - for { - attempts++ - openaiResponse, err := o.client.Chat.Completions.New( - ctx, - params, - ) - // If there is an error we are going to see if we can retry the call - if err != nil { - retry, after, retryErr := o.shouldRetry(attempts, err) - if retryErr != nil { - return nil, retryErr - } - if retry { - slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err) - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(time.Duration(after) * time.Millisecond): - continue - } - } - return nil, retryErr - } - - if len(openaiResponse.Choices) == 0 { - return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration") - } - - content := "" - if openaiResponse.Choices[0].Message.Content != "" { - content = openaiResponse.Choices[0].Message.Content - } - - toolCalls := o.toolCalls(*openaiResponse) - finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason)) - - if len(toolCalls) > 0 { - finishReason = message.FinishReasonToolUse - } - - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: o.usage(*openaiResponse), - FinishReason: finishReason, - }, nil - } -} - -func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { - params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) - params.StreamOptions = openai.ChatCompletionStreamOptionsParam{ - IncludeUsage: openai.Bool(true), - } - - attempts := 0 - eventChan := make(chan ProviderEvent) - - go func() { - for { - attempts++ - // Kujtim: fixes an issue with anthropig models on openrouter - if len(params.Tools) == 0 { - params.Tools = nil - } - openaiStream := o.client.Chat.Completions.NewStreaming( - ctx, - params, - ) - - acc := openai.ChatCompletionAccumulator{} - currentContent := "" - toolCalls := make([]message.ToolCall, 0) - msgToolCalls := make(map[int64]openai.ChatCompletionMessageToolCall) - toolMap := make(map[string]openai.ChatCompletionMessageToolCall) - toolCallIDMap := make(map[string]string) - for openaiStream.Next() { - chunk := openaiStream.Current() - // Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index - if len(chunk.Choices) != 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 { - chunk.Choices[0].Delta.ToolCalls[0].Index = 0 - } - acc.AddChunk(chunk) - for i, choice := range chunk.Choices { - reasoning, ok := choice.Delta.JSON.ExtraFields["reasoning"] - if ok && reasoning.Raw() != "" { - reasoningStr := "" - json.Unmarshal([]byte(reasoning.Raw()), &reasoningStr) - if reasoningStr != "" { - eventChan <- ProviderEvent{ - Type: EventThinkingDelta, - Thinking: reasoningStr, - } - } - } - if choice.Delta.Content != "" { - eventChan <- ProviderEvent{ - Type: EventContentDelta, - Content: choice.Delta.Content, - } - currentContent += choice.Delta.Content - } else if len(choice.Delta.ToolCalls) > 0 { - toolCall := choice.Delta.ToolCalls[0] - if strings.HasPrefix(toolCall.ID, "functions.") { - exID, ok := toolCallIDMap[toolCall.ID] - if !ok { - newID := uuid.NewString() - toolCallIDMap[toolCall.ID] = newID - toolCall.ID = newID - } else { - toolCall.ID = exID - } - } - newToolCall := false - if existingToolCall, ok := msgToolCalls[toolCall.Index]; ok { // tool call exists - if toolCall.ID != "" && toolCall.ID != existingToolCall.ID { - found := false - // try to find the tool based on the ID - for _, tool := range msgToolCalls { - if tool.ID == toolCall.ID { - existingToolCall.Function.Arguments += toolCall.Function.Arguments - msgToolCalls[toolCall.Index] = existingToolCall - toolMap[existingToolCall.ID] = existingToolCall - found = true - } - } - if !found { - newToolCall = true - } - } else { - existingToolCall.Function.Arguments += toolCall.Function.Arguments - msgToolCalls[toolCall.Index] = existingToolCall - toolMap[existingToolCall.ID] = existingToolCall - } - } else { - newToolCall = true - } - if newToolCall { // new tool call - if toolCall.ID == "" { - toolCall.ID = uuid.NewString() - } - eventChan <- ProviderEvent{ - Type: EventToolUseStart, - ToolCall: &message.ToolCall{ - ID: toolCall.ID, - Name: toolCall.Function.Name, - Finished: false, - }, - } - msgToolCalls[toolCall.Index] = openai.ChatCompletionMessageToolCall{ - ID: toolCall.ID, - Type: "function", - Function: openai.ChatCompletionMessageToolCallFunction{ - Name: toolCall.Function.Name, - Arguments: toolCall.Function.Arguments, - }, - } - toolMap[toolCall.ID] = msgToolCalls[toolCall.Index] - } - toolCalls := []openai.ChatCompletionMessageToolCall{} - for _, tc := range toolMap { - toolCalls = append(toolCalls, tc) - } - acc.Choices[i].Message.ToolCalls = toolCalls - } - } - } - - err := openaiStream.Err() - if err == nil || errors.Is(err, io.EOF) { - if len(acc.Choices) == 0 { - eventChan <- ProviderEvent{ - Type: EventError, - Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"), - } - return - } - - resultFinishReason := acc.Choices[0].FinishReason - if resultFinishReason == "" { - // If the finish reason is empty, we assume it was a successful completion - // INFO: this is happening for openrouter for some reason - resultFinishReason = "stop" - } - // Stream completed successfully - finishReason := o.finishReason(resultFinishReason) - if len(acc.Choices[0].Message.ToolCalls) > 0 { - toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...) - } - if len(toolCalls) > 0 { - finishReason = message.FinishReasonToolUse - } - - eventChan <- ProviderEvent{ - Type: EventComplete, - Response: &ProviderResponse{ - Content: currentContent, - ToolCalls: toolCalls, - Usage: o.usage(acc.ChatCompletion), - FinishReason: finishReason, - }, - } - close(eventChan) - return - } - - // If there is an error we are going to see if we can retry the call - retry, after, retryErr := o.shouldRetry(attempts, err) - if retryErr != nil { - eventChan <- ProviderEvent{Type: EventError, Error: retryErr} - close(eventChan) - return - } - if retry { - slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries, "error", err) - select { - case <-ctx.Done(): - // context cancelled - if ctx.Err() != nil { - eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} - } - close(eventChan) - return - case <-time.After(time.Duration(after) * time.Millisecond): - continue - } - } - eventChan <- ProviderEvent{Type: EventError, Error: retryErr} - close(eventChan) - return - } - }() - - return eventChan -} - -func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) { - if attempts > maxRetries { - return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) - } - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return false, 0, err - } - var apiErr *openai.Error - retryMs := 0 - retryAfterValues := []string{} - if errors.As(err, &apiErr) { - // Check for token expiration (401 Unauthorized) - if apiErr.StatusCode == http.StatusUnauthorized { - prev := o.providerOptions.apiKey - // in case the key comes from a script, we try to re-evaluate it. - o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey) - if err != nil { - return false, 0, fmt.Errorf("failed to resolve API key: %w", err) - } - // if it didn't change, do not retry. - if prev == o.providerOptions.apiKey { - return false, 0, err - } - o.client = createOpenAIClient(o.providerOptions) - return true, 0, nil - } - - if apiErr.StatusCode == http.StatusTooManyRequests { - // Check if this is an insufficient quota error (permanent) - if apiErr.Type == "insufficient_quota" || apiErr.Code == "insufficient_quota" { - return false, 0, fmt.Errorf("OpenAI quota exceeded: %s. Please check your plan and billing details", apiErr.Message) - } - // Other 429 errors (rate limiting) can be retried - } else if apiErr.StatusCode != http.StatusInternalServerError { - return false, 0, err - } - - if apiErr.Response != nil { - retryAfterValues = apiErr.Response.Header.Values("Retry-After") - } - } - - if apiErr != nil { - slog.Warn("OpenAI API error", "status_code", apiErr.StatusCode, "message", apiErr.Message, "type", apiErr.Type) - if len(retryAfterValues) > 0 { - slog.Warn("Retry-After header", "values", retryAfterValues) - } - } else { - slog.Error("OpenAI API error", "error", err.Error(), "attempt", attempts, "max_retries", maxRetries) - } - - backoffMs := 2000 * (1 << (attempts - 1)) - jitterMs := int(float64(backoffMs) * 0.2) - retryMs = backoffMs + jitterMs - if len(retryAfterValues) > 0 { - if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil { - retryMs = retryMs * 1000 - } - } - return true, int64(retryMs), nil -} - -func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall { - var toolCalls []message.ToolCall - - if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 { - for _, call := range completion.Choices[0].Message.ToolCalls { - // accumulator for some reason does this. - if call.Function.Name == "" { - continue - } - toolCall := message.ToolCall{ - ID: call.ID, - Name: call.Function.Name, - Input: call.Function.Arguments, - Type: "function", - Finished: true, - } - toolCalls = append(toolCalls, toolCall) - } - } - - return toolCalls -} - -func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { - cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens - inputTokens := completion.Usage.PromptTokens - cachedTokens - - return TokenUsage{ - InputTokens: inputTokens, - OutputTokens: completion.Usage.CompletionTokens, - CacheCreationTokens: 0, // OpenAI doesn't provide this directly - CacheReadTokens: cachedTokens, - } -} - -func (o *openaiClient) Model() catwalk.Model { - return o.providerOptions.model(o.providerOptions.modelType) -} diff --git a/internal/llm/provider/openai_test.go b/internal/llm/provider/openai_test.go deleted file mode 100644 index 52b0a20c9316d67ba987ccc5051aa2f6d321aff4..0000000000000000000000000000000000000000 --- a/internal/llm/provider/openai_test.go +++ /dev/null @@ -1,166 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "os" - "strings" - "testing" - "time" - - "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/message" - "github.com/openai/openai-go" - "github.com/openai/openai-go/option" -) - -func TestMain(m *testing.M) { - _, err := config.Init(".", "", true) - if err != nil { - panic("Failed to initialize config: " + err.Error()) - } - - os.Exit(m.Run()) -} - -func TestOpenAIClientStreamChoices(t *testing.T) { - // Create a mock server that returns Server-Sent Events with empty choices - // This simulates the 🤡 behavior when a server returns 200 instead of 404 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.WriteHeader(http.StatusOK) - - emptyChoicesChunk := map[string]any{ - "id": "chat-completion-test", - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": "test-model", - "choices": []any{}, // Empty choices array that causes panic - } - - jsonData, _ := json.Marshal(emptyChoicesChunk) - w.Write([]byte("data: " + string(jsonData) + "\n\n")) - w.Write([]byte("data: [DONE]\n\n")) - })) - defer server.Close() - - // Create OpenAI client pointing to our mock server - client := &openaiClient{ - providerOptions: providerClientOptions{ - modelType: config.SelectedModelTypeLarge, - apiKey: "test-key", - systemMessage: "test", - model: func(config.SelectedModelType) catwalk.Model { - return catwalk.Model{ - ID: "test-model", - Name: "test-model", - } - }, - }, - client: openai.NewClient( - option.WithAPIKey("test-key"), - option.WithBaseURL(server.URL), - ), - } - - // Create test messages - messages := []message.Message{ - { - Role: message.User, - Parts: []message.ContentPart{message.TextContent{Text: "Hello"}}, - }, - } - - ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) - defer cancel() - - eventsChan := client.stream(ctx, messages, nil) - - // Collect events - this will panic without the bounds check - for event := range eventsChan { - t.Logf("Received event: %+v", event) - if event.Type == EventError || event.Type == EventComplete { - break - } - } -} - -func TestOpenAIClient429InsufficientQuotaError(t *testing.T) { - client := &openaiClient{ - providerOptions: providerClientOptions{ - modelType: config.SelectedModelTypeLarge, - apiKey: "test-key", - systemMessage: "test", - config: config.ProviderConfig{ - ID: "test-openai", - APIKey: "test-key", - }, - model: func(config.SelectedModelType) catwalk.Model { - return catwalk.Model{ - ID: "test-model", - Name: "test-model", - } - }, - }, - } - - // Test insufficient_quota error should not retry - apiErr := &openai.Error{ - StatusCode: 429, - Message: "You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors.", - Type: "insufficient_quota", - Code: "insufficient_quota", - } - - retry, _, err := client.shouldRetry(1, apiErr) - if retry { - t.Error("Expected shouldRetry to return false for insufficient_quota error, but got true") - } - if err == nil { - t.Error("Expected shouldRetry to return an error for insufficient_quota, but got nil") - } - if err != nil && !strings.Contains(err.Error(), "quota") { - t.Errorf("Expected error message to mention quota, got: %v", err) - } -} - -func TestOpenAIClient429RateLimitError(t *testing.T) { - client := &openaiClient{ - providerOptions: providerClientOptions{ - modelType: config.SelectedModelTypeLarge, - apiKey: "test-key", - systemMessage: "test", - config: config.ProviderConfig{ - ID: "test-openai", - APIKey: "test-key", - }, - model: func(config.SelectedModelType) catwalk.Model { - return catwalk.Model{ - ID: "test-model", - Name: "test-model", - } - }, - }, - } - - // Test regular rate limit error should retry - apiErr := &openai.Error{ - StatusCode: 429, - Message: "Rate limit reached for requests", - Type: "rate_limit_exceeded", - Code: "rate_limit_exceeded", - } - - retry, _, err := client.shouldRetry(1, apiErr) - if !retry { - t.Error("Expected shouldRetry to return true for rate_limit_exceeded error, but got false") - } - if err != nil { - t.Errorf("Expected shouldRetry to return nil error for rate_limit_exceeded, but got: %v", err) - } -} diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go deleted file mode 100644 index 0dada9d8b1e353801fde43b1d9ebb1fc6eaa0a1e..0000000000000000000000000000000000000000 --- a/internal/llm/provider/provider.go +++ /dev/null @@ -1,208 +0,0 @@ -package provider - -import ( - "context" - "fmt" - - "github.com/charmbracelet/catwalk/pkg/catwalk" - - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/tools" - "github.com/charmbracelet/crush/internal/message" -) - -type EventType string - -const maxRetries = 3 - -const ( - EventContentStart EventType = "content_start" - EventToolUseStart EventType = "tool_use_start" - EventToolUseDelta EventType = "tool_use_delta" - EventToolUseStop EventType = "tool_use_stop" - EventContentDelta EventType = "content_delta" - EventThinkingDelta EventType = "thinking_delta" - EventSignatureDelta EventType = "signature_delta" - EventContentStop EventType = "content_stop" - EventComplete EventType = "complete" - EventError EventType = "error" - EventWarning EventType = "warning" -) - -type TokenUsage struct { - InputTokens int64 - OutputTokens int64 - CacheCreationTokens int64 - CacheReadTokens int64 -} - -type ProviderResponse struct { - Content string - ToolCalls []message.ToolCall - Usage TokenUsage - FinishReason message.FinishReason -} - -type ProviderEvent struct { - Type EventType - - Content string - Thinking string - Signature string - Response *ProviderResponse - ToolCall *message.ToolCall - Error error -} -type Provider interface { - SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) - - StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - - Model() catwalk.Model -} - -type providerClientOptions struct { - baseURL string - config config.ProviderConfig - apiKey string - modelType config.SelectedModelType - model func(config.SelectedModelType) catwalk.Model - disableCache bool - systemMessage string - systemPromptPrefix string - maxTokens int64 - extraHeaders map[string]string - extraBody map[string]any - extraParams map[string]string -} - -type ProviderClientOption func(*providerClientOptions) - -type ProviderClient interface { - send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) - stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - - Model() catwalk.Model -} - -type baseProvider[C ProviderClient] struct { - options providerClientOptions - client C -} - -func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) { - for _, msg := range messages { - // The message has no content - if len(msg.Parts) == 0 { - continue - } - cleaned = append(cleaned, msg) - } - return cleaned -} - -func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = p.cleanMessages(messages) - return p.client.send(ctx, messages, tools) -} - -func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { - messages = p.cleanMessages(messages) - return p.client.stream(ctx, messages, tools) -} - -func (p *baseProvider[C]) Model() catwalk.Model { - return p.client.Model() -} - -func WithModel(model config.SelectedModelType) ProviderClientOption { - return func(options *providerClientOptions) { - options.modelType = model - } -} - -func WithDisableCache(disableCache bool) ProviderClientOption { - return func(options *providerClientOptions) { - options.disableCache = disableCache - } -} - -func WithSystemMessage(systemMessage string) ProviderClientOption { - return func(options *providerClientOptions) { - options.systemMessage = systemMessage - } -} - -func WithMaxTokens(maxTokens int64) ProviderClientOption { - return func(options *providerClientOptions) { - options.maxTokens = maxTokens - } -} - -func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { - restore := config.PushPopCrushEnv() - defer restore() - resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey) - if err != nil { - return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err) - } - - // Resolve extra headers - resolvedExtraHeaders := make(map[string]string) - for key, value := range cfg.ExtraHeaders { - resolvedValue, err := config.Get().Resolve(value) - if err != nil { - return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err) - } - resolvedExtraHeaders[key] = resolvedValue - } - - clientOptions := providerClientOptions{ - baseURL: cfg.BaseURL, - config: cfg, - apiKey: resolvedAPIKey, - extraHeaders: resolvedExtraHeaders, - extraBody: cfg.ExtraBody, - extraParams: cfg.ExtraParams, - systemPromptPrefix: cfg.SystemPromptPrefix, - model: func(tp config.SelectedModelType) catwalk.Model { - return *config.Get().GetModelByType(tp) - }, - } - for _, o := range opts { - o(&clientOptions) - } - switch cfg.Type { - case catwalk.TypeAnthropic: - return &baseProvider[AnthropicClient]{ - options: clientOptions, - client: newAnthropicClient(clientOptions, AnthropicClientTypeNormal), - }, nil - case catwalk.TypeOpenAI: - return &baseProvider[OpenAIClient]{ - options: clientOptions, - client: newOpenAIClient(clientOptions), - }, nil - case catwalk.TypeGemini: - return &baseProvider[GeminiClient]{ - options: clientOptions, - client: newGeminiClient(clientOptions), - }, nil - case catwalk.TypeBedrock: - return &baseProvider[BedrockClient]{ - options: clientOptions, - client: newBedrockClient(clientOptions), - }, nil - case catwalk.TypeAzure: - return &baseProvider[AzureClient]{ - options: clientOptions, - client: newAzureClient(clientOptions), - }, nil - case catwalk.TypeVertexAI: - return &baseProvider[VertexAIClient]{ - options: clientOptions, - client: newVertexAIClient(clientOptions), - }, nil - } - return nil, fmt.Errorf("provider not supported: %s", cfg.Type) -} diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go deleted file mode 100644 index 871ff092b058af70833ba615260efcdbc09f2514..0000000000000000000000000000000000000000 --- a/internal/llm/provider/vertexai.go +++ /dev/null @@ -1,40 +0,0 @@ -package provider - -import ( - "context" - "log/slog" - "strings" - - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/log" - "google.golang.org/genai" -) - -type VertexAIClient ProviderClient - -func newVertexAIClient(opts providerClientOptions) VertexAIClient { - project := opts.extraParams["project"] - location := opts.extraParams["location"] - cc := &genai.ClientConfig{ - Project: project, - Location: location, - Backend: genai.BackendVertexAI, - } - if config.Get().Options.Debug { - cc.HTTPClient = log.NewHTTPClient() - } - client, err := genai.NewClient(context.Background(), cc) - if err != nil { - slog.Error("Failed to create VertexAI client", "error", err) - return nil - } - - model := opts.model(opts.modelType) - if strings.Contains(model.ID, "anthropic") || strings.Contains(model.ID, "claude-sonnet") { - return newAnthropicClient(opts, AnthropicClientTypeVertex) - } - return &geminiClient{ - providerOptions: opts, - client: client, - } -} diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go deleted file mode 100644 index 7f91ecb78b3d79380d029a1d58bb62083cdf27fe..0000000000000000000000000000000000000000 --- a/internal/llm/tools/bash.go +++ /dev/null @@ -1,395 +0,0 @@ -package tools - -import ( - "bytes" - "context" - _ "embed" - "encoding/json" - "fmt" - "html/template" - "strings" - "time" - - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/permission" - "github.com/charmbracelet/crush/internal/shell" -) - -type BashParams struct { - Command string `json:"command"` - Timeout int `json:"timeout"` -} - -type BashPermissionsParams struct { - Command string `json:"command"` - Timeout int `json:"timeout"` -} - -type BashResponseMetadata struct { - StartTime int64 `json:"start_time"` - EndTime int64 `json:"end_time"` - Output string `json:"output"` - WorkingDirectory string `json:"working_directory"` -} -type bashTool struct { - permissions permission.Service - workingDir string - attribution *config.Attribution -} - -const ( - BashToolName = "bash" - - DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds - MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds - MaxOutputLength = 30000 - BashNoOutput = "no output" -) - -//go:embed bash.md -var bashDescription []byte - -var bashDescriptionTpl = template.Must( - template.New("bashDescription"). - Parse(string(bashDescription)), -) - -type bashDescriptionData struct { - BannedCommands string - MaxOutputLength int - AttributionStep string - AttributionExample string - PRAttribution string -} - -var bannedCommands = []string{ - // Network/Download tools - "alias", - "aria2c", - "axel", - "chrome", - "curl", - "curlie", - "firefox", - "http-prompt", - "httpie", - "links", - "lynx", - "nc", - "safari", - "scp", - "ssh", - "telnet", - "w3m", - "wget", - "xh", - - // System administration - "doas", - "su", - "sudo", - - // Package managers - "apk", - "apt", - "apt-cache", - "apt-get", - "dnf", - "dpkg", - "emerge", - "home-manager", - "makepkg", - "opkg", - "pacman", - "paru", - "pkg", - "pkg_add", - "pkg_delete", - "portage", - "rpm", - "yay", - "yum", - "zypper", - - // System modification - "at", - "batch", - "chkconfig", - "crontab", - "fdisk", - "mkfs", - "mount", - "parted", - "service", - "systemctl", - "umount", - - // Network configuration - "firewall-cmd", - "ifconfig", - "ip", - "iptables", - "netstat", - "pfctl", - "route", - "ufw", -} - -func (b *bashTool) bashDescription() string { - bannedCommandsStr := strings.Join(bannedCommands, ", ") - - // Build attribution text based on settings - var attributionStep, attributionExample, prAttribution string - - // Default to true if attribution is nil (backward compatibility) - generatedWith := b.attribution == nil || b.attribution.GeneratedWith - coAuthoredBy := b.attribution == nil || b.attribution.CoAuthoredBy - - // Build PR attribution - if generatedWith { - prAttribution = "💘 Generated with Crush" - } - - if generatedWith || coAuthoredBy { - var attributionParts []string - if generatedWith { - attributionParts = append(attributionParts, "💘 Generated with Crush") - } - if coAuthoredBy { - attributionParts = append(attributionParts, "Co-Authored-By: Crush ") - } - - if len(attributionParts) > 0 { - attributionStep = fmt.Sprintf("4. Create the commit with a message ending with:\n%s", strings.Join(attributionParts, "\n")) - - attributionText := strings.Join(attributionParts, "\n ") - attributionExample = fmt.Sprintf(` -git commit -m "$(cat <<'EOF' - Commit message here. - - %s - EOF -)"`, attributionText) - } - } - - if attributionStep == "" { - attributionStep = "4. Create the commit with your commit message." - attributionExample = ` -git commit -m "$(cat <<'EOF' - Commit message here. - EOF -)"` - } - - var out bytes.Buffer - if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{ - BannedCommands: bannedCommandsStr, - MaxOutputLength: MaxOutputLength, - AttributionStep: attributionStep, - AttributionExample: attributionExample, - PRAttribution: prAttribution, - }); err != nil { - // this should never happen. - panic("failed to execute bash description template: " + err.Error()) - } - return out.String() -} - -func blockFuncs() []shell.BlockFunc { - return []shell.BlockFunc{ - shell.CommandsBlocker(bannedCommands), - - // System package managers - shell.ArgumentsBlocker("apk", []string{"add"}, nil), - shell.ArgumentsBlocker("apt", []string{"install"}, nil), - shell.ArgumentsBlocker("apt-get", []string{"install"}, nil), - shell.ArgumentsBlocker("dnf", []string{"install"}, nil), - shell.ArgumentsBlocker("pacman", nil, []string{"-S"}), - shell.ArgumentsBlocker("pkg", []string{"install"}, nil), - shell.ArgumentsBlocker("yum", []string{"install"}, nil), - shell.ArgumentsBlocker("zypper", []string{"install"}, nil), - - // Language-specific package managers - shell.ArgumentsBlocker("brew", []string{"install"}, nil), - shell.ArgumentsBlocker("cargo", []string{"install"}, nil), - shell.ArgumentsBlocker("gem", []string{"install"}, nil), - shell.ArgumentsBlocker("go", []string{"install"}, nil), - shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}), - shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}), - shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}), - shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}), - shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}), - shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}), - shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil), - - // `go test -exec` can run arbitrary commands - shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}), - } -} - -func NewBashTool(permission permission.Service, workingDir string, attribution *config.Attribution) BaseTool { - // Set up command blocking on the persistent shell - persistentShell := shell.GetPersistentShell(workingDir) - persistentShell.SetBlockFuncs(blockFuncs()) - - return &bashTool{ - permissions: permission, - workingDir: workingDir, - attribution: attribution, - } -} - -func (b *bashTool) Name() string { - return BashToolName -} - -func (b *bashTool) Info() ToolInfo { - return ToolInfo{ - Name: BashToolName, - Description: b.bashDescription(), - Parameters: map[string]any{ - "command": map[string]any{ - "type": "string", - "description": "The command to execute", - }, - "timeout": map[string]any{ - "type": "number", - "description": "Optional timeout in milliseconds (max 600000)", - }, - }, - Required: []string{"command"}, - } -} - -func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { - var params BashParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return NewTextErrorResponse("invalid parameters"), nil - } - - if params.Timeout > MaxTimeout { - params.Timeout = MaxTimeout - } else if params.Timeout <= 0 { - params.Timeout = DefaultTimeout - } - - if params.Command == "" { - return NewTextErrorResponse("missing command"), nil - } - - isSafeReadOnly := false - cmdLower := strings.ToLower(params.Command) - - for _, safe := range safeCommands { - if strings.HasPrefix(cmdLower, safe) { - if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' { - isSafeReadOnly = true - break - } - } - } - - sessionID, messageID := GetContextValues(ctx) - if sessionID == "" || messageID == "" { - return ToolResponse{}, fmt.Errorf("session ID and message ID are required for executing shell command") - } - if !isSafeReadOnly { - shell := shell.GetPersistentShell(b.workingDir) - p := b.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: shell.GetWorkingDir(), - ToolCallID: call.ID, - ToolName: BashToolName, - Action: "execute", - Description: fmt.Sprintf("Execute command: %s", params.Command), - Params: BashPermissionsParams{ - Command: params.Command, - }, - }, - ) - if !p { - return ToolResponse{}, permission.ErrorPermissionDenied - } - } - startTime := time.Now() - if params.Timeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond) - defer cancel() - } - - persistentShell := shell.GetPersistentShell(b.workingDir) - stdout, stderr, err := persistentShell.Exec(ctx, params.Command) - - // Get the current working directory after command execution - currentWorkingDir := persistentShell.GetWorkingDir() - interrupted := shell.IsInterrupt(err) - exitCode := shell.ExitCode(err) - if exitCode == 0 && !interrupted && err != nil { - return ToolResponse{}, fmt.Errorf("error executing command: %w", err) - } - - stdout = truncateOutput(stdout) - stderr = truncateOutput(stderr) - - errorMessage := stderr - if errorMessage == "" && err != nil { - errorMessage = err.Error() - } - - if interrupted { - if errorMessage != "" { - errorMessage += "\n" - } - errorMessage += "Command was aborted before completion" - } else if exitCode != 0 { - if errorMessage != "" { - errorMessage += "\n" - } - errorMessage += fmt.Sprintf("Exit code %d", exitCode) - } - - hasBothOutputs := stdout != "" && stderr != "" - - if hasBothOutputs { - stdout += "\n" - } - - if errorMessage != "" { - stdout += "\n" + errorMessage - } - - metadata := BashResponseMetadata{ - StartTime: startTime.UnixMilli(), - EndTime: time.Now().UnixMilli(), - Output: stdout, - WorkingDirectory: currentWorkingDir, - } - if stdout == "" { - return WithResponseMetadata(NewTextResponse(BashNoOutput), metadata), nil - } - stdout += fmt.Sprintf("\n\n%s", currentWorkingDir) - return WithResponseMetadata(NewTextResponse(stdout), metadata), nil -} - -func truncateOutput(content string) string { - if len(content) <= MaxOutputLength { - return content - } - - halfLength := MaxOutputLength / 2 - start := content[:halfLength] - end := content[len(content)-halfLength:] - - truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength]) - return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end) -} - -func countLines(s string) int { - if s == "" { - return 0 - } - return len(strings.Split(s, "\n")) -} diff --git a/internal/llm/tools/bash.md b/internal/llm/tools/bash.md deleted file mode 100644 index 464b2749fb0205729caafbccc2dde57ffe267cd1..0000000000000000000000000000000000000000 --- a/internal/llm/tools/bash.md +++ /dev/null @@ -1,161 +0,0 @@ -Executes a given bash command in a persistent shell session with optional timeout, ensuring proper handling and security measures. - -CROSS-PLATFORM SHELL SUPPORT: - -- This tool uses a shell interpreter (mvdan/sh) that mimics the Bash language, - so you should use Bash syntax in all platforms, including Windows. - The most common shell builtins and core utils are available in Windows as - well. -- Make sure to use forward slashes (/) as path separators in commands, even on - Windows. Example: "ls C:/foo/bar" instead of "ls C:\foo\bar". - -Before executing the command, please follow these steps: - -1. Directory Verification: - -- If the command will create new directories or files, first use the LS tool to verify the parent directory exists and is the correct location -- For example, before running "mkdir foo/bar", first use LS to check that "foo" exists and is the intended parent directory - -2. Security Check: - -- For security and to limit the threat of a prompt injection attack, some commands are limited or banned. If you use a disallowed command, you will receive an error message explaining the restriction. Explain the error to the User. -- Verify that the command is not one of the banned commands: {{ .BannedCommands }}. - -3. Command Execution: - -- After ensuring proper quoting, execute the command. -- Capture the output of the command. - -4. Output Processing: - -- If the output exceeds {{ .MaxOutputLength }} characters, output will be truncated before being returned to you. -- Prepare the output for display to the user. - -5. Return Result: - -- Provide the processed output of the command. -- If any errors occurred during execution, include those in the output. -- The result will also have metadata like the cwd (current working directory) at the end, included with tags. - -Usage notes: - -- The command argument is required. -- You can specify an optional timeout in milliseconds (up to 600000ms / 10 minutes). If not specified, commands will timeout after 30 minutes. -- VERY IMPORTANT: You MUST avoid using search commands like 'find' and 'grep'. Instead use Grep, Glob, or Agent tools to search. You MUST avoid read tools like 'cat', 'head', 'tail', and 'ls', and use FileRead and LS tools to read files. -- When issuing multiple commands, use the ';' or '&&' operator to separate them. DO NOT use newlines (newlines are ok in quoted strings). -- IMPORTANT: All commands share the same shell session. Shell state (environment variables, virtual environments, current directory, etc.) persist between commands. For example, if you set an environment variable as part of a command, the environment variable will persist for subsequent commands. -- Try to maintain your current working directory throughout the session by using absolute paths and avoiding usage of 'cd'. You may use 'cd' if the User explicitly requests it. - - pytest /foo/bar/tests - - - cd /foo/bar && pytest tests - - -# Committing changes with git - -When the user asks you to create a new git commit, follow these steps carefully: - -1. Start with a single message that contains exactly three tool_use blocks that do the following (it is VERY IMPORTANT that you send these tool_use blocks in a single message, otherwise it will feel slow to the user!): - -- Run a git status command to see all untracked files. -- Run a git diff command to see both staged and unstaged changes that will be committed. -- Run a git log command to see recent commit messages, so that you can follow this repository's commit message style. - -2. Use the git context at the start of this conversation to determine which files are relevant to your commit. Add relevant untracked files to the staging area. Do not commit files that were already modified at the start of this conversation, if they are not relevant to your commit. - -3. Analyze all staged changes (both previously staged and newly added) and draft a commit message. Wrap your analysis process in tags: - - - -- List the files that have been changed or added -- Summarize the nature of the changes (eg. new feature, enhancement to an existing feature, bug fix, refactoring, test, docs, etc.) -- Brainstorm the purpose or motivation behind these changes -- Do not use tools to explore code, beyond what is available in the git context -- Assess the impact of these changes on the overall project -- Check for any sensitive information that shouldn't be committed -- Draft a concise (1-2 sentences) commit message that focuses on the "why" rather than the "what" -- Ensure your language is clear, concise, and to the point -- Ensure the message accurately reflects the changes and their purpose (i.e. "add" means a wholly new feature, "update" means an enhancement to an existing feature, "fix" means a bug fix, etc.) -- Ensure the message is not generic (avoid words like "Update" or "Fix" without context) -- Review the draft message to ensure it accurately reflects the changes and their purpose - - -{{ .AttributionStep }} - -- In order to ensure good formatting, ALWAYS pass the commit message via a HEREDOC, a la this example: - {{ .AttributionExample }} - -5. If the commit fails due to pre-commit hook changes, retry the commit ONCE to include these automated changes. If it fails again, it usually means a pre-commit hook is preventing the commit. If the commit succeeds but you notice that files were modified by the pre-commit hook, you MUST amend your commit to include them. - -6. Finally, run git status to make sure the commit succeeded. - -Important notes: - -- When possible, combine the "git add" and "git commit" commands into a single "git commit -am" command, to speed things up -- However, be careful not to stage files (e.g. with 'git add .') for commits that aren't part of the change, they may have untracked files they want to keep around, but not commit. -- NEVER update the git config -- DO NOT push to the remote repository -- IMPORTANT: Never use git commands with the -i flag (like git rebase -i or git add -i) since they require interactive input which is not supported. -- If there are no changes to commit (i.e., no untracked files and no modifications), do not create an empty commit -- Ensure your commit message is meaningful and concise. It should explain the purpose of the changes, not just describe them. -- Return an empty response - the user will see the git output directly - -# Creating pull requests - -Use the gh command via the Bash tool for ALL GitHub-related tasks including working with issues, pull requests, checks, and releases. If given a Github URL use the gh command to get the information needed. - -IMPORTANT: When the user asks you to create a pull request, follow these steps carefully: - -1. Understand the current state of the branch. Remember to send a single message that contains multiple tool_use blocks (it is VERY IMPORTANT that you do this in a single message, otherwise it will feel slow to the user!): - -- Run a git status command to see all untracked files. -- Run a git diff command to see both staged and unstaged changes that will be committed. -- Check if the current branch tracks a remote branch and is up to date with the remote, so you know if you need to push to the remote -- Run a git log command and 'git diff main...HEAD' to understand the full commit history for the current branch (from the time it diverged from the 'main' branch.) - -2. Create new branch if needed - -3. Commit changes if needed - -4. Push to remote with -u flag if needed - -5. Analyze all changes that will be included in the pull request, making sure to look at all relevant commits (not just the latest commit, but all commits that will be included in the pull request!), and draft a pull request summary. Wrap your analysis process in tags: - - - -- List the commits since diverging from the main branch -- Summarize the nature of the changes (eg. new feature, enhancement to an existing feature, bug fix, refactoring, test, docs, etc.) -- Brainstorm the purpose or motivation behind these changes -- Assess the impact of these changes on the overall project -- Do not use tools to explore code, beyond what is available in the git context -- Check for any sensitive information that shouldn't be committed -- Draft a concise (1-2 bullet points) pull request summary that focuses on the "why" rather than the "what" -- Ensure the summary accurately reflects all changes since diverging from the main branch -- Ensure your language is clear, concise, and to the point -- Ensure the summary accurately reflects the changes and their purpose (ie. "add" means a wholly new feature, "update" means an enhancement to an existing feature, "fix" means a bug fix, etc.) -- Ensure the summary is not generic (avoid words like "Update" or "Fix" without context) -- Review the draft summary to ensure it accurately reflects the changes and their purpose - - -6. Create PR using gh pr create with the format below. Use a HEREDOC to pass the body to ensure correct formatting. - - gh pr create --title "the pr title" --body "$(cat <<'EOF' - -## Summary - -<1-3 bullet points> - -## Test plan - -[Checklist of TODOs for testing the pull request...] - -{{ .PRAttribution }} -EOF -)" - - -Important: - -- Return an empty response - the user will see the gh output directly -- Never update git config diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go deleted file mode 100644 index 8e0c332cef76e40d5e24e74ed3260b95aab8b04b..0000000000000000000000000000000000000000 --- a/internal/llm/tools/diagnostics.go +++ /dev/null @@ -1,213 +0,0 @@ -package tools - -import ( - "context" - _ "embed" - "encoding/json" - "fmt" - "log/slog" - "sort" - "strings" - "time" - - "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/lsp" - "github.com/charmbracelet/x/powernap/pkg/lsp/protocol" -) - -type DiagnosticsParams struct { - FilePath string `json:"file_path"` -} - -type diagnosticsTool struct { - lspClients *csync.Map[string, *lsp.Client] -} - -const DiagnosticsToolName = "diagnostics" - -//go:embed diagnostics.md -var diagnosticsDescription []byte - -func NewDiagnosticsTool(lspClients *csync.Map[string, *lsp.Client]) BaseTool { - return &diagnosticsTool{ - lspClients, - } -} - -func (b *diagnosticsTool) Name() string { - return DiagnosticsToolName -} - -func (b *diagnosticsTool) Info() ToolInfo { - return ToolInfo{ - Name: DiagnosticsToolName, - Description: string(diagnosticsDescription), - Parameters: map[string]any{ - "file_path": map[string]any{ - "type": "string", - "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)", - }, - }, - Required: []string{}, - } -} - -func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { - var params DiagnosticsParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil - } - - if b.lspClients.Len() == 0 { - return NewTextErrorResponse("no LSP clients available"), nil - } - notifyLSPs(ctx, b.lspClients, params.FilePath) - output := getDiagnostics(params.FilePath, b.lspClients) - return NewTextResponse(output), nil -} - -func notifyLSPs(ctx context.Context, lsps *csync.Map[string, *lsp.Client], filepath string) { - if filepath == "" { - return - } - for client := range lsps.Seq() { - if !client.HandlesFile(filepath) { - continue - } - _ = client.OpenFileOnDemand(ctx, filepath) - _ = client.NotifyChange(ctx, filepath) - client.WaitForDiagnostics(ctx, 5*time.Second) - } -} - -func getDiagnostics(filePath string, lsps *csync.Map[string, *lsp.Client]) string { - fileDiagnostics := []string{} - projectDiagnostics := []string{} - - for lspName, client := range lsps.Seq2() { - for location, diags := range client.GetDiagnostics() { - path, err := location.Path() - if err != nil { - slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err) - continue - } - isCurrentFile := path == filePath - for _, diag := range diags { - formattedDiag := formatDiagnostic(path, diag, lspName) - if isCurrentFile { - fileDiagnostics = append(fileDiagnostics, formattedDiag) - } else { - projectDiagnostics = append(projectDiagnostics, formattedDiag) - } - } - } - } - - sortDiagnostics(fileDiagnostics) - sortDiagnostics(projectDiagnostics) - - var output strings.Builder - writeDiagnostics(&output, "file_diagnostics", fileDiagnostics) - writeDiagnostics(&output, "project_diagnostics", projectDiagnostics) - - if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 { - fileErrors := countSeverity(fileDiagnostics, "Error") - fileWarnings := countSeverity(fileDiagnostics, "Warn") - projectErrors := countSeverity(projectDiagnostics, "Error") - projectWarnings := countSeverity(projectDiagnostics, "Warn") - output.WriteString("\n\n") - fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings) - fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings) - output.WriteString("\n") - } - - out := output.String() - slog.Info("Diagnostics", "output", fmt.Sprintf("%q", out)) - return out -} - -func writeDiagnostics(output *strings.Builder, tag string, in []string) { - if len(in) == 0 { - return - } - output.WriteString("\n<" + tag + ">\n") - if len(in) > 10 { - output.WriteString(strings.Join(in[:10], "\n")) - fmt.Fprintf(output, "\n... and %d more diagnostics", len(in)-10) - } else { - output.WriteString(strings.Join(in, "\n")) - } - output.WriteString("\n\n") -} - -func sortDiagnostics(in []string) []string { - sort.Slice(in, func(i, j int) bool { - iIsError := strings.HasPrefix(in[i], "Error") - jIsError := strings.HasPrefix(in[j], "Error") - if iIsError != jIsError { - return iIsError // Errors come first - } - return in[i] < in[j] // Then alphabetically - }) - return in -} - -func formatDiagnostic(pth string, diagnostic protocol.Diagnostic, source string) string { - severity := "Info" - switch diagnostic.Severity { - case protocol.SeverityError: - severity = "Error" - case protocol.SeverityWarning: - severity = "Warn" - case protocol.SeverityHint: - severity = "Hint" - } - - location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1) - - sourceInfo := "" - if diagnostic.Source != "" { - sourceInfo = diagnostic.Source - } else if source != "" { - sourceInfo = source - } - - codeInfo := "" - if diagnostic.Code != nil { - codeInfo = fmt.Sprintf("[%v]", diagnostic.Code) - } - - tagsInfo := "" - if len(diagnostic.Tags) > 0 { - tags := []string{} - for _, tag := range diagnostic.Tags { - switch tag { - case protocol.Unnecessary: - tags = append(tags, "unnecessary") - case protocol.Deprecated: - tags = append(tags, "deprecated") - } - } - if len(tags) > 0 { - tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", ")) - } - } - - return fmt.Sprintf("%s: %s [%s]%s%s %s", - severity, - location, - sourceInfo, - codeInfo, - tagsInfo, - diagnostic.Message) -} - -func countSeverity(diagnostics []string, severity string) int { - count := 0 - for _, diag := range diagnostics { - if strings.HasPrefix(diag, severity) { - count++ - } - } - return count -} diff --git a/internal/llm/tools/diagnostics.md b/internal/llm/tools/diagnostics.md deleted file mode 100644 index 3f00d66039ade58615d46372bbf7f3cd6e619dcf..0000000000000000000000000000000000000000 --- a/internal/llm/tools/diagnostics.md +++ /dev/null @@ -1,21 +0,0 @@ -Get diagnostics for a file and/or project. -WHEN TO USE THIS TOOL: - -- Use when you need to check for errors or warnings in your code -- Helpful for debugging and ensuring code quality -- Good for getting a quick overview of issues in a file or project - HOW TO USE: -- Provide a path to a file to get diagnostics for that file -- Leave the path empty to get diagnostics for the entire project -- Results are displayed in a structured format with severity levels - FEATURES: -- Displays errors, warnings, and hints -- Groups diagnostics by severity -- Provides detailed information about each diagnostic - LIMITATIONS: -- Results are limited to the diagnostics provided by the LSP clients -- May not cover all possible issues in the code -- Does not provide suggestions for fixing issues - TIPS: -- Use in conjunction with other tools for a comprehensive code review -- Combine with the LSP client for real-time diagnostics diff --git a/internal/llm/tools/download.go b/internal/llm/tools/download.go deleted file mode 100644 index 63092cf24a0dbb98e7aef58e9d4ee867525ae945..0000000000000000000000000000000000000000 --- a/internal/llm/tools/download.go +++ /dev/null @@ -1,196 +0,0 @@ -package tools - -import ( - "context" - _ "embed" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strings" - "time" - - "github.com/charmbracelet/crush/internal/permission" -) - -type DownloadParams struct { - URL string `json:"url"` - FilePath string `json:"file_path"` - Timeout int `json:"timeout,omitempty"` -} - -type DownloadPermissionsParams struct { - URL string `json:"url"` - FilePath string `json:"file_path"` - Timeout int `json:"timeout,omitempty"` -} - -type downloadTool struct { - client *http.Client - permissions permission.Service - workingDir string -} - -const DownloadToolName = "download" - -//go:embed download.md -var downloadDescription []byte - -func NewDownloadTool(permissions permission.Service, workingDir string) BaseTool { - return &downloadTool{ - client: &http.Client{ - Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads - Transport: &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - }, - }, - permissions: permissions, - workingDir: workingDir, - } -} - -func (t *downloadTool) Name() string { - return DownloadToolName -} - -func (t *downloadTool) Info() ToolInfo { - return ToolInfo{ - Name: DownloadToolName, - Description: string(downloadDescription), - Parameters: map[string]any{ - "url": map[string]any{ - "type": "string", - "description": "The URL to download from", - }, - "file_path": map[string]any{ - "type": "string", - "description": "The local file path where the downloaded content should be saved", - }, - "timeout": map[string]any{ - "type": "number", - "description": "Optional timeout in seconds (max 600)", - }, - }, - Required: []string{"url", "file_path"}, - } -} - -func (t *downloadTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { - var params DownloadParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return NewTextErrorResponse("Failed to parse download parameters: " + err.Error()), nil - } - - if params.URL == "" { - return NewTextErrorResponse("URL parameter is required"), nil - } - - if params.FilePath == "" { - return NewTextErrorResponse("file_path parameter is required"), nil - } - - if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") { - return NewTextErrorResponse("URL must start with http:// or https://"), nil - } - - // Convert relative path to absolute path - var filePath string - if filepath.IsAbs(params.FilePath) { - filePath = params.FilePath - } else { - filePath = filepath.Join(t.workingDir, params.FilePath) - } - - sessionID, messageID := GetContextValues(ctx) - if sessionID == "" || messageID == "" { - return ToolResponse{}, fmt.Errorf("session ID and message ID are required for downloading files") - } - - p := t.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: filePath, - ToolName: DownloadToolName, - Action: "download", - Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath), - Params: DownloadPermissionsParams(params), - }, - ) - - if !p { - return ToolResponse{}, permission.ErrorPermissionDenied - } - - // Handle timeout with context - requestCtx := ctx - if params.Timeout > 0 { - maxTimeout := 600 // 10 minutes - if params.Timeout > maxTimeout { - params.Timeout = maxTimeout - } - var cancel context.CancelFunc - requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second) - defer cancel() - } - - req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("User-Agent", "crush/1.0") - - resp, err := t.client.Do(req) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil - } - - // Check content length if available - maxSize := int64(100 * 1024 * 1024) // 100MB - if resp.ContentLength > maxSize { - return NewTextErrorResponse(fmt.Sprintf("File too large: %d bytes (max %d bytes)", resp.ContentLength, maxSize)), nil - } - - // Create parent directories if they don't exist - if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { - return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err) - } - - // Create the output file - outFile, err := os.Create(filePath) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to create output file: %w", err) - } - defer outFile.Close() - - // Copy data with size limit - limitedReader := io.LimitReader(resp.Body, maxSize) - bytesWritten, err := io.Copy(outFile, limitedReader) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) - } - - // Check if we hit the size limit - if bytesWritten == maxSize { - // Clean up the file since it might be incomplete - os.Remove(filePath) - return NewTextErrorResponse(fmt.Sprintf("File too large: exceeded %d bytes limit", maxSize)), nil - } - - contentType := resp.Header.Get("Content-Type") - responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, filePath) - if contentType != "" { - responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType) - } - - return NewTextResponse(responseMsg), nil -} diff --git a/internal/llm/tools/download.md b/internal/llm/tools/download.md deleted file mode 100644 index 4a9516100dfa2c38f3a4bac588793964569e78b3..0000000000000000000000000000000000000000 --- a/internal/llm/tools/download.md +++ /dev/null @@ -1,34 +0,0 @@ -Downloads binary data from a URL and saves it to a local file. - -WHEN TO USE THIS TOOL: - -- Use when you need to download files, images, or other binary data from URLs -- Helpful for downloading assets, documents, or any file type -- Useful for saving remote content locally for processing or storage - -HOW TO USE: - -- Provide the URL to download from -- Specify the local file path where the content should be saved -- Optionally set a timeout for the request - -FEATURES: - -- Downloads any file type (binary or text) -- Automatically creates parent directories if they don't exist -- Handles large files efficiently with streaming -- Sets reasonable timeouts to prevent hanging -- Validates input parameters before making requests - -LIMITATIONS: - -- Maximum file size is 100MB -- Only supports HTTP and HTTPS protocols -- Cannot handle authentication or cookies -- Some websites may block automated requests -- Will overwrite existing files without warning - -TIPS: - -- Use absolute paths or paths relative to the working directory -- Set appropriate timeouts for large files or slow connections diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go deleted file mode 100644 index ae4322e08da087127a3238a1e3edf6c1a9a2e37a..0000000000000000000000000000000000000000 --- a/internal/llm/tools/edit.go +++ /dev/null @@ -1,486 +0,0 @@ -package tools - -import ( - "context" - _ "embed" - "encoding/json" - "fmt" - "log/slog" - "os" - "path/filepath" - "strings" - "time" - - "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/diff" - "github.com/charmbracelet/crush/internal/fsext" - "github.com/charmbracelet/crush/internal/history" - - "github.com/charmbracelet/crush/internal/lsp" - "github.com/charmbracelet/crush/internal/permission" -) - -type EditParams struct { - FilePath string `json:"file_path"` - OldString string `json:"old_string"` - NewString string `json:"new_string"` - ReplaceAll bool `json:"replace_all,omitempty"` -} - -type EditPermissionsParams struct { - FilePath string `json:"file_path"` - OldContent string `json:"old_content,omitempty"` - NewContent string `json:"new_content,omitempty"` -} - -type EditResponseMetadata struct { - Additions int `json:"additions"` - Removals int `json:"removals"` - OldContent string `json:"old_content,omitempty"` - NewContent string `json:"new_content,omitempty"` -} - -type editTool struct { - lspClients *csync.Map[string, *lsp.Client] - permissions permission.Service - files history.Service - workingDir string -} - -const EditToolName = "edit" - -//go:embed edit.md -var editDescription []byte - -func NewEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool { - return &editTool{ - lspClients: lspClients, - permissions: permissions, - files: files, - workingDir: workingDir, - } -} - -func (e *editTool) Name() string { - return EditToolName -} - -func (e *editTool) Info() ToolInfo { - return ToolInfo{ - Name: EditToolName, - Description: string(editDescription), - Parameters: map[string]any{ - "file_path": map[string]any{ - "type": "string", - "description": "The absolute path to the file to modify", - }, - "old_string": map[string]any{ - "type": "string", - "description": "The text to replace", - }, - "new_string": map[string]any{ - "type": "string", - "description": "The text to replace it with", - }, - "replace_all": map[string]any{ - "type": "boolean", - "description": "Replace all occurrences of old_string (default false)", - }, - }, - Required: []string{"file_path", "old_string", "new_string"}, - } -} - -func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { - var params EditParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return NewTextErrorResponse("invalid parameters"), nil - } - - if params.FilePath == "" { - return NewTextErrorResponse("file_path is required"), nil - } - - if !filepath.IsAbs(params.FilePath) { - params.FilePath = filepath.Join(e.workingDir, params.FilePath) - } - - var response ToolResponse - var err error - - if params.OldString == "" { - response, err = e.createNewFile(ctx, params.FilePath, params.NewString, call) - if err != nil { - return response, err - } - } - - if params.NewString == "" { - response, err = e.deleteContent(ctx, params.FilePath, params.OldString, params.ReplaceAll, call) - if err != nil { - return response, err - } - } - - response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString, params.ReplaceAll, call) - if err != nil { - return response, err - } - if response.IsError { - // Return early if there was an error during content replacement - // This prevents unnecessary LSP diagnostics processing - return response, nil - } - - notifyLSPs(ctx, e.lspClients, params.FilePath) - - text := fmt.Sprintf("\n%s\n\n", response.Content) - text += getDiagnostics(params.FilePath, e.lspClients) - response.Content = text - return response, nil -} - -func (e *editTool) createNewFile(ctx context.Context, filePath, content string, call ToolCall) (ToolResponse, error) { - fileInfo, err := os.Stat(filePath) - if err == nil { - if fileInfo.IsDir() { - return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil - } - return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", filePath)), nil - } else if !os.IsNotExist(err) { - return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) - } - - dir := filepath.Dir(filePath) - if err = os.MkdirAll(dir, 0o755); err != nil { - return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err) - } - - sessionID, messageID := GetContextValues(ctx) - if sessionID == "" || messageID == "" { - return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") - } - - _, additions, removals := diff.GenerateDiff( - "", - content, - strings.TrimPrefix(filePath, e.workingDir), - ) - p := e.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: fsext.PathOrPrefix(filePath, e.workingDir), - ToolCallID: call.ID, - ToolName: EditToolName, - Action: "write", - Description: fmt.Sprintf("Create file %s", filePath), - Params: EditPermissionsParams{ - FilePath: filePath, - OldContent: "", - NewContent: content, - }, - }, - ) - if !p { - return ToolResponse{}, permission.ErrorPermissionDenied - } - - err = os.WriteFile(filePath, []byte(content), 0o644) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) - } - - // File can't be in the history so we create a new file history - _, err = e.files.Create(ctx, sessionID, filePath, "") - if err != nil { - // Log error but don't fail the operation - return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) - } - - // Add the new content to the file history - _, err = e.files.CreateVersion(ctx, sessionID, filePath, content) - if err != nil { - // Log error but don't fail the operation - slog.Debug("Error creating file history version", "error", err) - } - - recordFileWrite(filePath) - recordFileRead(filePath) - - return WithResponseMetadata( - NewTextResponse("File created: "+filePath), - EditResponseMetadata{ - OldContent: "", - NewContent: content, - Additions: additions, - Removals: removals, - }, - ), nil -} - -func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string, replaceAll bool, call ToolCall) (ToolResponse, error) { - fileInfo, err := os.Stat(filePath) - if err != nil { - if os.IsNotExist(err) { - return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil - } - return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) - } - - if fileInfo.IsDir() { - return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil - } - - if getLastReadTime(filePath).IsZero() { - return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil - } - - modTime := fileInfo.ModTime() - lastRead := getLastReadTime(filePath) - if modTime.After(lastRead) { - return NewTextErrorResponse( - fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", - filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339), - )), nil - } - - content, err := os.ReadFile(filePath) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to read file: %w", err) - } - - oldContent, isCrlf := fsext.ToUnixLineEndings(string(content)) - - var newContent string - var deletionCount int - - if replaceAll { - newContent = strings.ReplaceAll(oldContent, oldString, "") - deletionCount = strings.Count(oldContent, oldString) - if deletionCount == 0 { - return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil - } - } else { - index := strings.Index(oldContent, oldString) - if index == -1 { - return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil - } - - lastIndex := strings.LastIndex(oldContent, oldString) - if index != lastIndex { - return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil - } - - newContent = oldContent[:index] + oldContent[index+len(oldString):] - deletionCount = 1 - } - - sessionID, messageID := GetContextValues(ctx) - - if sessionID == "" || messageID == "" { - return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") - } - - _, additions, removals := diff.GenerateDiff( - oldContent, - newContent, - strings.TrimPrefix(filePath, e.workingDir), - ) - - p := e.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: fsext.PathOrPrefix(filePath, e.workingDir), - ToolCallID: call.ID, - ToolName: EditToolName, - Action: "write", - Description: fmt.Sprintf("Delete content from file %s", filePath), - Params: EditPermissionsParams{ - FilePath: filePath, - OldContent: oldContent, - NewContent: newContent, - }, - }, - ) - if !p { - return ToolResponse{}, permission.ErrorPermissionDenied - } - - if isCrlf { - newContent, _ = fsext.ToWindowsLineEndings(newContent) - } - - err = os.WriteFile(filePath, []byte(newContent), 0o644) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) - } - - // Check if file exists in history - file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID) - if err != nil { - _, err = e.files.Create(ctx, sessionID, filePath, oldContent) - if err != nil { - // Log error but don't fail the operation - return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) - } - } - if file.Content != oldContent { - // User Manually changed the content store an intermediate version - _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) - if err != nil { - slog.Debug("Error creating file history version", "error", err) - } - } - // Store the new version - _, err = e.files.CreateVersion(ctx, sessionID, filePath, "") - if err != nil { - slog.Debug("Error creating file history version", "error", err) - } - - recordFileWrite(filePath) - recordFileRead(filePath) - - return WithResponseMetadata( - NewTextResponse("Content deleted from file: "+filePath), - EditResponseMetadata{ - OldContent: oldContent, - NewContent: newContent, - Additions: additions, - Removals: removals, - }, - ), nil -} - -func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string, replaceAll bool, call ToolCall) (ToolResponse, error) { - fileInfo, err := os.Stat(filePath) - if err != nil { - if os.IsNotExist(err) { - return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil - } - return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) - } - - if fileInfo.IsDir() { - return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil - } - - if getLastReadTime(filePath).IsZero() { - return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil - } - - modTime := fileInfo.ModTime() - lastRead := getLastReadTime(filePath) - if modTime.After(lastRead) { - return NewTextErrorResponse( - fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", - filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339), - )), nil - } - - content, err := os.ReadFile(filePath) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to read file: %w", err) - } - - oldContent, isCrlf := fsext.ToUnixLineEndings(string(content)) - - var newContent string - var replacementCount int - - if replaceAll { - newContent = strings.ReplaceAll(oldContent, oldString, newString) - replacementCount = strings.Count(oldContent, oldString) - if replacementCount == 0 { - return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil - } - } else { - index := strings.Index(oldContent, oldString) - if index == -1 { - return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil - } - - lastIndex := strings.LastIndex(oldContent, oldString) - if index != lastIndex { - return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil - } - - newContent = oldContent[:index] + newString + oldContent[index+len(oldString):] - replacementCount = 1 - } - - if oldContent == newContent { - return NewTextErrorResponse("new content is the same as old content. No changes made."), nil - } - sessionID, messageID := GetContextValues(ctx) - - if sessionID == "" || messageID == "" { - return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") - } - _, additions, removals := diff.GenerateDiff( - oldContent, - newContent, - strings.TrimPrefix(filePath, e.workingDir), - ) - - p := e.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: fsext.PathOrPrefix(filePath, e.workingDir), - ToolCallID: call.ID, - ToolName: EditToolName, - Action: "write", - Description: fmt.Sprintf("Replace content in file %s", filePath), - Params: EditPermissionsParams{ - FilePath: filePath, - OldContent: oldContent, - NewContent: newContent, - }, - }, - ) - if !p { - return ToolResponse{}, permission.ErrorPermissionDenied - } - - if isCrlf { - newContent, _ = fsext.ToWindowsLineEndings(newContent) - } - - err = os.WriteFile(filePath, []byte(newContent), 0o644) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) - } - - // Check if file exists in history - file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID) - if err != nil { - _, err = e.files.Create(ctx, sessionID, filePath, oldContent) - if err != nil { - // Log error but don't fail the operation - return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) - } - } - if file.Content != oldContent { - // User Manually changed the content store an intermediate version - _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) - if err != nil { - slog.Debug("Error creating file history version", "error", err) - } - } - // Store the new version - _, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent) - if err != nil { - slog.Debug("Error creating file history version", "error", err) - } - - recordFileWrite(filePath) - recordFileRead(filePath) - - return WithResponseMetadata( - NewTextResponse("Content replaced in file: "+filePath), - EditResponseMetadata{ - OldContent: oldContent, - NewContent: newContent, - Additions: additions, - Removals: removals, - }), nil -} diff --git a/internal/llm/tools/edit.md b/internal/llm/tools/edit.md deleted file mode 100644 index 2f9ec73883f3a68546e7b14bd53a12f9f69cc430..0000000000000000000000000000000000000000 --- a/internal/llm/tools/edit.md +++ /dev/null @@ -1,60 +0,0 @@ -Edits files by replacing text, creating new files, or deleting content. For moving or renaming files, use the Bash tool with the 'mv' command instead. For larger file edits, use the FileWrite tool to overwrite files. - -Before using this tool: - -1. Use the FileRead tool to understand the file's contents and context - -2. Verify the directory path is correct (only applicable when creating new files): - - Use the LS tool to verify the parent directory exists and is the correct location - -To make a file edit, provide the following: - -1. file_path: The absolute path to the file to modify (must be absolute, not relative) -2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation) -3. new_string: The edited text to replace the old_string -4. replace_all: Replace all occurrences of old_string (default false) - -Special cases: - -- To create a new file: provide file_path and new_string, leave old_string empty -- To delete content: provide file_path and old_string, leave new_string empty - -The tool will replace ONE occurrence of old_string with new_string in the specified file by default. Set replace_all to true to replace all occurrences. - -CRITICAL REQUIREMENTS FOR USING THIS TOOL: - -1. UNIQUENESS: When replace_all is false (default), the old_string MUST uniquely identify the specific instance you want to change. This means: - - Include AT LEAST 3-5 lines of context BEFORE the change point - - Include AT LEAST 3-5 lines of context AFTER the change point - - Include all whitespace, indentation, and surrounding code exactly as it appears in the file - -2. SINGLE INSTANCE: When replace_all is false, this tool can only change ONE instance at a time. If you need to change multiple instances: - - Set replace_all to true to replace all occurrences at once - - Or make separate calls to this tool for each instance - - Each call must uniquely identify its specific instance using extensive context - -3. VERIFICATION: Before using this tool: - - Check how many instances of the target text exist in the file - - If multiple instances exist and replace_all is false, gather enough context to uniquely identify each one - - Plan separate tool calls for each instance or use replace_all - -WARNING: If you do not follow these requirements: - -- The tool will fail if old_string matches multiple locations and replace_all is false -- The tool will fail if old_string doesn't match exactly (including whitespace) -- You may change the wrong instance if you don't include enough context - -When making edits: - -- Ensure the edit results in idiomatic, correct code -- Do not leave the code in a broken state -- Always use absolute file paths (starting with /) - -WINDOWS NOTES: - -- File paths should use forward slashes (/) for cross-platform compatibility -- On Windows, absolute paths start with drive letters (C:/) but forward slashes work throughout -- File permissions are handled automatically by the Go runtime -- Always assumes \n for line endings. The tool will handle \r\n conversion automatically if needed. - -Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each. diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go deleted file mode 100644 index e9c5732462651cd9e31337f91f4f4d91a36e70c1..0000000000000000000000000000000000000000 --- a/internal/llm/tools/fetch.go +++ /dev/null @@ -1,236 +0,0 @@ -package tools - -import ( - "context" - _ "embed" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" - "unicode/utf8" - - md "github.com/JohannesKaufmann/html-to-markdown" - "github.com/PuerkitoBio/goquery" - "github.com/charmbracelet/crush/internal/permission" -) - -type FetchParams struct { - URL string `json:"url"` - Format string `json:"format"` - Timeout int `json:"timeout,omitempty"` -} - -type FetchPermissionsParams struct { - URL string `json:"url"` - Format string `json:"format"` - Timeout int `json:"timeout,omitempty"` -} - -type fetchTool struct { - client *http.Client - permissions permission.Service - workingDir string -} - -const FetchToolName = "fetch" - -//go:embed fetch.md -var fetchDescription []byte - -func NewFetchTool(permissions permission.Service, workingDir string) BaseTool { - return &fetchTool{ - client: &http.Client{ - Timeout: 30 * time.Second, - Transport: &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - }, - }, - permissions: permissions, - workingDir: workingDir, - } -} - -func (t *fetchTool) Name() string { - return FetchToolName -} - -func (t *fetchTool) Info() ToolInfo { - return ToolInfo{ - Name: FetchToolName, - Description: string(fetchDescription), - Parameters: map[string]any{ - "url": map[string]any{ - "type": "string", - "description": "The URL to fetch content from", - }, - "format": map[string]any{ - "type": "string", - "description": "The format to return the content in (text, markdown, or html)", - "enum": []string{"text", "markdown", "html"}, - }, - "timeout": map[string]any{ - "type": "number", - "description": "Optional timeout in seconds (max 120)", - }, - }, - Required: []string{"url", "format"}, - } -} - -func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { - var params FetchParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return NewTextErrorResponse("Failed to parse fetch parameters: " + err.Error()), nil - } - - if params.URL == "" { - return NewTextErrorResponse("URL parameter is required"), nil - } - - format := strings.ToLower(params.Format) - if format != "text" && format != "markdown" && format != "html" { - return NewTextErrorResponse("Format must be one of: text, markdown, html"), nil - } - - if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") { - return NewTextErrorResponse("URL must start with http:// or https://"), nil - } - - sessionID, messageID := GetContextValues(ctx) - if sessionID == "" || messageID == "" { - return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") - } - - p := t.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: t.workingDir, - ToolCallID: call.ID, - ToolName: FetchToolName, - Action: "fetch", - Description: fmt.Sprintf("Fetch content from URL: %s", params.URL), - Params: FetchPermissionsParams(params), - }, - ) - - if !p { - return ToolResponse{}, permission.ErrorPermissionDenied - } - - // Handle timeout with context - requestCtx := ctx - if params.Timeout > 0 { - maxTimeout := 120 // 2 minutes - if params.Timeout > maxTimeout { - params.Timeout = maxTimeout - } - var cancel context.CancelFunc - requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second) - defer cancel() - } - - req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("User-Agent", "crush/1.0") - - resp, err := t.client.Do(req) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil - } - - maxSize := int64(5 * 1024 * 1024) // 5MB - body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize)) - if err != nil { - return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil - } - - content := string(body) - - isValidUt8 := utf8.ValidString(content) - if !isValidUt8 { - return NewTextErrorResponse("Response content is not valid UTF-8"), nil - } - contentType := resp.Header.Get("Content-Type") - - switch format { - case "text": - if strings.Contains(contentType, "text/html") { - text, err := extractTextFromHTML(content) - if err != nil { - return NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil - } - content = text - } - - case "markdown": - if strings.Contains(contentType, "text/html") { - markdown, err := convertHTMLToMarkdown(content) - if err != nil { - return NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil - } - content = markdown - } - - content = "```\n" + content + "\n```" - - case "html": - // return only the body of the HTML document - if strings.Contains(contentType, "text/html") { - doc, err := goquery.NewDocumentFromReader(strings.NewReader(content)) - if err != nil { - return NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil - } - body, err := doc.Find("body").Html() - if err != nil { - return NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil - } - if body == "" { - return NewTextErrorResponse("No body content found in HTML"), nil - } - content = "\n\n" + body + "\n\n" - } - } - // calculate byte size of content - contentSize := int64(len(content)) - if contentSize > MaxReadSize { - content = content[:MaxReadSize] - content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize) - } - - return NewTextResponse(content), nil -} - -func extractTextFromHTML(html string) (string, error) { - doc, err := goquery.NewDocumentFromReader(strings.NewReader(html)) - if err != nil { - return "", err - } - - text := doc.Find("body").Text() - text = strings.Join(strings.Fields(text), " ") - - return text, nil -} - -func convertHTMLToMarkdown(html string) (string, error) { - converter := md.NewConverter("", true, nil) - - markdown, err := converter.ConvertString(html) - if err != nil { - return "", err - } - - return markdown, nil -} diff --git a/internal/llm/tools/fetch.md b/internal/llm/tools/fetch.md deleted file mode 100644 index ffb7bbaf968e4ae4e338ab1997211dc89a9234d1..0000000000000000000000000000000000000000 --- a/internal/llm/tools/fetch.md +++ /dev/null @@ -1,34 +0,0 @@ -Fetches content from a URL and returns it in the specified format. - -WHEN TO USE THIS TOOL: - -- Use when you need to download content from a URL -- Helpful for retrieving documentation, API responses, or web content -- Useful for getting external information to assist with tasks - -HOW TO USE: - -- Provide the URL to fetch content from -- Specify the desired output format (text, markdown, or html) -- Optionally set a timeout for the request - -FEATURES: - -- Supports three output formats: text, markdown, and html -- Automatically handles HTTP redirects -- Sets reasonable timeouts to prevent hanging -- Validates input parameters before making requests - -LIMITATIONS: - -- Maximum response size is 5MB -- Only supports HTTP and HTTPS protocols -- Cannot handle authentication or cookies -- Some websites may block automated requests - -TIPS: - -- Use text format for plain text content or simple API responses -- Use markdown format for content that should be rendered with formatting -- Use html format when you need the raw HTML structure -- Set appropriate timeouts for potentially slow websites diff --git a/internal/llm/tools/file.go b/internal/llm/tools/file.go deleted file mode 100644 index 7f34fdc1f615031decf00706c58aac37a235b57e..0000000000000000000000000000000000000000 --- a/internal/llm/tools/file.go +++ /dev/null @@ -1,53 +0,0 @@ -package tools - -import ( - "sync" - "time" -) - -// File record to track when files were read/written -type fileRecord struct { - path string - readTime time.Time - writeTime time.Time -} - -var ( - fileRecords = make(map[string]fileRecord) - fileRecordMutex sync.RWMutex -) - -func recordFileRead(path string) { - fileRecordMutex.Lock() - defer fileRecordMutex.Unlock() - - record, exists := fileRecords[path] - if !exists { - record = fileRecord{path: path} - } - record.readTime = time.Now() - fileRecords[path] = record -} - -func getLastReadTime(path string) time.Time { - fileRecordMutex.RLock() - defer fileRecordMutex.RUnlock() - - record, exists := fileRecords[path] - if !exists { - return time.Time{} - } - return record.readTime -} - -func recordFileWrite(path string) { - fileRecordMutex.Lock() - defer fileRecordMutex.Unlock() - - record, exists := fileRecords[path] - if !exists { - record = fileRecord{path: path} - } - record.writeTime = time.Now() - fileRecords[path] = record -} diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go deleted file mode 100644 index 85207c28ea3408ab4a3867e983766d35c9474b47..0000000000000000000000000000000000000000 --- a/internal/llm/tools/glob.go +++ /dev/null @@ -1,150 +0,0 @@ -package tools - -import ( - "bytes" - "context" - _ "embed" - "encoding/json" - "fmt" - "log/slog" - "os/exec" - "path/filepath" - "sort" - "strings" - - "github.com/charmbracelet/crush/internal/fsext" -) - -const GlobToolName = "glob" - -//go:embed glob.md -var globDescription []byte - -type GlobParams struct { - Pattern string `json:"pattern"` - Path string `json:"path"` -} - -type GlobResponseMetadata struct { - NumberOfFiles int `json:"number_of_files"` - Truncated bool `json:"truncated"` -} - -type globTool struct { - workingDir string -} - -func NewGlobTool(workingDir string) BaseTool { - return &globTool{ - workingDir: workingDir, - } -} - -func (g *globTool) Name() string { - return GlobToolName -} - -func (g *globTool) Info() ToolInfo { - return ToolInfo{ - Name: GlobToolName, - Description: string(globDescription), - Parameters: map[string]any{ - "pattern": map[string]any{ - "type": "string", - "description": "The glob pattern to match files against", - }, - "path": map[string]any{ - "type": "string", - "description": "The directory to search in. Defaults to the current working directory.", - }, - }, - Required: []string{"pattern"}, - } -} - -func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { - var params GlobParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil - } - - if params.Pattern == "" { - return NewTextErrorResponse("pattern is required"), nil - } - - searchPath := params.Path - if searchPath == "" { - searchPath = g.workingDir - } - - files, truncated, err := globFiles(ctx, params.Pattern, searchPath, 100) - if err != nil { - return ToolResponse{}, fmt.Errorf("error finding files: %w", err) - } - - var output string - if len(files) == 0 { - output = "No files found" - } else { - output = strings.Join(files, "\n") - if truncated { - output += "\n\n(Results are truncated. Consider using a more specific path or pattern.)" - } - } - - return WithResponseMetadata( - NewTextResponse(output), - GlobResponseMetadata{ - NumberOfFiles: len(files), - Truncated: truncated, - }, - ), nil -} - -func globFiles(ctx context.Context, pattern, searchPath string, limit int) ([]string, bool, error) { - cmdRg := getRgCmd(ctx, pattern) - if cmdRg != nil { - cmdRg.Dir = searchPath - matches, err := runRipgrep(cmdRg, searchPath, limit) - if err == nil { - return matches, len(matches) >= limit && limit > 0, nil - } - slog.Warn("Ripgrep execution failed, falling back to doublestar", "error", err) - } - - return fsext.GlobWithDoubleStar(pattern, searchPath, limit) -} - -func runRipgrep(cmd *exec.Cmd, searchRoot string, limit int) ([]string, error) { - out, err := cmd.CombinedOutput() - if err != nil { - if ee, ok := err.(*exec.ExitError); ok && ee.ExitCode() == 1 { - return nil, nil - } - return nil, fmt.Errorf("ripgrep: %w\n%s", err, out) - } - - var matches []string - for p := range bytes.SplitSeq(out, []byte{0}) { - if len(p) == 0 { - continue - } - absPath := string(p) - if !filepath.IsAbs(absPath) { - absPath = filepath.Join(searchRoot, absPath) - } - if fsext.SkipHidden(absPath) { - continue - } - matches = append(matches, absPath) - } - - sort.SliceStable(matches, func(i, j int) bool { - return len(matches[i]) < len(matches[j]) - }) - - if limit > 0 && len(matches) > limit { - matches = matches[:limit] - } - return matches, nil -} diff --git a/internal/llm/tools/glob.md b/internal/llm/tools/glob.md deleted file mode 100644 index 1d73d42c9bc59372368965f8c04ef0c068d3deca..0000000000000000000000000000000000000000 --- a/internal/llm/tools/glob.md +++ /dev/null @@ -1,46 +0,0 @@ -Fast file pattern matching tool that finds files by name and pattern, returning matching paths sorted by modification time (newest first). - -WHEN TO USE THIS TOOL: - -- Use when you need to find files by name patterns or extensions -- Great for finding specific file types across a directory structure -- Useful for discovering files that match certain naming conventions - -HOW TO USE: - -- Provide a glob pattern to match against file paths -- Optionally specify a starting directory (defaults to current working directory) -- Results are sorted with most recently modified files first - -GLOB PATTERN SYNTAX: - -- '\*' matches any sequence of non-separator characters -- '\*\*' matches any sequence of characters, including separators -- '?' matches any single non-separator character -- '[...]' matches any character in the brackets -- '[!...]' matches any character not in the brackets - -COMMON PATTERN EXAMPLES: - -- '\*.js' - Find all JavaScript files in the current directory -- '\*_/_.js' - Find all JavaScript files in any subdirectory -- 'src/\*_/_.{ts,tsx}' - Find all TypeScript files in the src directory -- '\*.{html,css,js}' - Find all HTML, CSS, and JS files - -LIMITATIONS: - -- Results are limited to 100 files (newest first) -- Does not search file contents (use Grep tool for that) -- Hidden files (starting with '.') are skipped - -WINDOWS NOTES: - -- Path separators are handled automatically (both / and \ work) -- Uses ripgrep (rg) command if available, otherwise falls back to built-in Go implementation - -TIPS: - -- Patterns should use forward slashes (/) for cross-platform compatibility -- For the most useful results, combine with the Grep tool: first find files with Glob, then search their contents with Grep -- When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead -- Always check if results are truncated and refine your search pattern if needed diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go deleted file mode 100644 index cbf50360b9355c05797690678a99d1310b19556f..0000000000000000000000000000000000000000 --- a/internal/llm/tools/grep.go +++ /dev/null @@ -1,470 +0,0 @@ -package tools - -import ( - "bufio" - "context" - _ "embed" - "encoding/json" - "fmt" - "io" - "os" - "os/exec" - "path/filepath" - "regexp" - "sort" - "strconv" - "strings" - "sync" - "time" - - "github.com/charmbracelet/crush/internal/fsext" -) - -// regexCache provides thread-safe caching of compiled regex patterns -type regexCache struct { - cache map[string]*regexp.Regexp - mu sync.RWMutex -} - -// newRegexCache creates a new regex cache -func newRegexCache() *regexCache { - return ®exCache{ - cache: make(map[string]*regexp.Regexp), - } -} - -// get retrieves a compiled regex from cache or compiles and caches it -func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) { - // Try to get from cache first (read lock) - rc.mu.RLock() - if regex, exists := rc.cache[pattern]; exists { - rc.mu.RUnlock() - return regex, nil - } - rc.mu.RUnlock() - - // Compile the regex (write lock) - rc.mu.Lock() - defer rc.mu.Unlock() - - // Double-check in case another goroutine compiled it while we waited - if regex, exists := rc.cache[pattern]; exists { - return regex, nil - } - - // Compile and cache the regex - regex, err := regexp.Compile(pattern) - if err != nil { - return nil, err - } - - rc.cache[pattern] = regex - return regex, nil -} - -// Global regex cache instances -var ( - searchRegexCache = newRegexCache() - globRegexCache = newRegexCache() - // Pre-compiled regex for glob conversion (used frequently) - globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`) -) - -type GrepParams struct { - Pattern string `json:"pattern"` - Path string `json:"path"` - Include string `json:"include"` - LiteralText bool `json:"literal_text"` -} - -type grepMatch struct { - path string - modTime time.Time - lineNum int - lineText string -} - -type GrepResponseMetadata struct { - NumberOfMatches int `json:"number_of_matches"` - Truncated bool `json:"truncated"` -} - -type grepTool struct { - workingDir string -} - -const GrepToolName = "grep" - -//go:embed grep.md -var grepDescription []byte - -func NewGrepTool(workingDir string) BaseTool { - return &grepTool{ - workingDir: workingDir, - } -} - -func (g *grepTool) Name() string { - return GrepToolName -} - -func (g *grepTool) Info() ToolInfo { - return ToolInfo{ - Name: GrepToolName, - Description: string(grepDescription), - Parameters: map[string]any{ - "pattern": map[string]any{ - "type": "string", - "description": "The regex pattern to search for in file contents", - }, - "path": map[string]any{ - "type": "string", - "description": "The directory to search in. Defaults to the current working directory.", - }, - "include": map[string]any{ - "type": "string", - "description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")", - }, - "literal_text": map[string]any{ - "type": "boolean", - "description": "If true, the pattern will be treated as literal text with special regex characters escaped. Default is false.", - }, - }, - Required: []string{"pattern"}, - } -} - -// escapeRegexPattern escapes special regex characters so they're treated as literal characters -func escapeRegexPattern(pattern string) string { - specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"} - escaped := pattern - - for _, char := range specialChars { - escaped = strings.ReplaceAll(escaped, char, "\\"+char) - } - - return escaped -} - -func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { - var params GrepParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil - } - - if params.Pattern == "" { - return NewTextErrorResponse("pattern is required"), nil - } - - // If literal_text is true, escape the pattern - searchPattern := params.Pattern - if params.LiteralText { - searchPattern = escapeRegexPattern(params.Pattern) - } - - searchPath := params.Path - if searchPath == "" { - searchPath = g.workingDir - } - - matches, truncated, err := searchFiles(ctx, searchPattern, searchPath, params.Include, 100) - if err != nil { - return ToolResponse{}, fmt.Errorf("error searching files: %w", err) - } - - var output strings.Builder - if len(matches) == 0 { - output.WriteString("No files found") - } else { - fmt.Fprintf(&output, "Found %d matches\n", len(matches)) - - currentFile := "" - for _, match := range matches { - if currentFile != match.path { - if currentFile != "" { - output.WriteString("\n") - } - currentFile = match.path - fmt.Fprintf(&output, "%s:\n", match.path) - } - if match.lineNum > 0 { - fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, match.lineText) - } else { - fmt.Fprintf(&output, " %s\n", match.path) - } - } - - if truncated { - output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)") - } - } - - return WithResponseMetadata( - NewTextResponse(output.String()), - GrepResponseMetadata{ - NumberOfMatches: len(matches), - Truncated: truncated, - }, - ), nil -} - -func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) { - matches, err := searchWithRipgrep(ctx, pattern, rootPath, include) - if err != nil { - matches, err = searchFilesWithRegex(pattern, rootPath, include) - if err != nil { - return nil, false, err - } - } - - sort.Slice(matches, func(i, j int) bool { - return matches[i].modTime.After(matches[j].modTime) - }) - - truncated := len(matches) > limit - if truncated { - matches = matches[:limit] - } - - return matches, truncated, nil -} - -func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) { - cmd := getRgSearchCmd(ctx, pattern, path, include) - if cmd == nil { - return nil, fmt.Errorf("ripgrep not found in $PATH") - } - - // Only add ignore files if they exist - for _, ignoreFile := range []string{".gitignore", ".crushignore"} { - ignorePath := filepath.Join(path, ignoreFile) - if _, err := os.Stat(ignorePath); err == nil { - cmd.Args = append(cmd.Args, "--ignore-file", ignorePath) - } - } - - output, err := cmd.Output() - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 { - return []grepMatch{}, nil - } - return nil, err - } - - lines := strings.Split(strings.TrimSpace(string(output)), "\n") - matches := make([]grepMatch, 0, len(lines)) - - for _, line := range lines { - if line == "" { - continue - } - - // Parse ripgrep output using null separation - filePath, lineNumStr, lineText, ok := parseRipgrepLine(line) - if !ok { - continue - } - - lineNum, err := strconv.Atoi(lineNumStr) - if err != nil { - continue - } - - fileInfo, err := os.Stat(filePath) - if err != nil { - continue // Skip files we can't access - } - - matches = append(matches, grepMatch{ - path: filePath, - modTime: fileInfo.ModTime(), - lineNum: lineNum, - lineText: lineText, - }) - } - - return matches, nil -} - -// parseRipgrepLine parses ripgrep output with null separation to handle Windows paths -func parseRipgrepLine(line string) (filePath, lineNum, lineText string, ok bool) { - // Split on null byte first to separate filename from rest - parts := strings.SplitN(line, "\x00", 2) - if len(parts) != 2 { - return "", "", "", false - } - - filePath = parts[0] - remainder := parts[1] - - // Now split the remainder on first colon: "linenum:content" - colonIndex := strings.Index(remainder, ":") - if colonIndex == -1 { - return "", "", "", false - } - - lineNumStr := remainder[:colonIndex] - lineText = remainder[colonIndex+1:] - - if _, err := strconv.Atoi(lineNumStr); err != nil { - return "", "", "", false - } - - return filePath, lineNumStr, lineText, true -} - -func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) { - matches := []grepMatch{} - - // Use cached regex compilation - regex, err := searchRegexCache.get(pattern) - if err != nil { - return nil, fmt.Errorf("invalid regex pattern: %w", err) - } - - var includePattern *regexp.Regexp - if include != "" { - regexPattern := globToRegex(include) - includePattern, err = globRegexCache.get(regexPattern) - if err != nil { - return nil, fmt.Errorf("invalid include pattern: %w", err) - } - } - - // Create walker with gitignore and crushignore support - walker := fsext.NewFastGlobWalker(rootPath) - - err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error { - if err != nil { - return nil // Skip errors - } - - if info.IsDir() { - // Check if directory should be skipped - if walker.ShouldSkip(path) { - return filepath.SkipDir - } - return nil // Continue into directory - } - - // Use walker's shouldSkip method for files - if walker.ShouldSkip(path) { - return nil - } - - // Skip hidden files (starting with a dot) to match ripgrep's default behavior - base := filepath.Base(path) - if base != "." && strings.HasPrefix(base, ".") { - return nil - } - - if includePattern != nil && !includePattern.MatchString(path) { - return nil - } - - match, lineNum, lineText, err := fileContainsPattern(path, regex) - if err != nil { - return nil // Skip files we can't read - } - - if match { - matches = append(matches, grepMatch{ - path: path, - modTime: info.ModTime(), - lineNum: lineNum, - lineText: lineText, - }) - - if len(matches) >= 200 { - return filepath.SkipAll - } - } - - return nil - }) - if err != nil { - return nil, err - } - - return matches, nil -} - -func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) { - // Quick binary file detection - if isBinaryFile(filePath) { - return false, 0, "", nil - } - - file, err := os.Open(filePath) - if err != nil { - return false, 0, "", err - } - defer file.Close() - - scanner := bufio.NewScanner(file) - lineNum := 0 - for scanner.Scan() { - lineNum++ - line := scanner.Text() - if pattern.MatchString(line) { - return true, lineNum, line, nil - } - } - - return false, 0, "", scanner.Err() -} - -var binaryExts = map[string]struct{}{ - ".exe": {}, ".dll": {}, ".so": {}, ".dylib": {}, - ".bin": {}, ".obj": {}, ".o": {}, ".a": {}, - ".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {}, - ".jpg": {}, ".jpeg": {}, ".png": {}, ".gif": {}, - ".pdf": {}, ".doc": {}, ".docx": {}, ".xls": {}, - ".mp3": {}, ".mp4": {}, ".avi": {}, ".mov": {}, -} - -// isBinaryFile performs a quick check to determine if a file is binary -func isBinaryFile(filePath string) bool { - // Check file extension first (fastest) - ext := strings.ToLower(filepath.Ext(filePath)) - if _, isBinary := binaryExts[ext]; isBinary { - return true - } - - // Quick content check for files without clear extensions - file, err := os.Open(filePath) - if err != nil { - return false // If we can't open it, let the caller handle the error - } - defer file.Close() - - // Read first 512 bytes to check for null bytes - buffer := make([]byte, 512) - n, err := file.Read(buffer) - if err != nil && err != io.EOF { - return false - } - - // Check for null bytes (common in binary files) - for i := range n { - if buffer[i] == 0 { - return true - } - } - - return false -} - -func globToRegex(glob string) string { - regexPattern := strings.ReplaceAll(glob, ".", "\\.") - regexPattern = strings.ReplaceAll(regexPattern, "*", ".*") - regexPattern = strings.ReplaceAll(regexPattern, "?", ".") - - // Use pre-compiled regex instead of compiling each time - regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string { - inner := match[1 : len(match)-1] - return "(" + strings.ReplaceAll(inner, ",", "|") + ")" - }) - - return regexPattern -} diff --git a/internal/llm/tools/grep.md b/internal/llm/tools/grep.md deleted file mode 100644 index c17a0ad1b8531a4fb9fd73b56e89973c749f91b1..0000000000000000000000000000000000000000 --- a/internal/llm/tools/grep.md +++ /dev/null @@ -1,54 +0,0 @@ -Fast content search tool that finds files containing specific text or patterns, returning matching file paths sorted by modification time (newest first). - -WHEN TO USE THIS TOOL: - -- Use when you need to find files containing specific text or patterns -- Great for searching code bases for function names, variable declarations, or error messages -- Useful for finding all files that use a particular API or pattern - -HOW TO USE: - -- Provide a regex pattern to search for within file contents -- Set literal_text=true if you want to search for the exact text with special characters (recommended for non-regex users) -- Optionally specify a starting directory (defaults to current working directory) -- Optionally provide an include pattern to filter which files to search -- Results are sorted with most recently modified files first - -REGEX PATTERN SYNTAX (when literal_text=false): - -- Supports standard regular expression syntax -- 'function' searches for the literal text "function" -- 'log\..\*Error' finds text starting with "log." and ending with "Error" -- 'import\s+.\*\s+from' finds import statements in JavaScript/TypeScript - -COMMON INCLUDE PATTERN EXAMPLES: - -- '\*.js' - Only search JavaScript files -- '\*.{ts,tsx}' - Only search TypeScript files -- '\*.go' - Only search Go files - -LIMITATIONS: - -- Results are limited to 100 files (newest first) -- Performance depends on the number of files being searched -- Very large binary files may be skipped -- Hidden files (starting with '.') are skipped - -IGNORE FILE SUPPORT: - -- Respects .gitignore patterns to skip ignored files and directories -- Respects .crushignore patterns for additional ignore rules -- Both ignore files are automatically detected in the search root directory - -CROSS-PLATFORM NOTES: - -- Uses ripgrep (rg) command if available for better performance -- Falls back to built-in Go implementation if ripgrep is not available -- File paths are normalized automatically for cross-platform compatibility - -TIPS: - -- For faster, more targeted searches, first use Glob to find relevant files, then use Grep -- When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead -- Always check if results are truncated and refine your search pattern if needed -- Use literal_text=true when searching for exact text containing special characters like dots, parentheses, etc. diff --git a/internal/llm/tools/grep_test.go b/internal/llm/tools/grep_test.go deleted file mode 100644 index 53c96b22df444adfba59c6b13995a104411a57be..0000000000000000000000000000000000000000 --- a/internal/llm/tools/grep_test.go +++ /dev/null @@ -1,200 +0,0 @@ -package tools - -import ( - "os" - "path/filepath" - "regexp" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestRegexCache(t *testing.T) { - cache := newRegexCache() - - // Test basic caching - pattern := "test.*pattern" - regex1, err := cache.get(pattern) - if err != nil { - t.Fatalf("Failed to compile regex: %v", err) - } - - regex2, err := cache.get(pattern) - if err != nil { - t.Fatalf("Failed to get cached regex: %v", err) - } - - // Should be the same instance (cached) - if regex1 != regex2 { - t.Error("Expected cached regex to be the same instance") - } - - // Test that it actually works - if !regex1.MatchString("test123pattern") { - t.Error("Regex should match test string") - } -} - -func TestGlobToRegexCaching(t *testing.T) { - // Test that globToRegex uses pre-compiled regex - pattern1 := globToRegex("*.{js,ts}") - - // Should not panic and should work correctly - regex1, err := regexp.Compile(pattern1) - if err != nil { - t.Fatalf("Failed to compile glob regex: %v", err) - } - - if !regex1.MatchString("test.js") { - t.Error("Glob regex should match .js files") - } - if !regex1.MatchString("test.ts") { - t.Error("Glob regex should match .ts files") - } - if regex1.MatchString("test.go") { - t.Error("Glob regex should not match .go files") - } -} - -func TestGrepWithIgnoreFiles(t *testing.T) { - t.Parallel() - tempDir := t.TempDir() - - // Create test files - testFiles := map[string]string{ - "file1.txt": "hello world", - "file2.txt": "hello world", - "ignored/file3.txt": "hello world", - "node_modules/lib.js": "hello world", - "secret.key": "hello world", - } - - for path, content := range testFiles { - fullPath := filepath.Join(tempDir, path) - require.NoError(t, os.MkdirAll(filepath.Dir(fullPath), 0o755)) - require.NoError(t, os.WriteFile(fullPath, []byte(content), 0o644)) - } - - // Create .gitignore file - gitignoreContent := "ignored/\n*.key\n" - require.NoError(t, os.WriteFile(filepath.Join(tempDir, ".gitignore"), []byte(gitignoreContent), 0o644)) - - // Create .crushignore file - crushignoreContent := "node_modules/\n" - require.NoError(t, os.WriteFile(filepath.Join(tempDir, ".crushignore"), []byte(crushignoreContent), 0o644)) - - // Test both implementations - for name, fn := range map[string]func(pattern, path, include string) ([]grepMatch, error){ - "regex": searchFilesWithRegex, - "rg": func(pattern, path, include string) ([]grepMatch, error) { - return searchWithRipgrep(t.Context(), pattern, path, include) - }, - } { - t.Run(name, func(t *testing.T) { - t.Parallel() - - if name == "rg" && getRg() == "" { - t.Skip("rg is not in $PATH") - } - - matches, err := fn("hello world", tempDir, "") - require.NoError(t, err) - - // Convert matches to a set of file paths for easier testing - foundFiles := make(map[string]bool) - for _, match := range matches { - foundFiles[filepath.Base(match.path)] = true - } - - // Should find file1.txt and file2.txt - require.True(t, foundFiles["file1.txt"], "Should find file1.txt") - require.True(t, foundFiles["file2.txt"], "Should find file2.txt") - - // Should NOT find ignored files - require.False(t, foundFiles["file3.txt"], "Should not find file3.txt (ignored by .gitignore)") - require.False(t, foundFiles["lib.js"], "Should not find lib.js (ignored by .crushignore)") - require.False(t, foundFiles["secret.key"], "Should not find secret.key (ignored by .gitignore)") - - // Should find exactly 2 matches - require.Equal(t, 2, len(matches), "Should find exactly 2 matches") - }) - } -} - -func TestSearchImplementations(t *testing.T) { - t.Parallel() - tempDir := t.TempDir() - - for path, content := range map[string]string{ - "file1.go": "package main\nfunc main() {\n\tfmt.Println(\"hello world\")\n}", - "file2.js": "console.log('hello world');", - "file3.txt": "hello world from text file", - "binary.exe": "\x00\x01\x02\x03", - "empty.txt": "", - "subdir/nested.go": "package nested\n// hello world comment", - ".hidden.txt": "hello world in hidden file", - "file4.txt": "hello world from a banana", - "file5.txt": "hello world from a grape", - } { - fullPath := filepath.Join(tempDir, path) - require.NoError(t, os.MkdirAll(filepath.Dir(fullPath), 0o755)) - require.NoError(t, os.WriteFile(fullPath, []byte(content), 0o644)) - } - - require.NoError(t, os.WriteFile(filepath.Join(tempDir, ".gitignore"), []byte("file4.txt\n"), 0o644)) - require.NoError(t, os.WriteFile(filepath.Join(tempDir, ".crushignore"), []byte("file5.txt\n"), 0o644)) - - for name, fn := range map[string]func(pattern, path, include string) ([]grepMatch, error){ - "regex": searchFilesWithRegex, - "rg": func(pattern, path, include string) ([]grepMatch, error) { - return searchWithRipgrep(t.Context(), pattern, path, include) - }, - } { - t.Run(name, func(t *testing.T) { - t.Parallel() - - if name == "rg" && getRg() == "" { - t.Skip("rg is not in $PATH") - } - - matches, err := fn("hello world", tempDir, "") - require.NoError(t, err) - - require.Equal(t, len(matches), 4) - for _, match := range matches { - require.NotEmpty(t, match.path) - require.NotZero(t, match.lineNum) - require.NotEmpty(t, match.lineText) - require.NotZero(t, match.modTime) - require.NotContains(t, match.path, ".hidden.txt") - require.NotContains(t, match.path, "file4.txt") - require.NotContains(t, match.path, "file5.txt") - require.NotContains(t, match.path, "binary.exe") - } - }) - } -} - -// Benchmark to show performance improvement -func BenchmarkRegexCacheVsCompile(b *testing.B) { - cache := newRegexCache() - pattern := "test.*pattern.*[0-9]+" - - b.Run("WithCache", func(b *testing.B) { - for b.Loop() { - _, err := cache.get(pattern) - if err != nil { - b.Fatal(err) - } - } - }) - - b.Run("WithoutCache", func(b *testing.B) { - for b.Loop() { - _, err := regexp.Compile(pattern) - if err != nil { - b.Fatal(err) - } - } - }) -} diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go deleted file mode 100644 index 305f7f10249594ff06ac008a8bf81145d7d834de..0000000000000000000000000000000000000000 --- a/internal/llm/tools/ls.go +++ /dev/null @@ -1,283 +0,0 @@ -package tools - -import ( - "cmp" - "context" - _ "embed" - "encoding/json" - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fsext" - "github.com/charmbracelet/crush/internal/permission" -) - -type LSParams struct { - Path string `json:"path"` - Ignore []string `json:"ignore"` - Depth int `json:"depth"` -} - -type LSPermissionsParams struct { - Path string `json:"path"` - Ignore []string `json:"ignore"` - Depth int `json:"depth"` -} - -type TreeNode struct { - Name string `json:"name"` - Path string `json:"path"` - Type string `json:"type"` // "file" or "directory" - Children []*TreeNode `json:"children,omitempty"` -} - -type LSResponseMetadata struct { - NumberOfFiles int `json:"number_of_files"` - Truncated bool `json:"truncated"` -} - -type lsTool struct { - workingDir string - permissions permission.Service -} - -const ( - LSToolName = "ls" - maxLSFiles = 1000 -) - -//go:embed ls.md -var lsDescription []byte - -func NewLsTool(permissions permission.Service, workingDir string) BaseTool { - return &lsTool{ - workingDir: workingDir, - permissions: permissions, - } -} - -func (l *lsTool) Name() string { - return LSToolName -} - -func (l *lsTool) Info() ToolInfo { - return ToolInfo{ - Name: LSToolName, - Description: string(lsDescription), - Parameters: map[string]any{ - "path": map[string]any{ - "type": "string", - "description": "The path to the directory to list (defaults to current working directory)", - }, - "depth": map[string]any{ - "type": "integer", - "description": "The maximum depth to traverse", - }, - "ignore": map[string]any{ - "type": "array", - "description": "List of glob patterns to ignore", - "items": map[string]any{ - "type": "string", - }, - }, - }, - Required: []string{"path"}, - } -} - -func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { - var params LSParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil - } - - searchPath, err := fsext.Expand(cmp.Or(params.Path, l.workingDir)) - if err != nil { - return ToolResponse{}, fmt.Errorf("error expanding path: %w", err) - } - - if !filepath.IsAbs(searchPath) { - searchPath = filepath.Join(l.workingDir, searchPath) - } - - // Check if directory is outside working directory and request permission if needed - absWorkingDir, err := filepath.Abs(l.workingDir) - if err != nil { - return ToolResponse{}, fmt.Errorf("error resolving working directory: %w", err) - } - - absSearchPath, err := filepath.Abs(searchPath) - if err != nil { - return ToolResponse{}, fmt.Errorf("error resolving search path: %w", err) - } - - relPath, err := filepath.Rel(absWorkingDir, absSearchPath) - if err != nil || strings.HasPrefix(relPath, "..") { - // Directory is outside working directory, request permission - sessionID, messageID := GetContextValues(ctx) - if sessionID == "" || messageID == "" { - return ToolResponse{}, fmt.Errorf("session ID and message ID are required for accessing directories outside working directory") - } - - granted := l.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: absSearchPath, - ToolCallID: call.ID, - ToolName: LSToolName, - Action: "list", - Description: fmt.Sprintf("List directory outside working directory: %s", absSearchPath), - Params: LSPermissionsParams(params), - }, - ) - - if !granted { - return ToolResponse{}, permission.ErrorPermissionDenied - } - } - - output, metadata, err := ListDirectoryTree(searchPath, params) - if err != nil { - return ToolResponse{}, err - } - - return WithResponseMetadata( - NewTextResponse(output), - metadata, - ), nil -} - -func ListDirectoryTree(searchPath string, params LSParams) (string, LSResponseMetadata, error) { - if _, err := os.Stat(searchPath); os.IsNotExist(err) { - return "", LSResponseMetadata{}, fmt.Errorf("path does not exist: %s", searchPath) - } - - ls := config.Get().Tools.Ls - depth, limit := ls.Limits() - maxFiles := min(limit, maxLSFiles) - files, truncated, err := fsext.ListDirectory( - searchPath, - params.Ignore, - cmp.Or(params.Depth, depth), - maxFiles, - ) - if err != nil { - return "", LSResponseMetadata{}, fmt.Errorf("error listing directory: %w", err) - } - - metadata := LSResponseMetadata{ - NumberOfFiles: len(files), - Truncated: truncated, - } - tree := createFileTree(files, searchPath) - - var output string - if truncated { - output = fmt.Sprintf("There are more than %d files in the directory. Use a more specific path or use the Glob tool to find specific files. The first %[1]d files and directories are included below.\n", maxFiles) - } - if depth > 0 { - output = fmt.Sprintf("The directory tree is shown up to a depth of %d. Use a higher depth and a specific path to see more levels.\n", cmp.Or(params.Depth, depth)) - } - return output + "\n" + printTree(tree, searchPath), metadata, nil -} - -func createFileTree(sortedPaths []string, rootPath string) []*TreeNode { - root := []*TreeNode{} - pathMap := make(map[string]*TreeNode) - - for _, path := range sortedPaths { - relativePath := strings.TrimPrefix(path, rootPath) - parts := strings.Split(relativePath, string(filepath.Separator)) - currentPath := "" - var parentPath string - - var cleanParts []string - for _, part := range parts { - if part != "" { - cleanParts = append(cleanParts, part) - } - } - parts = cleanParts - - if len(parts) == 0 { - continue - } - - for i, part := range parts { - if currentPath == "" { - currentPath = part - } else { - currentPath = filepath.Join(currentPath, part) - } - - if _, exists := pathMap[currentPath]; exists { - parentPath = currentPath - continue - } - - isLastPart := i == len(parts)-1 - isDir := !isLastPart || strings.HasSuffix(relativePath, string(filepath.Separator)) - nodeType := "file" - if isDir { - nodeType = "directory" - } - newNode := &TreeNode{ - Name: part, - Path: currentPath, - Type: nodeType, - Children: []*TreeNode{}, - } - - pathMap[currentPath] = newNode - - if i > 0 && parentPath != "" { - if parent, ok := pathMap[parentPath]; ok { - parent.Children = append(parent.Children, newNode) - } - } else { - root = append(root, newNode) - } - - parentPath = currentPath - } - } - - return root -} - -func printTree(tree []*TreeNode, rootPath string) string { - var result strings.Builder - - result.WriteString("- ") - result.WriteString(rootPath) - if rootPath[len(rootPath)-1] != '/' { - result.WriteByte(filepath.Separator) - } - result.WriteByte('\n') - - for _, node := range tree { - printNode(&result, node, 1) - } - - return result.String() -} - -func printNode(builder *strings.Builder, node *TreeNode, level int) { - indent := strings.Repeat(" ", level) - - nodeName := node.Name - if node.Type == "directory" { - nodeName = nodeName + string(filepath.Separator) - } - - fmt.Fprintf(builder, "%s- %s\n", indent, nodeName) - - if node.Type == "directory" && len(node.Children) > 0 { - for _, child := range node.Children { - printNode(builder, child, level+1) - } - } -} diff --git a/internal/llm/tools/ls.md b/internal/llm/tools/ls.md deleted file mode 100644 index 798ddcc90f9659032870a3b893129a7ec2edad50..0000000000000000000000000000000000000000 --- a/internal/llm/tools/ls.md +++ /dev/null @@ -1,40 +0,0 @@ -Directory listing tool that shows files and subdirectories in a tree structure, helping you explore and understand the project organization. - -WHEN TO USE THIS TOOL: - -- Use when you need to explore the structure of a directory -- Helpful for understanding the organization of a project -- Good first step when getting familiar with a new codebase - -HOW TO USE: - -- Provide a path to list (defaults to current working directory) -- Optionally specify glob patterns to ignore -- Results are displayed in a tree structure - -FEATURES: - -- Displays a hierarchical view of files and directories -- Automatically skips hidden files/directories (starting with '.') -- Skips common system directories like **pycache** -- Can filter out files matching specific patterns - -LIMITATIONS: - -- Results are limited to 1000 files -- Very large directories will be truncated -- Does not show file sizes or permissions -- Cannot recursively list all directories in a large project - -WINDOWS NOTES: - -- Hidden file detection uses Unix convention (files starting with '.') -- Windows-specific hidden files (with hidden attribute) are not automatically skipped -- Common Windows directories like System32, Program Files are not in default ignore list -- Path separators are handled automatically (both / and \ work) - -TIPS: - -- Use Glob tool for finding files by name patterns instead of browsing -- Use Grep tool for searching file contents -- Combine with other tools for more effective exploration diff --git a/internal/llm/tools/multiedit.go b/internal/llm/tools/multiedit.go deleted file mode 100644 index 1d6e415ceb8cc10bcc32788cddbbdc40fc781907..0000000000000000000000000000000000000000 --- a/internal/llm/tools/multiedit.go +++ /dev/null @@ -1,424 +0,0 @@ -package tools - -import ( - "context" - _ "embed" - "encoding/json" - "fmt" - "log/slog" - "os" - "path/filepath" - "strings" - "time" - - "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/diff" - "github.com/charmbracelet/crush/internal/fsext" - "github.com/charmbracelet/crush/internal/history" - "github.com/charmbracelet/crush/internal/lsp" - "github.com/charmbracelet/crush/internal/permission" -) - -type MultiEditOperation struct { - OldString string `json:"old_string"` - NewString string `json:"new_string"` - ReplaceAll bool `json:"replace_all,omitempty"` -} - -type MultiEditParams struct { - FilePath string `json:"file_path"` - Edits []MultiEditOperation `json:"edits"` -} - -type MultiEditPermissionsParams struct { - FilePath string `json:"file_path"` - OldContent string `json:"old_content,omitempty"` - NewContent string `json:"new_content,omitempty"` -} - -type MultiEditResponseMetadata struct { - Additions int `json:"additions"` - Removals int `json:"removals"` - OldContent string `json:"old_content,omitempty"` - NewContent string `json:"new_content,omitempty"` - EditsApplied int `json:"edits_applied"` -} - -type multiEditTool struct { - lspClients *csync.Map[string, *lsp.Client] - permissions permission.Service - files history.Service - workingDir string -} - -const MultiEditToolName = "multiedit" - -//go:embed multiedit.md -var multieditDescription []byte - -func NewMultiEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool { - return &multiEditTool{ - lspClients: lspClients, - permissions: permissions, - files: files, - workingDir: workingDir, - } -} - -func (m *multiEditTool) Name() string { - return MultiEditToolName -} - -func (m *multiEditTool) Info() ToolInfo { - return ToolInfo{ - Name: MultiEditToolName, - Description: string(multieditDescription), - Parameters: map[string]any{ - "file_path": map[string]any{ - "type": "string", - "description": "The absolute path to the file to modify", - }, - "edits": map[string]any{ - "type": "array", - "items": map[string]any{ - "type": "object", - "properties": map[string]any{ - "old_string": map[string]any{ - "type": "string", - "description": "The text to replace", - }, - "new_string": map[string]any{ - "type": "string", - "description": "The text to replace it with", - }, - "replace_all": map[string]any{ - "type": "boolean", - "default": false, - "description": "Replace all occurrences of old_string (default false).", - }, - }, - "required": []string{"old_string", "new_string"}, - "additionalProperties": false, - }, - "minItems": 1, - "description": "Array of edit operations to perform sequentially on the file", - }, - }, - Required: []string{"file_path", "edits"}, - } -} - -func (m *multiEditTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { - var params MultiEditParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return NewTextErrorResponse("invalid parameters"), nil - } - - if params.FilePath == "" { - return NewTextErrorResponse("file_path is required"), nil - } - - if len(params.Edits) == 0 { - return NewTextErrorResponse("at least one edit operation is required"), nil - } - - if !filepath.IsAbs(params.FilePath) { - params.FilePath = filepath.Join(m.workingDir, params.FilePath) - } - - // Validate all edits before applying any - if err := m.validateEdits(params.Edits); err != nil { - return NewTextErrorResponse(err.Error()), nil - } - - var response ToolResponse - var err error - - // Handle file creation case (first edit has empty old_string) - if len(params.Edits) > 0 && params.Edits[0].OldString == "" { - response, err = m.processMultiEditWithCreation(ctx, params, call) - } else { - response, err = m.processMultiEditExistingFile(ctx, params, call) - } - - if err != nil { - return response, err - } - - if response.IsError { - return response, nil - } - - // Notify LSP clients about the change - notifyLSPs(ctx, m.lspClients, params.FilePath) - - // Wait for LSP diagnostics and add them to the response - text := fmt.Sprintf("\n%s\n\n", response.Content) - text += getDiagnostics(params.FilePath, m.lspClients) - response.Content = text - return response, nil -} - -func (m *multiEditTool) validateEdits(edits []MultiEditOperation) error { - for i, edit := range edits { - if edit.OldString == edit.NewString { - return fmt.Errorf("edit %d: old_string and new_string are identical", i+1) - } - // Only the first edit can have empty old_string (for file creation) - if i > 0 && edit.OldString == "" { - return fmt.Errorf("edit %d: only the first edit can have empty old_string (for file creation)", i+1) - } - } - return nil -} - -func (m *multiEditTool) processMultiEditWithCreation(ctx context.Context, params MultiEditParams, call ToolCall) (ToolResponse, error) { - // First edit creates the file - firstEdit := params.Edits[0] - if firstEdit.OldString != "" { - return NewTextErrorResponse("first edit must have empty old_string for file creation"), nil - } - - // Check if file already exists - if _, err := os.Stat(params.FilePath); err == nil { - return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", params.FilePath)), nil - } else if !os.IsNotExist(err) { - return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) - } - - // Create parent directories - dir := filepath.Dir(params.FilePath) - if err := os.MkdirAll(dir, 0o755); err != nil { - return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err) - } - - // Start with the content from the first edit - currentContent := firstEdit.NewString - - // Apply remaining edits to the content - for i := 1; i < len(params.Edits); i++ { - edit := params.Edits[i] - newContent, err := m.applyEditToContent(currentContent, edit) - if err != nil { - return NewTextErrorResponse(fmt.Sprintf("edit %d failed: %s", i+1, err.Error())), nil - } - currentContent = newContent - } - - // Get session and message IDs - sessionID, messageID := GetContextValues(ctx) - if sessionID == "" || messageID == "" { - return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") - } - - // Check permissions - _, additions, removals := diff.GenerateDiff("", currentContent, strings.TrimPrefix(params.FilePath, m.workingDir)) - - p := m.permissions.Request(permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: fsext.PathOrPrefix(params.FilePath, m.workingDir), - ToolCallID: call.ID, - ToolName: MultiEditToolName, - Action: "write", - Description: fmt.Sprintf("Create file %s with %d edits", params.FilePath, len(params.Edits)), - Params: MultiEditPermissionsParams{ - FilePath: params.FilePath, - OldContent: "", - NewContent: currentContent, - }, - }) - if !p { - return ToolResponse{}, permission.ErrorPermissionDenied - } - - // Write the file - err := os.WriteFile(params.FilePath, []byte(currentContent), 0o644) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) - } - - // Update file history - _, err = m.files.Create(ctx, sessionID, params.FilePath, "") - if err != nil { - return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) - } - - _, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, currentContent) - if err != nil { - slog.Debug("Error creating file history version", "error", err) - } - - recordFileWrite(params.FilePath) - recordFileRead(params.FilePath) - - return WithResponseMetadata( - NewTextResponse(fmt.Sprintf("File created with %d edits: %s", len(params.Edits), params.FilePath)), - MultiEditResponseMetadata{ - OldContent: "", - NewContent: currentContent, - Additions: additions, - Removals: removals, - EditsApplied: len(params.Edits), - }, - ), nil -} - -func (m *multiEditTool) processMultiEditExistingFile(ctx context.Context, params MultiEditParams, call ToolCall) (ToolResponse, error) { - // Validate file exists and is readable - fileInfo, err := os.Stat(params.FilePath) - if err != nil { - if os.IsNotExist(err) { - return NewTextErrorResponse(fmt.Sprintf("file not found: %s", params.FilePath)), nil - } - return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) - } - - if fileInfo.IsDir() { - return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", params.FilePath)), nil - } - - // Check if file was read before editing - if getLastReadTime(params.FilePath).IsZero() { - return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil - } - - // Check if file was modified since last read - modTime := fileInfo.ModTime() - lastRead := getLastReadTime(params.FilePath) - if modTime.After(lastRead) { - return NewTextErrorResponse( - fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", - params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339), - )), nil - } - - // Read current file content - content, err := os.ReadFile(params.FilePath) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to read file: %w", err) - } - - oldContent, isCrlf := fsext.ToUnixLineEndings(string(content)) - currentContent := oldContent - - // Apply all edits sequentially - for i, edit := range params.Edits { - newContent, err := m.applyEditToContent(currentContent, edit) - if err != nil { - return NewTextErrorResponse(fmt.Sprintf("edit %d failed: %s", i+1, err.Error())), nil - } - currentContent = newContent - } - - // Check if content actually changed - if oldContent == currentContent { - return NewTextErrorResponse("no changes made - all edits resulted in identical content"), nil - } - - // Get session and message IDs - sessionID, messageID := GetContextValues(ctx) - if sessionID == "" || messageID == "" { - return ToolResponse{}, fmt.Errorf("session ID and message ID are required for editing file") - } - - // Generate diff and check permissions - _, additions, removals := diff.GenerateDiff(oldContent, currentContent, strings.TrimPrefix(params.FilePath, m.workingDir)) - p := m.permissions.Request(permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: fsext.PathOrPrefix(params.FilePath, m.workingDir), - ToolCallID: call.ID, - ToolName: MultiEditToolName, - Action: "write", - Description: fmt.Sprintf("Apply %d edits to file %s", len(params.Edits), params.FilePath), - Params: MultiEditPermissionsParams{ - FilePath: params.FilePath, - OldContent: oldContent, - NewContent: currentContent, - }, - }) - if !p { - return ToolResponse{}, permission.ErrorPermissionDenied - } - - if isCrlf { - currentContent, _ = fsext.ToWindowsLineEndings(currentContent) - } - - // Write the updated content - err = os.WriteFile(params.FilePath, []byte(currentContent), 0o644) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) - } - - // Update file history - file, err := m.files.GetByPathAndSession(ctx, params.FilePath, sessionID) - if err != nil { - _, err = m.files.Create(ctx, sessionID, params.FilePath, oldContent) - if err != nil { - return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) - } - } - if file.Content != oldContent { - // User manually changed the content, store an intermediate version - _, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, oldContent) - if err != nil { - slog.Debug("Error creating file history version", "error", err) - } - } - - // Store the new version - _, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, currentContent) - if err != nil { - slog.Debug("Error creating file history version", "error", err) - } - - recordFileWrite(params.FilePath) - recordFileRead(params.FilePath) - - return WithResponseMetadata( - NewTextResponse(fmt.Sprintf("Applied %d edits to file: %s", len(params.Edits), params.FilePath)), - MultiEditResponseMetadata{ - OldContent: oldContent, - NewContent: currentContent, - Additions: additions, - Removals: removals, - EditsApplied: len(params.Edits), - }, - ), nil -} - -func (m *multiEditTool) applyEditToContent(content string, edit MultiEditOperation) (string, error) { - if edit.OldString == "" && edit.NewString == "" { - return content, nil - } - - if edit.OldString == "" { - return "", fmt.Errorf("old_string cannot be empty for content replacement") - } - - var newContent string - var replacementCount int - - if edit.ReplaceAll { - newContent = strings.ReplaceAll(content, edit.OldString, edit.NewString) - replacementCount = strings.Count(content, edit.OldString) - if replacementCount == 0 { - return "", fmt.Errorf("old_string not found in content. Make sure it matches exactly, including whitespace and line breaks") - } - } else { - index := strings.Index(content, edit.OldString) - if index == -1 { - return "", fmt.Errorf("old_string not found in content. Make sure it matches exactly, including whitespace and line breaks") - } - - lastIndex := strings.LastIndex(content, edit.OldString) - if index != lastIndex { - return "", fmt.Errorf("old_string appears multiple times in the content. Please provide more context to ensure a unique match, or set replace_all to true") - } - - newContent = content[:index] + edit.NewString + content[index+len(edit.OldString):] - replacementCount = 1 - } - - return newContent, nil -} diff --git a/internal/llm/tools/multiedit.md b/internal/llm/tools/multiedit.md deleted file mode 100644 index 4cc64f7a82c46140b79491a897fa8b691e29cc40..0000000000000000000000000000000000000000 --- a/internal/llm/tools/multiedit.md +++ /dev/null @@ -1,48 +0,0 @@ -This is a tool for making multiple edits to a single file in one operation. It is built on top of the Edit tool and allows you to perform multiple find-and-replace operations efficiently. Prefer this tool over the Edit tool when you need to make multiple edits to the same file. - -Before using this tool: - -1. Use the Read tool to understand the file's contents and context - -2. Verify the directory path is correct - -To make multiple file edits, provide the following: - -1. file_path: The absolute path to the file to modify (must be absolute, not relative) -2. edits: An array of edit operations to perform, where each edit contains: - - old_string: The text to replace (must match the file contents exactly, including all whitespace and indentation) - - new_string: The edited text to replace the old_string - - replace_all: Replace all occurrences of old_string. This parameter is optional and defaults to false. - -IMPORTANT: - -- All edits are applied in sequence, in the order they are provided -- Each edit operates on the result of the previous edit -- All edits must be valid for the operation to succeed - if any edit fails, none will be applied -- This tool is ideal when you need to make several changes to different parts of the same file - -CRITICAL REQUIREMENTS: - -1. All edits follow the same requirements as the single Edit tool -2. The edits are atomic - either all succeed or none are applied -3. Plan your edits carefully to avoid conflicts between sequential operations - -WARNING: - -- The tool will fail if edits.old_string doesn't match the file contents exactly (including whitespace) -- The tool will fail if edits.old_string and edits.new_string are the same -- Since edits are applied in sequence, ensure that earlier edits don't affect the text that later edits are trying to find - -When making edits: - -- Ensure all edits result in idiomatic, correct code -- Do not leave the code in a broken state -- Always use absolute file paths (starting with /) -- Only use emojis if the user explicitly requests it. Avoid adding emojis to files unless asked. -- Use replace_all for replacing and renaming strings across the file. This parameter is useful if you want to rename a variable for instance. - -If you want to create a new file, use: - -- A new file path, including dir name if needed -- First edit: empty old_string and the new file's contents as new_string -- Subsequent edits: normal edit operations on the created content diff --git a/internal/llm/tools/rg.go b/internal/llm/tools/rg.go deleted file mode 100644 index 8809b57c8db30b4ac1ed6c070df5a7218c59e233..0000000000000000000000000000000000000000 --- a/internal/llm/tools/rg.go +++ /dev/null @@ -1,53 +0,0 @@ -package tools - -import ( - "context" - "log/slog" - "os/exec" - "path/filepath" - "strings" - "sync" - - "github.com/charmbracelet/crush/internal/log" -) - -var getRg = sync.OnceValue(func() string { - path, err := exec.LookPath("rg") - if err != nil { - if log.Initialized() { - slog.Warn("Ripgrep (rg) not found in $PATH. Some grep features might be limited or slower.") - } - return "" - } - return path -}) - -func getRgCmd(ctx context.Context, globPattern string) *exec.Cmd { - name := getRg() - if name == "" { - return nil - } - args := []string{"--files", "-L", "--null"} - if globPattern != "" { - if !filepath.IsAbs(globPattern) && !strings.HasPrefix(globPattern, "/") { - globPattern = "/" + globPattern - } - args = append(args, "--glob", globPattern) - } - return exec.CommandContext(ctx, name, args...) -} - -func getRgSearchCmd(ctx context.Context, pattern, path, include string) *exec.Cmd { - name := getRg() - if name == "" { - return nil - } - // Use -n to show line numbers, -0 for null separation to handle Windows paths - args := []string{"-H", "-n", "-0", pattern} - if include != "" { - args = append(args, "--glob", include) - } - args = append(args, path) - - return exec.CommandContext(ctx, name, args...) -} diff --git a/internal/llm/tools/safe.go b/internal/llm/tools/safe.go deleted file mode 100644 index b0e6635393632390cba1e09d1d5df336fb1979cb..0000000000000000000000000000000000000000 --- a/internal/llm/tools/safe.go +++ /dev/null @@ -1,70 +0,0 @@ -package tools - -import "runtime" - -var safeCommands = []string{ - // Bash builtins and core utils - "cal", - "date", - "df", - "du", - "echo", - "env", - "free", - "groups", - "hostname", - "id", - "kill", - "killall", - "ls", - "nice", - "nohup", - "printenv", - "ps", - "pwd", - "set", - "time", - "timeout", - "top", - "type", - "uname", - "unset", - "uptime", - "whatis", - "whereis", - "which", - "whoami", - - // Git - "git blame", - "git branch", - "git config --get", - "git config --list", - "git describe", - "git diff", - "git grep", - "git log", - "git ls-files", - "git ls-remote", - "git remote", - "git rev-parse", - "git shortlog", - "git show", - "git status", - "git tag", -} - -func init() { - if runtime.GOOS == "windows" { - safeCommands = append( - safeCommands, - // Windows-specific commands - "ipconfig", - "nslookup", - "ping", - "systeminfo", - "tasklist", - "where", - ) - } -} diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go deleted file mode 100644 index aea6b1ae8f0d52d10c083debf5bd3780957f261a..0000000000000000000000000000000000000000 --- a/internal/llm/tools/sourcegraph.go +++ /dev/null @@ -1,302 +0,0 @@ -package tools - -import ( - "bytes" - "context" - _ "embed" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" -) - -type SourcegraphParams struct { - Query string `json:"query"` - Count int `json:"count,omitempty"` - ContextWindow int `json:"context_window,omitempty"` - Timeout int `json:"timeout,omitempty"` -} - -type SourcegraphResponseMetadata struct { - NumberOfMatches int `json:"number_of_matches"` - Truncated bool `json:"truncated"` -} - -type sourcegraphTool struct { - client *http.Client -} - -const SourcegraphToolName = "sourcegraph" - -//go:embed sourcegraph.md -var sourcegraphDescription []byte - -func NewSourcegraphTool() BaseTool { - return &sourcegraphTool{ - client: &http.Client{ - Timeout: 30 * time.Second, - Transport: &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - }, - }, - } -} - -func (t *sourcegraphTool) Name() string { - return SourcegraphToolName -} - -func (t *sourcegraphTool) Info() ToolInfo { - return ToolInfo{ - Name: SourcegraphToolName, - Description: string(sourcegraphDescription), - Parameters: map[string]any{ - "query": map[string]any{ - "type": "string", - "description": "The Sourcegraph search query", - }, - "count": map[string]any{ - "type": "number", - "description": "Optional number of results to return (default: 10, max: 20)", - }, - "context_window": map[string]any{ - "type": "number", - "description": "The context around the match to return (default: 10 lines)", - }, - "timeout": map[string]any{ - "type": "number", - "description": "Optional timeout in seconds (max 120)", - }, - }, - Required: []string{"query"}, - } -} - -func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { - var params SourcegraphParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return NewTextErrorResponse("Failed to parse sourcegraph parameters: " + err.Error()), nil - } - - if params.Query == "" { - return NewTextErrorResponse("Query parameter is required"), nil - } - - if params.Count <= 0 { - params.Count = 10 - } else if params.Count > 20 { - params.Count = 20 // Limit to 20 results - } - - if params.ContextWindow <= 0 { - params.ContextWindow = 10 // Default context window - } - - // Handle timeout with context - requestCtx := ctx - if params.Timeout > 0 { - maxTimeout := 120 // 2 minutes - if params.Timeout > maxTimeout { - params.Timeout = maxTimeout - } - var cancel context.CancelFunc - requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second) - defer cancel() - } - - type graphqlRequest struct { - Query string `json:"query"` - Variables struct { - Query string `json:"query"` - } `json:"variables"` - } - - request := graphqlRequest{ - Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }", - } - request.Variables.Query = params.Query - - graphqlQueryBytes, err := json.Marshal(request) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err) - } - graphqlQuery := string(graphqlQueryBytes) - - req, err := http.NewRequestWithContext( - requestCtx, - "POST", - "https://sourcegraph.com/.api/graphql", - bytes.NewBuffer([]byte(graphqlQuery)), - ) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", "crush/1.0") - - resp, err := t.client.Do(req) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - if len(body) > 0 { - return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil - } - - return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil - } - body, err := io.ReadAll(resp.Body) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to read response body: %w", err) - } - - var result map[string]any - if err = json.Unmarshal(body, &result); err != nil { - return ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err) - } - - formattedResults, err := formatSourcegraphResults(result, params.ContextWindow) - if err != nil { - return NewTextErrorResponse("Failed to format results: " + err.Error()), nil - } - - return NewTextResponse(formattedResults), nil -} - -func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) { - var buffer strings.Builder - - if errors, ok := result["errors"].([]any); ok && len(errors) > 0 { - buffer.WriteString("## Sourcegraph API Error\n\n") - for _, err := range errors { - if errMap, ok := err.(map[string]any); ok { - if message, ok := errMap["message"].(string); ok { - buffer.WriteString(fmt.Sprintf("- %s\n", message)) - } - } - } - return buffer.String(), nil - } - - data, ok := result["data"].(map[string]any) - if !ok { - return "", fmt.Errorf("invalid response format: missing data field") - } - - search, ok := data["search"].(map[string]any) - if !ok { - return "", fmt.Errorf("invalid response format: missing search field") - } - - searchResults, ok := search["results"].(map[string]any) - if !ok { - return "", fmt.Errorf("invalid response format: missing results field") - } - - matchCount, _ := searchResults["matchCount"].(float64) - resultCount, _ := searchResults["resultCount"].(float64) - limitHit, _ := searchResults["limitHit"].(bool) - - buffer.WriteString("# Sourcegraph Search Results\n\n") - buffer.WriteString(fmt.Sprintf("Found %d matches across %d results\n", int(matchCount), int(resultCount))) - - if limitHit { - buffer.WriteString("(Result limit reached, try a more specific query)\n") - } - - buffer.WriteString("\n") - - results, ok := searchResults["results"].([]any) - if !ok || len(results) == 0 { - buffer.WriteString("No results found. Try a different query.\n") - return buffer.String(), nil - } - - maxResults := 10 - if len(results) > maxResults { - results = results[:maxResults] - } - - for i, res := range results { - fileMatch, ok := res.(map[string]any) - if !ok { - continue - } - - typeName, _ := fileMatch["__typename"].(string) - if typeName != "FileMatch" { - continue - } - - repo, _ := fileMatch["repository"].(map[string]any) - file, _ := fileMatch["file"].(map[string]any) - lineMatches, _ := fileMatch["lineMatches"].([]any) - - if repo == nil || file == nil { - continue - } - - repoName, _ := repo["name"].(string) - filePath, _ := file["path"].(string) - fileURL, _ := file["url"].(string) - fileContent, _ := file["content"].(string) - - buffer.WriteString(fmt.Sprintf("## Result %d: %s/%s\n\n", i+1, repoName, filePath)) - - if fileURL != "" { - buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL)) - } - - if len(lineMatches) > 0 { - for _, lm := range lineMatches { - lineMatch, ok := lm.(map[string]any) - if !ok { - continue - } - - lineNumber, _ := lineMatch["lineNumber"].(float64) - preview, _ := lineMatch["preview"].(string) - - if fileContent != "" { - lines := strings.Split(fileContent, "\n") - - buffer.WriteString("```\n") - - startLine := max(1, int(lineNumber)-contextWindow) - - for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ { - if j >= 0 { - buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j])) - } - } - - buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview)) - - endLine := int(lineNumber) + contextWindow - - for j := int(lineNumber); j < endLine && j < len(lines); j++ { - if j < len(lines) { - buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j])) - } - } - - buffer.WriteString("```\n\n") - } else { - buffer.WriteString("```\n") - buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview)) - buffer.WriteString("```\n\n") - } - } - } - } - - return buffer.String(), nil -} diff --git a/internal/llm/tools/sourcegraph.md b/internal/llm/tools/sourcegraph.md deleted file mode 100644 index ec6610d7af429454783282061252a5b40191f5e5..0000000000000000000000000000000000000000 --- a/internal/llm/tools/sourcegraph.md +++ /dev/null @@ -1,102 +0,0 @@ -Search code across public repositories using Sourcegraph's GraphQL API. - -WHEN TO USE THIS TOOL: - -- Use when you need to find code examples or implementations across public repositories -- Helpful for researching how others have solved similar problems -- Useful for discovering patterns and best practices in open source code - -HOW TO USE: - -- Provide a search query using Sourcegraph's query syntax -- Optionally specify the number of results to return (default: 10) -- Optionally set a timeout for the request - -QUERY SYNTAX: - -- Basic search: "fmt.Println" searches for exact matches -- File filters: "file:.go fmt.Println" limits to Go files -- Repository filters: "repo:^github\.com/golang/go$ fmt.Println" limits to specific repos -- Language filters: "lang:go fmt.Println" limits to Go code -- Boolean operators: "fmt.Println AND log.Fatal" for combined terms -- Regular expressions: "fmt\.(Print|Printf|Println)" for pattern matching -- Quoted strings: "\"exact phrase\"" for exact phrase matching -- Exclude filters: "-file:test" or "-repo:forks" to exclude matches - -ADVANCED FILTERS: - -- Repository filters: - - "repo:name" - Match repositories with name containing "name" - - "repo:^github\.com/org/repo$" - Exact repository match - - "repo:org/repo@branch" - Search specific branch - - "repo:org/repo rev:branch" - Alternative branch syntax - - "-repo:name" - Exclude repositories - - "fork:yes" or "fork:only" - Include or only show forks - - "archived:yes" or "archived:only" - Include or only show archived repos - - "visibility:public" or "visibility:private" - Filter by visibility - -- File filters: - - "file:\.js$" - Files with .js extension - - "file:internal/" - Files in internal directory - - "-file:test" - Exclude test files - - "file:has.content(Copyright)" - Files containing "Copyright" - - "file:has.contributor([email protected])" - Files with specific contributor - -- Content filters: - - "content:\"exact string\"" - Search for exact string - - "-content:\"unwanted\"" - Exclude files with unwanted content - - "case:yes" - Case-sensitive search - -- Type filters: - - "type:symbol" - Search for symbols (functions, classes, etc.) - - "type:file" - Search file content only - - "type:path" - Search filenames only - - "type:diff" - Search code changes - - "type:commit" - Search commit messages - -- Commit/diff search: - - "after:\"1 month ago\"" - Commits after date - - "before:\"2023-01-01\"" - Commits before date - - "author:name" - Commits by author - - "message:\"fix bug\"" - Commits with message - -- Result selection: - - "select:repo" - Show only repository names - - "select:file" - Show only file paths - - "select:content" - Show only matching content - - "select:symbol" - Show only matching symbols - -- Result control: - - "count:100" - Return up to 100 results - - "count:all" - Return all results - - "timeout:30s" - Set search timeout - -EXAMPLES: - -- "file:.go context.WithTimeout" - Find Go code using context.WithTimeout -- "lang:typescript useState type:symbol" - Find TypeScript React useState hooks -- "repo:^github\.com/kubernetes/kubernetes$ pod list type:file" - Find Kubernetes files related to pod listing -- "repo:sourcegraph/sourcegraph$ after:\"3 months ago\" type:diff database" - Recent changes to database code -- "file:Dockerfile (alpine OR ubuntu) -content:alpine:latest" - Dockerfiles with specific base images -- "repo:has.path(\.py) file:requirements.txt tensorflow" - Python projects using TensorFlow - -BOOLEAN OPERATORS: - -- "term1 AND term2" - Results containing both terms -- "term1 OR term2" - Results containing either term -- "term1 NOT term2" - Results with term1 but not term2 -- "term1 and (term2 or term3)" - Grouping with parentheses - -LIMITATIONS: - -- Only searches public repositories -- Rate limits may apply -- Complex queries may take longer to execute -- Maximum of 20 results per query - -TIPS: - -- Use specific file extensions to narrow results -- Add repo: filters for more targeted searches -- Use type:symbol to find function/method definitions -- Use type:file to find relevant files diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go deleted file mode 100644 index 41c0515616032b117f3c09a0056cac9e86b62c66..0000000000000000000000000000000000000000 --- a/internal/llm/tools/tools.go +++ /dev/null @@ -1,85 +0,0 @@ -package tools - -import ( - "context" - "encoding/json" -) - -type ToolInfo struct { - Name string - Description string - Parameters map[string]any - Required []string -} - -type toolResponseType string - -type ( - sessionIDContextKey string - messageIDContextKey string -) - -const ( - ToolResponseTypeText toolResponseType = "text" - ToolResponseTypeImage toolResponseType = "image" - - SessionIDContextKey sessionIDContextKey = "session_id" - MessageIDContextKey messageIDContextKey = "message_id" -) - -type ToolResponse struct { - Type toolResponseType `json:"type"` - Content string `json:"content"` - Metadata string `json:"metadata,omitempty"` - IsError bool `json:"is_error"` -} - -func NewTextResponse(content string) ToolResponse { - return ToolResponse{ - Type: ToolResponseTypeText, - Content: content, - } -} - -func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse { - if metadata != nil { - metadataBytes, err := json.Marshal(metadata) - if err != nil { - return response - } - response.Metadata = string(metadataBytes) - } - return response -} - -func NewTextErrorResponse(content string) ToolResponse { - return ToolResponse{ - Type: ToolResponseTypeText, - Content: content, - IsError: true, - } -} - -type ToolCall struct { - ID string `json:"id"` - Name string `json:"name"` - Input string `json:"input"` -} - -type BaseTool interface { - Info() ToolInfo - Name() string - Run(ctx context.Context, params ToolCall) (ToolResponse, error) -} - -func GetContextValues(ctx context.Context) (string, string) { - sessionID := ctx.Value(SessionIDContextKey) - messageID := ctx.Value(MessageIDContextKey) - if sessionID == nil { - return "", "" - } - if messageID == nil { - return sessionID.(string), "" - } - return sessionID.(string), messageID.(string) -} diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go deleted file mode 100644 index c3824bf1b1cc11dc22c57d60eb72654999a5e970..0000000000000000000000000000000000000000 --- a/internal/llm/tools/view.go +++ /dev/null @@ -1,343 +0,0 @@ -package tools - -import ( - "bufio" - "context" - _ "embed" - "encoding/json" - "fmt" - "io" - "os" - "path/filepath" - "strings" - "unicode/utf8" - - "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/lsp" - "github.com/charmbracelet/crush/internal/permission" -) - -//go:embed view.md -var viewDescription []byte - -type ViewParams struct { - FilePath string `json:"file_path"` - Offset int `json:"offset"` - Limit int `json:"limit"` -} - -type ViewPermissionsParams struct { - FilePath string `json:"file_path"` - Offset int `json:"offset"` - Limit int `json:"limit"` -} - -type viewTool struct { - lspClients *csync.Map[string, *lsp.Client] - workingDir string - permissions permission.Service -} - -type ViewResponseMetadata struct { - FilePath string `json:"file_path"` - Content string `json:"content"` -} - -const ( - ViewToolName = "view" - MaxReadSize = 250 * 1024 - DefaultReadLimit = 2000 - MaxLineLength = 2000 -) - -func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, workingDir string) BaseTool { - return &viewTool{ - lspClients: lspClients, - workingDir: workingDir, - permissions: permissions, - } -} - -func (v *viewTool) Name() string { - return ViewToolName -} - -func (v *viewTool) Info() ToolInfo { - return ToolInfo{ - Name: ViewToolName, - Description: string(viewDescription), - Parameters: map[string]any{ - "file_path": map[string]any{ - "type": "string", - "description": "The path to the file to read", - }, - "offset": map[string]any{ - "type": "integer", - "description": "The line number to start reading from (0-based)", - }, - "limit": map[string]any{ - "type": "integer", - "description": "The number of lines to read (defaults to 2000)", - }, - }, - Required: []string{"file_path"}, - } -} - -// Run implements Tool. -func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { - var params ViewParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil - } - - if params.FilePath == "" { - return NewTextErrorResponse("file_path is required"), nil - } - - // Handle relative paths - filePath := params.FilePath - if !filepath.IsAbs(filePath) { - filePath = filepath.Join(v.workingDir, filePath) - } - - // Check if file is outside working directory and request permission if needed - absWorkingDir, err := filepath.Abs(v.workingDir) - if err != nil { - return ToolResponse{}, fmt.Errorf("error resolving working directory: %w", err) - } - - absFilePath, err := filepath.Abs(filePath) - if err != nil { - return ToolResponse{}, fmt.Errorf("error resolving file path: %w", err) - } - - relPath, err := filepath.Rel(absWorkingDir, absFilePath) - if err != nil || strings.HasPrefix(relPath, "..") { - // File is outside working directory, request permission - sessionID, messageID := GetContextValues(ctx) - if sessionID == "" || messageID == "" { - return ToolResponse{}, fmt.Errorf("session ID and message ID are required for accessing files outside working directory") - } - - granted := v.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: absFilePath, - ToolCallID: call.ID, - ToolName: ViewToolName, - Action: "read", - Description: fmt.Sprintf("Read file outside working directory: %s", absFilePath), - Params: ViewPermissionsParams(params), - }, - ) - - if !granted { - return ToolResponse{}, permission.ErrorPermissionDenied - } - } - - // Check if file exists - fileInfo, err := os.Stat(filePath) - if err != nil { - if os.IsNotExist(err) { - // Try to offer suggestions for similarly named files - dir := filepath.Dir(filePath) - base := filepath.Base(filePath) - - dirEntries, dirErr := os.ReadDir(dir) - if dirErr == nil { - var suggestions []string - for _, entry := range dirEntries { - if strings.Contains(strings.ToLower(entry.Name()), strings.ToLower(base)) || - strings.Contains(strings.ToLower(base), strings.ToLower(entry.Name())) { - suggestions = append(suggestions, filepath.Join(dir, entry.Name())) - if len(suggestions) >= 3 { - break - } - } - } - - if len(suggestions) > 0 { - return NewTextErrorResponse(fmt.Sprintf("File not found: %s\n\nDid you mean one of these?\n%s", - filePath, strings.Join(suggestions, "\n"))), nil - } - } - - return NewTextErrorResponse(fmt.Sprintf("File not found: %s", filePath)), nil - } - return ToolResponse{}, fmt.Errorf("error accessing file: %w", err) - } - - // Check if it's a directory - if fileInfo.IsDir() { - return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil - } - - // Check file size - if fileInfo.Size() > MaxReadSize { - return NewTextErrorResponse(fmt.Sprintf("File is too large (%d bytes). Maximum size is %d bytes", - fileInfo.Size(), MaxReadSize)), nil - } - - // Set default limit if not provided - if params.Limit <= 0 { - params.Limit = DefaultReadLimit - } - - // Check if it's an image file - isImage, imageType := isImageFile(filePath) - // TODO: handle images - if isImage { - return NewTextErrorResponse(fmt.Sprintf("This is an image file of type: %s\n", imageType)), nil - } - - // Read the file content - content, lineCount, err := readTextFile(filePath, params.Offset, params.Limit) - isValidUt8 := utf8.ValidString(content) - if !isValidUt8 { - return NewTextErrorResponse("File content is not valid UTF-8"), nil - } - if err != nil { - return ToolResponse{}, fmt.Errorf("error reading file: %w", err) - } - - notifyLSPs(ctx, v.lspClients, filePath) - output := "\n" - // Format the output with line numbers - output += addLineNumbers(content, params.Offset+1) - - // Add a note if the content was truncated - if lineCount > params.Offset+len(strings.Split(content, "\n")) { - output += fmt.Sprintf("\n\n(File has more lines. Use 'offset' parameter to read beyond line %d)", - params.Offset+len(strings.Split(content, "\n"))) - } - output += "\n\n" - output += getDiagnostics(filePath, v.lspClients) - recordFileRead(filePath) - return WithResponseMetadata( - NewTextResponse(output), - ViewResponseMetadata{ - FilePath: filePath, - Content: content, - }, - ), nil -} - -func addLineNumbers(content string, startLine int) string { - if content == "" { - return "" - } - - lines := strings.Split(content, "\n") - - var result []string - for i, line := range lines { - line = strings.TrimSuffix(line, "\r") - - lineNum := i + startLine - numStr := fmt.Sprintf("%d", lineNum) - - if len(numStr) >= 6 { - result = append(result, fmt.Sprintf("%s|%s", numStr, line)) - } else { - paddedNum := fmt.Sprintf("%6s", numStr) - result = append(result, fmt.Sprintf("%s|%s", paddedNum, line)) - } - } - - return strings.Join(result, "\n") -} - -func readTextFile(filePath string, offset, limit int) (string, int, error) { - file, err := os.Open(filePath) - if err != nil { - return "", 0, err - } - defer file.Close() - - lineCount := 0 - - scanner := NewLineScanner(file) - if offset > 0 { - for lineCount < offset && scanner.Scan() { - lineCount++ - } - if err = scanner.Err(); err != nil { - return "", 0, err - } - } - - if offset == 0 { - _, err = file.Seek(0, io.SeekStart) - if err != nil { - return "", 0, err - } - } - - // Pre-allocate slice with expected capacity - lines := make([]string, 0, limit) - lineCount = offset - - for scanner.Scan() && len(lines) < limit { - lineCount++ - lineText := scanner.Text() - if len(lineText) > MaxLineLength { - lineText = lineText[:MaxLineLength] + "..." - } - lines = append(lines, lineText) - } - - // Continue scanning to get total line count - for scanner.Scan() { - lineCount++ - } - - if err := scanner.Err(); err != nil { - return "", 0, err - } - - return strings.Join(lines, "\n"), lineCount, nil -} - -func isImageFile(filePath string) (bool, string) { - ext := strings.ToLower(filepath.Ext(filePath)) - switch ext { - case ".jpg", ".jpeg": - return true, "JPEG" - case ".png": - return true, "PNG" - case ".gif": - return true, "GIF" - case ".bmp": - return true, "BMP" - case ".svg": - return true, "SVG" - case ".webp": - return true, "WebP" - default: - return false, "" - } -} - -type LineScanner struct { - scanner *bufio.Scanner -} - -func NewLineScanner(r io.Reader) *LineScanner { - return &LineScanner{ - scanner: bufio.NewScanner(r), - } -} - -func (s *LineScanner) Scan() bool { - return s.scanner.Scan() -} - -func (s *LineScanner) Text() string { - return s.scanner.Text() -} - -func (s *LineScanner) Err() error { - return s.scanner.Err() -} diff --git a/internal/llm/tools/view.md b/internal/llm/tools/view.md deleted file mode 100644 index 2179b77817b3e9a98ca7077bf8430abdde597392..0000000000000000000000000000000000000000 --- a/internal/llm/tools/view.md +++ /dev/null @@ -1,42 +0,0 @@ -File viewing tool that reads and displays the contents of files with line numbers, allowing you to examine code, logs, or text data. - -WHEN TO USE THIS TOOL: - -- Use when you need to read the contents of a specific file -- Helpful for examining source code, configuration files, or log files -- Perfect for looking at text-based file formats - -HOW TO USE: - -- Provide the path to the file you want to view -- Optionally specify an offset to start reading from a specific line -- Optionally specify a limit to control how many lines are read -- Do not use this for directories use the ls tool instead - -FEATURES: - -- Displays file contents with line numbers for easy reference -- Can read from any position in a file using the offset parameter -- Handles large files by limiting the number of lines read -- Automatically truncates very long lines for better display -- Suggests similar file names when the requested file isn't found - -LIMITATIONS: - -- Maximum file size is 250KB -- Default reading limit is 2000 lines -- Lines longer than 2000 characters are truncated -- Cannot display binary files or images -- Images can be identified but not displayed - -WINDOWS NOTES: - -- Handles both Windows (CRLF) and Unix (LF) line endings automatically -- File paths work with both forward slashes (/) and backslashes (\) -- Text encoding is detected automatically for most common formats - -TIPS: - -- Use with Glob tool to first find files you want to view -- For code exploration, first use Grep to find relevant files, then View to examine them -- When viewing large files, use the offset parameter to read specific sections diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go deleted file mode 100644 index 36ea13d2dd6e1930eb645ab036c950b9a4d6fdc8..0000000000000000000000000000000000000000 --- a/internal/llm/tools/write.go +++ /dev/null @@ -1,208 +0,0 @@ -package tools - -import ( - "context" - _ "embed" - "encoding/json" - "fmt" - "log/slog" - "os" - "path/filepath" - "strings" - "time" - - "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/diff" - "github.com/charmbracelet/crush/internal/fsext" - "github.com/charmbracelet/crush/internal/history" - - "github.com/charmbracelet/crush/internal/lsp" - "github.com/charmbracelet/crush/internal/permission" -) - -//go:embed write.md -var writeDescription []byte - -type WriteParams struct { - FilePath string `json:"file_path"` - Content string `json:"content"` -} - -type WritePermissionsParams struct { - FilePath string `json:"file_path"` - OldContent string `json:"old_content,omitempty"` - NewContent string `json:"new_content,omitempty"` -} - -type writeTool struct { - lspClients *csync.Map[string, *lsp.Client] - permissions permission.Service - files history.Service - workingDir string -} - -type WriteResponseMetadata struct { - Diff string `json:"diff"` - Additions int `json:"additions"` - Removals int `json:"removals"` -} - -const WriteToolName = "write" - -func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool { - return &writeTool{ - lspClients: lspClients, - permissions: permissions, - files: files, - workingDir: workingDir, - } -} - -func (w *writeTool) Name() string { - return WriteToolName -} - -func (w *writeTool) Info() ToolInfo { - return ToolInfo{ - Name: WriteToolName, - Description: string(writeDescription), - Parameters: map[string]any{ - "file_path": map[string]any{ - "type": "string", - "description": "The path to the file to write", - }, - "content": map[string]any{ - "type": "string", - "description": "The content to write to the file", - }, - }, - Required: []string{"file_path", "content"}, - } -} - -func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { - var params WriteParams - if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { - return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil - } - - if params.FilePath == "" { - return NewTextErrorResponse("file_path is required"), nil - } - - if params.Content == "" { - return NewTextErrorResponse("content is required"), nil - } - - filePath := params.FilePath - if !filepath.IsAbs(filePath) { - filePath = filepath.Join(w.workingDir, filePath) - } - - fileInfo, err := os.Stat(filePath) - if err == nil { - if fileInfo.IsDir() { - return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil - } - - modTime := fileInfo.ModTime() - lastRead := getLastReadTime(filePath) - if modTime.After(lastRead) { - return NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.", - filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil - } - - oldContent, readErr := os.ReadFile(filePath) - if readErr == nil && string(oldContent) == params.Content { - return NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil - } - } else if !os.IsNotExist(err) { - return ToolResponse{}, fmt.Errorf("error checking file: %w", err) - } - - dir := filepath.Dir(filePath) - if err = os.MkdirAll(dir, 0o755); err != nil { - return ToolResponse{}, fmt.Errorf("error creating directory: %w", err) - } - - oldContent := "" - if fileInfo != nil && !fileInfo.IsDir() { - oldBytes, readErr := os.ReadFile(filePath) - if readErr == nil { - oldContent = string(oldBytes) - } - } - - sessionID, messageID := GetContextValues(ctx) - if sessionID == "" || messageID == "" { - return ToolResponse{}, fmt.Errorf("session_id and message_id are required") - } - - diff, additions, removals := diff.GenerateDiff( - oldContent, - params.Content, - strings.TrimPrefix(filePath, w.workingDir), - ) - - p := w.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - Path: fsext.PathOrPrefix(filePath, w.workingDir), - ToolCallID: call.ID, - ToolName: WriteToolName, - Action: "write", - Description: fmt.Sprintf("Create file %s", filePath), - Params: WritePermissionsParams{ - FilePath: filePath, - OldContent: oldContent, - NewContent: params.Content, - }, - }, - ) - if !p { - return ToolResponse{}, permission.ErrorPermissionDenied - } - - err = os.WriteFile(filePath, []byte(params.Content), 0o644) - if err != nil { - return ToolResponse{}, fmt.Errorf("error writing file: %w", err) - } - - // Check if file exists in history - file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID) - if err != nil { - _, err = w.files.Create(ctx, sessionID, filePath, oldContent) - if err != nil { - // Log error but don't fail the operation - return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) - } - } - if file.Content != oldContent { - // User Manually changed the content store an intermediate version - _, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent) - if err != nil { - slog.Debug("Error creating file history version", "error", err) - } - } - // Store the new version - _, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content) - if err != nil { - slog.Debug("Error creating file history version", "error", err) - } - - recordFileWrite(filePath) - recordFileRead(filePath) - - notifyLSPs(ctx, w.lspClients, params.FilePath) - - result := fmt.Sprintf("File successfully written: %s", filePath) - result = fmt.Sprintf("\n%s\n", result) - result += getDiagnostics(filePath, w.lspClients) - return WithResponseMetadata(NewTextResponse(result), - WriteResponseMetadata{ - Diff: diff, - Additions: additions, - Removals: removals, - }, - ), nil -} diff --git a/internal/llm/tools/write.md b/internal/llm/tools/write.md deleted file mode 100644 index b8cc811ffe6234e7ebfd34037de57bd483305738..0000000000000000000000000000000000000000 --- a/internal/llm/tools/write.md +++ /dev/null @@ -1,38 +0,0 @@ -File writing tool that creates or updates files in the filesystem, allowing you to save or modify text content. - -WHEN TO USE THIS TOOL: - -- Use when you need to create a new file -- Helpful for updating existing files with modified content -- Perfect for saving generated code, configurations, or text data - -HOW TO USE: - -- Provide the path to the file you want to write -- Include the content to be written to the file -- The tool will create any necessary parent directories - -FEATURES: - -- Can create new files or overwrite existing ones -- Creates parent directories automatically if they don't exist -- Checks if the file has been modified since last read for safety -- Avoids unnecessary writes when content hasn't changed - -LIMITATIONS: - -- You should read a file before writing to it to avoid conflicts -- Cannot append to files (rewrites the entire file) - -WINDOWS NOTES: - -- File permissions (0o755, 0o644) are Unix-style but work on Windows with appropriate translations -- Use forward slashes (/) in paths for cross-platform compatibility -- Windows file attributes and permissions are handled automatically by the Go runtime - -TIPS: - -- Use the View tool first to examine existing files before modifying them -- Use the LS tool to verify the correct location when creating new files -- Combine with Glob and Grep tools to find and modify multiple files -- Always include descriptive comments when making changes to existing code diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go index 8688f7e24c94290c74ae4344499acff61b43ac39..1d97915e11f8f98e1545f5b4b186fb032904c3fb 100644 --- a/internal/tui/components/chat/chat.go +++ b/internal/tui/components/chat/chat.go @@ -8,8 +8,8 @@ import ( "github.com/atotto/clipboard" "github.com/charmbracelet/bubbles/v2/key" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/crush/internal/agent" "github.com/charmbracelet/crush/internal/app" - "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/pubsub" @@ -103,8 +103,8 @@ func (m *messageListCmp) Init() tea.Cmd { // Update handles incoming messages and updates the component state. func (m *messageListCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd - if m.session.ID != "" && m.app.CoderAgent != nil { - queueSize := m.app.CoderAgent.QueuedPrompts(m.session.ID) + if m.session.ID != "" && m.app.AgentCoordinator != nil { + queueSize := m.app.AgentCoordinator.QueuedPrompts(m.session.ID) if queueSize != m.promptQueue { m.promptQueue = queueSize cmds = append(cmds, m.SetSize(m.width, m.height)) @@ -235,7 +235,7 @@ func (m *messageListCmp) View() string { m.listCmp.View(), ), } - if m.app.CoderAgent != nil && m.promptQueue > 0 { + if m.app.AgentCoordinator != nil && m.promptQueue > 0 { queuePill := queuePill(m.promptQueue, t) view = append(view, t.S().Base.PaddingLeft(4).PaddingTop(1).Render(queuePill)) } diff --git a/internal/tui/components/chat/editor/editor.go b/internal/tui/components/chat/editor/editor.go index f70a0a3dbe63a9473f552efa233e03bd4efc0ee1..7ae42defd0dfbf767585c23d45351c23950d9848 100644 --- a/internal/tui/components/chat/editor/editor.go +++ b/internal/tui/components/chat/editor/editor.go @@ -211,7 +211,7 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } case commands.OpenExternalEditorMsg: - if m.app.CoderAgent.IsSessionBusy(m.session.ID) { + if m.app.AgentCoordinator.IsSessionBusy(m.session.ID) { return m, util.ReportWarn("Agent is working, please wait...") } return m, m.openEditor(m.textarea.Value()) @@ -297,7 +297,7 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } if key.Matches(msg, m.keyMap.OpenEditor) { - if m.app.CoderAgent.IsSessionBusy(m.session.ID) { + if m.app.AgentCoordinator.IsSessionBusy(m.session.ID) { return m, util.ReportWarn("Agent is working, please wait...") } return m, m.openEditor(m.textarea.Value()) @@ -415,7 +415,7 @@ func (m *editorCmp) randomizePlaceholders() { func (m *editorCmp) View() string { t := styles.CurrentTheme() // Update placeholder - if m.app.CoderAgent != nil && m.app.CoderAgent.IsBusy() { + if m.app.AgentCoordinator != nil && m.app.AgentCoordinator.IsBusy() { m.textarea.Placeholder = m.workingPlaceholder } else { m.textarea.Placeholder = m.readyPlaceholder diff --git a/internal/tui/components/chat/header/header.go b/internal/tui/components/chat/header/header.go index 21861a4a2eda1340f6e01c0748f24cb713f15398..58776a7f223d01d0e9476d843fb32f82926e6476 100644 --- a/internal/tui/components/chat/header/header.go +++ b/internal/tui/components/chat/header/header.go @@ -119,7 +119,7 @@ func (h *header) details(availWidth int) string { parts = append(parts, s.Error.Render(fmt.Sprintf("%s%d", styles.ErrorIcon, errorCount))) } - agentCfg := config.Get().Agents["coder"] + agentCfg := config.Get().Agents[config.AgentCoder] model := config.Get().GetModelByType(agentCfg.Model) percentage := (float64(h.session.CompletionTokens+h.session.PromptTokens) / float64(model.ContextWindow)) * 100 formattedPercentage := s.Muted.Render(fmt.Sprintf("%d%%", int(percentage))) diff --git a/internal/tui/components/chat/messages/renderer.go b/internal/tui/components/chat/messages/renderer.go index ffe3301ce9c5ccbbbdc9188485e082d4b8156f66..ce53ad3e763fc297d4ed72ad7164b52a64f7972e 100644 --- a/internal/tui/components/chat/messages/renderer.go +++ b/internal/tui/components/chat/messages/renderer.go @@ -6,10 +6,10 @@ import ( "strings" "time" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/ansiext" "github.com/charmbracelet/crush/internal/fsext" - "github.com/charmbracelet/crush/internal/llm/agent" - "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/highlight" "github.com/charmbracelet/crush/internal/tui/styles" diff --git a/internal/tui/components/chat/messages/tool.go b/internal/tui/components/chat/messages/tool.go index 7e03674f97243e7d9e569b341fe1c6f1d2450b93..4e3820b4afab6d6e067e0f379d6febee6d008bc6 100644 --- a/internal/tui/components/chat/messages/tool.go +++ b/internal/tui/components/chat/messages/tool.go @@ -10,10 +10,10 @@ import ( "github.com/atotto/clipboard" "github.com/charmbracelet/bubbles/v2/key" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/fsext" - "github.com/charmbracelet/crush/internal/llm/agent" - "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/tui/components/anim" diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index b50a78c7f8697e4f4db19649a01794cfe7a23bac..7b32cda78bde99ea92ec4472a0e817521b748c34 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -545,7 +545,7 @@ func formatTokensAndCost(tokens, contextWindow int64, cost float64) string { func (s *sidebarCmp) currentModelBlock() string { cfg := config.Get() - agentCfg := cfg.Agents["coder"] + agentCfg := cfg.Agents[config.AgentCoder] selectedModel := cfg.Models[agentCfg.Model] diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index 187fc35e6ec47a858b99f35e135a8cef3500fbf1..0781b51c54c7c413af2fa0d4aab018d06105a811 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -9,9 +9,9 @@ import ( "github.com/charmbracelet/bubbles/v2/spinner" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/agent" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/home" - "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/core/layout" @@ -334,7 +334,7 @@ func (s *splashCmp) initializeProject() tea.Cmd { cmds = append(cmds, util.CmdHandler(chat.SessionClearedMsg{}), util.CmdHandler(chat.SendMsg{ - Text: prompt.Initialize(), + Text: agent.InitializePrompt(), }), ) } @@ -695,7 +695,7 @@ func (s *splashCmp) mcpBlock() string { func (s *splashCmp) currentModelBlock() string { cfg := config.Get() - agentCfg := cfg.Agents["coder"] + agentCfg := cfg.Agents[config.AgentCoder] model := config.Get().GetModelByType(agentCfg.Model) if model == nil { return "" diff --git a/internal/tui/components/dialogs/commands/commands.go b/internal/tui/components/dialogs/commands/commands.go index 664158fc392a87d8a7725bfa964748f7ef4f8e67..55f25c0d3f2da7c590abf76f7f533f837e4d52f2 100644 --- a/internal/tui/components/dialogs/commands/commands.go +++ b/internal/tui/components/dialogs/commands/commands.go @@ -9,8 +9,8 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/lipgloss/v2" + "github.com/charmbracelet/crush/internal/agent" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" @@ -303,7 +303,7 @@ func (c *commandDialogCmp) defaultCommands() []Command { // Add reasoning toggle for models that support it cfg := config.Get() - if agentCfg, ok := cfg.Agents["coder"]; ok { + if agentCfg, ok := cfg.Agents[config.AgentCoder]; ok { providerCfg := cfg.GetProviderForModel(agentCfg.Model) model := cfg.GetModelByType(agentCfg.Model) if providerCfg != nil && model != nil && model.CanReason { @@ -350,7 +350,7 @@ func (c *commandDialogCmp) defaultCommands() []Command { }) } if c.sessionID != "" { - agentCfg := config.Get().Agents["coder"] + agentCfg := config.Get().Agents[config.AgentCoder] model := config.Get().GetModelByType(agentCfg.Model) if model.SupportsImages { commands = append(commands, Command{ @@ -402,7 +402,7 @@ func (c *commandDialogCmp) defaultCommands() []Command { Description: "Create/Update the CRUSH.md memory file", Handler: func(cmd Command) tea.Cmd { return util.CmdHandler(chat.SendMsg{ - Text: prompt.Initialize(), + Text: agent.InitializePrompt(), }) }, }, diff --git a/internal/tui/components/dialogs/compact/compact.go b/internal/tui/components/dialogs/compact/compact.go index ecde402fd8dfe1f31791834cd4e4bae13ec45e00..6321bb8e53e183feb4280fad4bd451e1ae37d8ba 100644 --- a/internal/tui/components/dialogs/compact/compact.go +++ b/internal/tui/components/dialogs/compact/compact.go @@ -7,7 +7,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/lipgloss/v2" - "github.com/charmbracelet/crush/internal/llm/agent" + "github.com/charmbracelet/crush/internal/agent" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" "github.com/charmbracelet/crush/internal/tui/styles" @@ -27,26 +27,16 @@ type compactDialogCmp struct { selected int keyMap KeyMap sessionID string - state compactState progress string - agent agent.Service + agent agent.Coordinator noAsk bool // If true, skip confirmation dialog } -type compactState int - -const ( - stateConfirm compactState = iota - stateCompacting - stateError -) - // NewCompactDialogCmp creates a new session compact dialog -func NewCompactDialogCmp(agent agent.Service, sessionID string, noAsk bool) CompactDialog { +func NewCompactDialogCmp(agent agent.Coordinator, sessionID string, noAsk bool) CompactDialog { return &compactDialogCmp{ sessionID: sessionID, keyMap: DefaultKeyMap(), - state: stateConfirm, selected: 0, agent: agent, noAsk: noAsk, @@ -56,7 +46,8 @@ func NewCompactDialogCmp(agent agent.Service, sessionID string, noAsk bool) Comp func (c *compactDialogCmp) Init() tea.Cmd { if c.noAsk { // If noAsk is true, skip confirmation and start compaction immediately - return c.startCompaction() + c.agent.Summarize(context.Background(), c.sessionID) + return util.CmdHandler(dialogs.CloseDialogMsg{}) } return nil } @@ -68,80 +59,30 @@ func (c *compactDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { c.wHeight = msg.Height cmd := c.SetSize() return c, cmd - case tea.KeyPressMsg: - switch c.state { - case stateConfirm: - switch { - case key.Matches(msg, c.keyMap.ChangeSelection): - c.selected = (c.selected + 1) % 2 - return c, nil - case key.Matches(msg, c.keyMap.Select): - if c.selected == 0 { - return c, c.startCompaction() - } else { - return c, util.CmdHandler(dialogs.CloseDialogMsg{}) - } - case key.Matches(msg, c.keyMap.Y): - return c, c.startCompaction() - case key.Matches(msg, c.keyMap.N): - return c, util.CmdHandler(dialogs.CloseDialogMsg{}) - case key.Matches(msg, c.keyMap.Close): - return c, util.CmdHandler(dialogs.CloseDialogMsg{}) - } - case stateCompacting: - switch { - case key.Matches(msg, c.keyMap.Close): - return c, util.CmdHandler(dialogs.CloseDialogMsg{}) - } - case stateError: - switch { - case key.Matches(msg, c.keyMap.Select): + switch { + case key.Matches(msg, c.keyMap.ChangeSelection): + c.selected = (c.selected + 1) % 2 + return c, nil + case key.Matches(msg, c.keyMap.Select): + if c.selected == 0 { + c.agent.Summarize(context.Background(), c.sessionID) return c, util.CmdHandler(dialogs.CloseDialogMsg{}) - case key.Matches(msg, c.keyMap.Close): - return c, util.CmdHandler(dialogs.CloseDialogMsg{}) - } - } - - case agent.AgentEvent: - switch msg.Type { - case agent.AgentEventTypeSummarize: - if msg.Error != nil { - c.state = stateError - c.progress = "Error: " + msg.Error.Error() - } else if msg.Done { - return c, util.CmdHandler(dialogs.CloseDialogMsg{}) - } else { - c.progress = msg.Progress - } - case agent.AgentEventTypeError: - // Handle errors that occur during summarization but are sent as separate error events. - c.state = stateError - if msg.Error != nil { - c.progress = "Error: " + msg.Error.Error() } else { - c.progress = "An unknown error occurred" + return c, util.CmdHandler(dialogs.CloseDialogMsg{}) } + case key.Matches(msg, c.keyMap.Y): + c.agent.Summarize(context.Background(), c.sessionID) + return c, util.CmdHandler(dialogs.CloseDialogMsg{}) + case key.Matches(msg, c.keyMap.N): + return c, util.CmdHandler(dialogs.CloseDialogMsg{}) + case key.Matches(msg, c.keyMap.Close): + return c, util.CmdHandler(dialogs.CloseDialogMsg{}) } - return c, nil } - return c, nil } -func (c *compactDialogCmp) startCompaction() tea.Cmd { - c.state = stateCompacting - c.progress = "Starting summarization..." - return func() tea.Msg { - err := c.agent.Summarize(context.Background(), c.sessionID) - if err != nil { - c.state = stateError - c.progress = "Error: " + err.Error() - } - return nil - } -} - func (c *compactDialogCmp) renderButtons() string { t := styles.CurrentTheme() baseStyle := t.S().Base @@ -164,82 +105,37 @@ func (c *compactDialogCmp) renderButtons() string { return baseStyle.AlignHorizontal(lipgloss.Right).Width(c.width - 4).Render(content) } -func (c *compactDialogCmp) renderContent() string { - t := styles.CurrentTheme() - baseStyle := t.S().Base - - switch c.state { - case stateConfirm: - explanation := t.S().Text. - Width(c.width - 4). - Render("This will summarize the current session and reset the context. The conversation history will be condensed into a summary to free up context space while preserving important information.") - - question := t.S().Text. - Width(c.width - 4). - Render("Do you want to continue?") - - return baseStyle.Render(lipgloss.JoinVertical( - lipgloss.Left, - explanation, - "", - question, - )) - case stateCompacting: - return baseStyle.Render(lipgloss.JoinVertical( - lipgloss.Left, - c.progress, - "", - "Please wait...", - )) - case stateError: - return baseStyle.Render(lipgloss.JoinVertical( - lipgloss.Left, - c.progress, - "", - "Press Enter to close", - )) - } - return "" -} - func (c *compactDialogCmp) render() string { t := styles.CurrentTheme() baseStyle := t.S().Base - var title string - switch c.state { - case stateConfirm: - title = "Compact Session" - case stateCompacting: - title = "Compacting Session" - case stateError: - title = "Compact Failed" - } - + title := "Compact Session" titleView := core.Title(title, c.width-4) - content := c.renderContent() - - var dialogContent string - if c.state == stateConfirm { - buttons := c.renderButtons() - dialogContent = lipgloss.JoinVertical( - lipgloss.Top, - titleView, - "", - content, - "", - buttons, - "", - ) - } else { - dialogContent = lipgloss.JoinVertical( - lipgloss.Top, - titleView, - "", - content, - "", - ) - } + explanation := t.S().Text. + Width(c.width - 4). + Render("This will summarize the current session and reset the context. The conversation history will be condensed into a summary to free up context space while preserving important information.") + + question := t.S().Text. + Width(c.width - 4). + Render("Do you want to continue?") + + content := baseStyle.Render(lipgloss.JoinVertical( + lipgloss.Left, + explanation, + "", + question, + )) + + buttons := c.renderButtons() + dialogContent := lipgloss.JoinVertical( + lipgloss.Top, + titleView, + "", + content, + "", + buttons, + "", + ) return baseStyle. Padding(0, 1). diff --git a/internal/tui/components/dialogs/permissions/permissions.go b/internal/tui/components/dialogs/permissions/permissions.go index 9e0a6b05d7385c354f8faba3110b1c0951f9a97d..d1b412aaaf2c791ad7a8361dd7adb14a86247eaf 100644 --- a/internal/tui/components/dialogs/permissions/permissions.go +++ b/internal/tui/components/dialogs/permissions/permissions.go @@ -9,8 +9,8 @@ import ( "github.com/charmbracelet/bubbles/v2/key" "github.com/charmbracelet/bubbles/v2/viewport" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/fsext" - "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" diff --git a/internal/tui/components/dialogs/reasoning/reasoning.go b/internal/tui/components/dialogs/reasoning/reasoning.go index ba49abd8c58a0e7eb84235e7b68f5f5193a96b1b..a53819d1644528159158bfc3f4765d95a1cc91d3 100644 --- a/internal/tui/components/dialogs/reasoning/reasoning.go +++ b/internal/tui/components/dialogs/reasoning/reasoning.go @@ -120,7 +120,7 @@ func (r *reasoningDialogCmp) Init() tea.Cmd { func (r *reasoningDialogCmp) populateEffortOptions() tea.Cmd { cfg := config.Get() - if agentCfg, ok := cfg.Agents["coder"]; ok { + if agentCfg, ok := cfg.Agents[config.AgentCoder]; ok { selectedModel := cfg.Models[agentCfg.Model] model := cfg.GetModelByType(agentCfg.Model) diff --git a/internal/tui/components/mcp/mcp.go b/internal/tui/components/mcp/mcp.go index d11826b77749ba65276b5336a5d88cdbc8552881..22dc86b2c824fb5230e0b0fea2d850bccc0273ba 100644 --- a/internal/tui/components/mcp/mcp.go +++ b/internal/tui/components/mcp/mcp.go @@ -5,8 +5,8 @@ import ( "github.com/charmbracelet/lipgloss/v2" + "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/styles" ) @@ -40,7 +40,7 @@ func RenderMCPList(opts RenderOptions) []string { } // Get MCP states - mcpStates := agent.GetMCPStates() + mcpStates := tools.GetMCPStates() // Determine how many items to show maxItems := len(mcps) @@ -60,17 +60,17 @@ func RenderMCPList(opts RenderOptions) []string { if state, exists := mcpStates[l.Name]; exists { switch state.State { - case agent.MCPStateDisabled: + case tools.MCPStateDisabled: description = t.S().Subtle.Render("disabled") - case agent.MCPStateStarting: + case tools.MCPStateStarting: icon = t.ItemBusyIcon description = t.S().Subtle.Render("starting...") - case agent.MCPStateConnected: + case tools.MCPStateConnected: icon = t.ItemOnlineIcon if state.ToolCount > 0 { extraContent = t.S().Subtle.Render(fmt.Sprintf("%d tools", state.ToolCount)) } - case agent.MCPStateError: + case tools.MCPStateError: icon = t.ItemErrorIcon if state.Error != nil { description = t.S().Subtle.Render(fmt.Sprintf("error: %s", state.Error.Error())) diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 2918925068cb2f012bead47bbf44260c6255288c..a8b4e30c00e21ff7f8d020d48ce9e848665654ef 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -331,7 +331,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return p, tea.Batch(cmds...) case commands.CommandRunCustomMsg: - if p.app.CoderAgent.IsBusy() { + if p.app.AgentCoordinator.IsBusy() { return p, util.ReportWarn("Agent is busy, please wait before executing a command...") } @@ -355,7 +355,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { p.focusedPane = PanelTypeEditor return p, p.SetSize(p.width, p.height) case commands.NewSessionsMsg: - if p.app.CoderAgent.IsBusy() { + if p.app.AgentCoordinator.IsBusy() { return p, util.ReportWarn("Agent is busy, please wait before starting a new session...") } return p, p.newSession() @@ -363,15 +363,15 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch { case key.Matches(msg, p.keyMap.NewSession): // if we have no agent do nothing - if p.app.CoderAgent == nil { + if p.app.AgentCoordinator == nil { return p, nil } - if p.app.CoderAgent.IsBusy() { + if p.app.AgentCoordinator.IsBusy() { return p, util.ReportWarn("Agent is busy, please wait before starting a new session...") } return p, p.newSession() case key.Matches(msg, p.keyMap.AddAttachment): - agentCfg := config.Get().Agents["coder"] + agentCfg := config.Get().Agents[config.AgentCoder] model := config.Get().GetModelByType(agentCfg.Model) if model.SupportsImages { return p, util.CmdHandler(commands.OpenFilePickerMsg{}) @@ -387,7 +387,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { p.changeFocus() return p, nil case key.Matches(msg, p.keyMap.Cancel): - if p.session.ID != "" && p.app.CoderAgent.IsBusy() { + if p.session.ID != "" && p.app.AgentCoordinator.IsBusy() { return p, p.cancel() } case key.Matches(msg, p.keyMap.Details): @@ -530,7 +530,7 @@ func (p *chatPage) updateCompactConfig(compact bool) tea.Cmd { func (p *chatPage) toggleThinking() tea.Cmd { return func() tea.Msg { cfg := config.Get() - agentCfg := cfg.Agents["coder"] + agentCfg := cfg.Agents[config.AgentCoder] currentModel := cfg.Models[agentCfg.Model] // Toggle the thinking mode @@ -559,7 +559,7 @@ func (p *chatPage) toggleThinking() tea.Cmd { func (p *chatPage) openReasoningDialog() tea.Cmd { return func() tea.Msg { cfg := config.Get() - agentCfg := cfg.Agents["coder"] + agentCfg := cfg.Agents[config.AgentCoder] model := cfg.GetModelByType(agentCfg.Model) providerCfg := cfg.GetProviderForModel(agentCfg.Model) @@ -577,7 +577,7 @@ func (p *chatPage) openReasoningDialog() tea.Cmd { func (p *chatPage) handleReasoningEffortSelected(effort string) tea.Cmd { return func() tea.Msg { cfg := config.Get() - agentCfg := cfg.Agents["coder"] + agentCfg := cfg.Agents[config.AgentCoder] currentModel := cfg.Models[agentCfg.Model] // Update the model configuration @@ -706,14 +706,14 @@ func (p *chatPage) changeFocus() { func (p *chatPage) cancel() tea.Cmd { if p.isCanceling { p.isCanceling = false - if p.app.CoderAgent != nil { - p.app.CoderAgent.Cancel(p.session.ID) + if p.app.AgentCoordinator != nil { + p.app.AgentCoordinator.Cancel(p.session.ID) } return nil } - if p.app.CoderAgent != nil && p.app.CoderAgent.QueuedPrompts(p.session.ID) > 0 { - p.app.CoderAgent.ClearQueue(p.session.ID) + if p.app.AgentCoordinator != nil && p.app.AgentCoordinator.QueuedPrompts(p.session.ID) > 0 { + p.app.AgentCoordinator.ClearQueue(p.session.ID) return nil } p.isCanceling = true @@ -746,14 +746,20 @@ func (p *chatPage) sendMessage(text string, attachments []message.Attachment) te session = newSession cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(session))) } - if p.app.CoderAgent == nil { + if p.app.AgentCoordinator == nil { return util.ReportError(fmt.Errorf("coder agent is not initialized")) } - _, err := p.app.CoderAgent.Run(context.Background(), session.ID, text, attachments...) - if err != nil { - return util.ReportError(err) - } cmds = append(cmds, p.chat.GoToBottom()) + cmds = append(cmds, func() tea.Msg { + _, err := p.app.AgentCoordinator.Run(context.Background(), session.ID, text, attachments...) + if err != nil { + return util.InfoMsg{ + Type: util.InfoTypeError, + Msg: err.Error(), + } + } + return nil + }) return tea.Batch(cmds...) } @@ -762,7 +768,7 @@ func (p *chatPage) Bindings() []key.Binding { p.keyMap.NewSession, p.keyMap.AddAttachment, } - if p.app.CoderAgent != nil && p.app.CoderAgent.IsBusy() { + if p.app.AgentCoordinator != nil && p.app.AgentCoordinator.IsBusy() { cancelBinding := p.keyMap.Cancel if p.isCanceling { cancelBinding = key.NewBinding( @@ -883,7 +889,7 @@ func (p *chatPage) Help() help.KeyMap { } return core.NewSimpleHelp(shortList, fullList) } - if p.app.CoderAgent != nil && p.app.CoderAgent.IsBusy() { + if p.app.AgentCoordinator != nil && p.app.AgentCoordinator.IsBusy() { cancelBinding := key.NewBinding( key.WithKeys("esc", "alt+esc"), key.WithHelp("esc", "cancel"), @@ -894,7 +900,7 @@ func (p *chatPage) Help() help.KeyMap { key.WithHelp("esc", "press again to cancel"), ) } - if p.app.CoderAgent != nil && p.app.CoderAgent.QueuedPrompts(p.session.ID) > 0 { + if p.app.AgentCoordinator != nil && p.app.AgentCoordinator.QueuedPrompts(p.session.ID) > 0 { cancelBinding = key.NewBinding( key.WithKeys("esc", "alt+esc"), key.WithHelp("esc", "clear queue"), diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 2c935810b833af01c582866ec38d5f7b277bc203..dbe3467adeab2619e884b8f0078471538d1af902 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -11,7 +11,6 @@ import ( "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/event" - "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/pubsub" cmpChat "github.com/charmbracelet/crush/internal/tui/components/chat" @@ -180,7 +179,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Compact case commands.CompactMsg: return a, util.CmdHandler(dialogs.OpenDialogMsg{ - Model: compact.NewCompactDialogCmp(a.app.CoderAgent, msg.SessionID, true), + Model: compact.NewCompactDialogCmp(a.app.AgentCoordinator, msg.SessionID, true), }) case commands.QuitMsg: return a, util.CmdHandler(dialogs.OpenDialogMsg{ @@ -194,7 +193,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, a.handleWindowResize(a.wWidth, a.wHeight) // Model Switch case models.ModelSelectedMsg: - if a.app.CoderAgent.IsBusy() { + if a.app.AgentCoordinator.IsBusy() { return a, util.ReportWarn("Agent is busy, please wait...") } @@ -253,36 +252,37 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return a, nil // Agent Events - case pubsub.Event[agent.AgentEvent]: - payload := msg.Payload - - // Forward agent events to dialogs - if a.dialog.HasDialogs() && a.dialog.ActiveDialogID() == compact.CompactDialogID { - u, dialogCmd := a.dialog.Update(payload) - if model, ok := u.(dialogs.DialogCmp); ok { - a.dialog = model - } - - cmds = append(cmds, dialogCmd) - } - - // Handle auto-compact logic - if payload.Done && payload.Type == agent.AgentEventTypeResponse && a.selectedSessionID != "" { - // Get current session to check token usage - session, err := a.app.Sessions.Get(context.Background(), a.selectedSessionID) - if err == nil { - model := a.app.CoderAgent.Model() - contextWindow := model.ContextWindow - tokens := session.CompletionTokens + session.PromptTokens - if (tokens >= int64(float64(contextWindow)*0.95)) && !config.Get().Options.DisableAutoSummarize { // Show compact confirmation dialog - cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{ - Model: compact.NewCompactDialogCmp(a.app.CoderAgent, a.selectedSessionID, false), - })) - } - } - } - - return a, tea.Batch(cmds...) + // TODO: HANDLE AUTO COMPACT + // case pubsub.Event[agent.AgentEvent]: + // payload := msg.Payload + // + // // Forward agent events to dialogs + // if a.dialog.HasDialogs() && a.dialog.ActiveDialogID() == compact.CompactDialogID { + // u, dialogCmd := a.dialog.Update(payload) + // if model, ok := u.(dialogs.DialogCmp); ok { + // a.dialog = model + // } + // + // cmds = append(cmds, dialogCmd) + // } + // + // // Handle auto-compact logic + // if payload.Done && payload.Type == agent.AgentEventTypeResponse && a.selectedSessionID != "" { + // // Get current session to check token usage + // session, err := a.app.Sessions.Get(context.Background(), a.selectedSessionID) + // if err == nil { + // model := a.app.AgentCoordinator.Model() + // contextWindow := model.CatwalkCfg.ContextWindow + // tokens := session.CompletionTokens + session.PromptTokens + // if (tokens >= int64(float64(contextWindow)*0.95)) && !config.Get().Options.DisableAutoSummarize { // Show compact confirmation dialog + // cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{ + // Model: compact.NewCompactDialogCmp(a.app.AgentCoordinator, a.selectedSessionID, false), + // })) + // } + // } + // } + // + // return a, tea.Batch(cmds...) case splash.OnboardingCompleteMsg: item, ok := a.pages[a.currentPage] if !ok { @@ -484,7 +484,7 @@ func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { ) return tea.Sequence(cmds...) case key.Matches(msg, a.keyMap.Suspend): - if a.app.CoderAgent != nil && a.app.CoderAgent.IsBusy() { + if a.app.AgentCoordinator != nil && a.app.AgentCoordinator.IsBusy() { return util.ReportWarn("Agent is busy, please wait...") } return tea.Suspend @@ -504,7 +504,7 @@ func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { // moveToPage handles navigation between different pages in the application. func (a *appModel) moveToPage(pageID page.PageID) tea.Cmd { - if a.app.CoderAgent.IsBusy() { + if a.app.AgentCoordinator.IsBusy() { // TODO: maybe remove this : For now we don't move to any page if the agent is busy return util.ReportWarn("Agent is busy, please wait...") }