tools.rs

  1use std::path::PathBuf;
  2
  3use agent_client_protocol as acp;
  4use itertools::Itertools;
  5use schemars::JsonSchema;
  6use serde::{Deserialize, Serialize};
  7use util::ResultExt;
  8
  9pub enum ClaudeTool {
 10    Task(Option<TaskToolParams>),
 11    NotebookRead(Option<NotebookReadToolParams>),
 12    NotebookEdit(Option<NotebookEditToolParams>),
 13    Edit(Option<EditToolParams>),
 14    MultiEdit(Option<MultiEditToolParams>),
 15    ReadFile(Option<ReadToolParams>),
 16    Write(Option<WriteToolParams>),
 17    Ls(Option<LsToolParams>),
 18    Glob(Option<GlobToolParams>),
 19    Grep(Option<GrepToolParams>),
 20    Terminal(Option<BashToolParams>),
 21    WebFetch(Option<WebFetchToolParams>),
 22    WebSearch(Option<WebSearchToolParams>),
 23    TodoWrite(Option<TodoWriteToolParams>),
 24    ExitPlanMode(Option<ExitPlanModeToolParams>),
 25    Other {
 26        name: String,
 27        input: serde_json::Value,
 28    },
 29}
 30
 31impl ClaudeTool {
 32    pub fn infer(tool_name: &str, input: serde_json::Value) -> Self {
 33        match tool_name {
 34            // Known tools
 35            "mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()),
 36            "mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()),
 37            "MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()),
 38            "Write" => Self::Write(serde_json::from_value(input).log_err()),
 39            "LS" => Self::Ls(serde_json::from_value(input).log_err()),
 40            "Glob" => Self::Glob(serde_json::from_value(input).log_err()),
 41            "Grep" => Self::Grep(serde_json::from_value(input).log_err()),
 42            "Bash" => Self::Terminal(serde_json::from_value(input).log_err()),
 43            "WebFetch" => Self::WebFetch(serde_json::from_value(input).log_err()),
 44            "WebSearch" => Self::WebSearch(serde_json::from_value(input).log_err()),
 45            "TodoWrite" => Self::TodoWrite(serde_json::from_value(input).log_err()),
 46            "exit_plan_mode" => Self::ExitPlanMode(serde_json::from_value(input).log_err()),
 47            "Task" => Self::Task(serde_json::from_value(input).log_err()),
 48            "NotebookRead" => Self::NotebookRead(serde_json::from_value(input).log_err()),
 49            "NotebookEdit" => Self::NotebookEdit(serde_json::from_value(input).log_err()),
 50            // Inferred from name
 51            _ => {
 52                let tool_name = tool_name.to_lowercase();
 53
 54                if tool_name.contains("edit") || tool_name.contains("write") {
 55                    Self::Edit(None)
 56                } else if tool_name.contains("terminal") {
 57                    Self::Terminal(None)
 58                } else {
 59                    Self::Other {
 60                        name: tool_name.to_string(),
 61                        input,
 62                    }
 63                }
 64            }
 65        }
 66    }
 67
 68    pub fn label(&self) -> String {
 69        match &self {
 70            Self::Task(Some(params)) => params.description.clone(),
 71            Self::Task(None) => "Task".into(),
 72            Self::NotebookRead(Some(params)) => {
 73                format!("Read Notebook {}", params.notebook_path.display())
 74            }
 75            Self::NotebookRead(None) => "Read Notebook".into(),
 76            Self::NotebookEdit(Some(params)) => {
 77                format!("Edit Notebook {}", params.notebook_path.display())
 78            }
 79            Self::NotebookEdit(None) => "Edit Notebook".into(),
 80            Self::Terminal(Some(params)) => format!("`{}`", params.command),
 81            Self::Terminal(None) => "Terminal".into(),
 82            Self::ReadFile(_) => "Read File".into(),
 83            Self::Ls(Some(params)) => {
 84                format!("List Directory {}", params.path.display())
 85            }
 86            Self::Ls(None) => "List Directory".into(),
 87            Self::Edit(Some(params)) => {
 88                format!("Edit {}", params.abs_path.display())
 89            }
 90            Self::Edit(None) => "Edit".into(),
 91            Self::MultiEdit(Some(params)) => {
 92                format!("Multi Edit {}", params.file_path.display())
 93            }
 94            Self::MultiEdit(None) => "Multi Edit".into(),
 95            Self::Write(Some(params)) => {
 96                format!("Write {}", params.file_path.display())
 97            }
 98            Self::Write(None) => "Write".into(),
 99            Self::Glob(Some(params)) => {
100                format!("Glob `{params}`")
101            }
102            Self::Glob(None) => "Glob".into(),
103            Self::Grep(Some(params)) => format!("`{params}`"),
104            Self::Grep(None) => "Grep".into(),
105            Self::WebFetch(Some(params)) => format!("Fetch {}", params.url),
106            Self::WebFetch(None) => "Fetch".into(),
107            Self::WebSearch(Some(params)) => format!("Web Search: {}", params),
108            Self::WebSearch(None) => "Web Search".into(),
109            Self::TodoWrite(Some(params)) => format!(
110                "Update TODOs: {}",
111                params.todos.iter().map(|todo| &todo.content).join(", ")
112            ),
113            Self::TodoWrite(None) => "Update TODOs".into(),
114            Self::ExitPlanMode(_) => "Exit Plan Mode".into(),
115            Self::Other { name, .. } => name.clone(),
116        }
117    }
118    pub fn content(&self) -> Vec<acp::ToolCallContent> {
119        match &self {
120            Self::Other { input, .. } => vec![
121                format!(
122                    "```json\n{}```",
123                    serde_json::to_string_pretty(&input).unwrap_or("{}".to_string())
124                )
125                .into(),
126            ],
127            Self::Task(Some(params)) => vec![params.prompt.clone().into()],
128            Self::NotebookRead(Some(params)) => {
129                vec![params.notebook_path.display().to_string().into()]
130            }
131            Self::NotebookEdit(Some(params)) => vec![params.new_source.clone().into()],
132            Self::Terminal(Some(params)) => vec![
133                format!(
134                    "`{}`\n\n{}",
135                    params.command,
136                    params.description.as_deref().unwrap_or_default()
137                )
138                .into(),
139            ],
140            Self::ReadFile(Some(params)) => vec![params.abs_path.display().to_string().into()],
141            Self::Ls(Some(params)) => vec![params.path.display().to_string().into()],
142            Self::Glob(Some(params)) => vec![params.to_string().into()],
143            Self::Grep(Some(params)) => vec![format!("`{params}`").into()],
144            Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()],
145            Self::WebSearch(Some(params)) => vec![params.to_string().into()],
146            Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()],
147            Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff {
148                diff: acp::Diff {
149                    path: params.abs_path.clone(),
150                    old_text: Some(params.old_text.clone()),
151                    new_text: params.new_text.clone(),
152                },
153            }],
154            Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff {
155                diff: acp::Diff {
156                    path: params.file_path.clone(),
157                    old_text: None,
158                    new_text: params.content.clone(),
159                },
160            }],
161            Self::MultiEdit(Some(params)) => {
162                // todo: show multiple edits in a multibuffer?
163                params
164                    .edits
165                    .first()
166                    .map(|edit| {
167                        vec![acp::ToolCallContent::Diff {
168                            diff: acp::Diff {
169                                path: params.file_path.clone(),
170                                old_text: Some(edit.old_string.clone()),
171                                new_text: edit.new_string.clone(),
172                            },
173                        }]
174                    })
175                    .unwrap_or_default()
176            }
177            Self::TodoWrite(Some(_)) => {
178                // These are mapped to plan updates later
179                vec![]
180            }
181            Self::Task(None)
182            | Self::NotebookRead(None)
183            | Self::NotebookEdit(None)
184            | Self::Terminal(None)
185            | Self::ReadFile(None)
186            | Self::Ls(None)
187            | Self::Glob(None)
188            | Self::Grep(None)
189            | Self::WebFetch(None)
190            | Self::WebSearch(None)
191            | Self::TodoWrite(None)
192            | Self::ExitPlanMode(None)
193            | Self::Edit(None)
194            | Self::Write(None)
195            | Self::MultiEdit(None) => vec![],
196        }
197    }
198
199    pub fn kind(&self) -> acp::ToolKind {
200        match self {
201            Self::Task(_) => acp::ToolKind::Think,
202            Self::NotebookRead(_) => acp::ToolKind::Read,
203            Self::NotebookEdit(_) => acp::ToolKind::Edit,
204            Self::Edit(_) => acp::ToolKind::Edit,
205            Self::MultiEdit(_) => acp::ToolKind::Edit,
206            Self::Write(_) => acp::ToolKind::Edit,
207            Self::ReadFile(_) => acp::ToolKind::Read,
208            Self::Ls(_) => acp::ToolKind::Search,
209            Self::Glob(_) => acp::ToolKind::Search,
210            Self::Grep(_) => acp::ToolKind::Search,
211            Self::Terminal(_) => acp::ToolKind::Execute,
212            Self::WebSearch(_) => acp::ToolKind::Search,
213            Self::WebFetch(_) => acp::ToolKind::Fetch,
214            Self::TodoWrite(_) => acp::ToolKind::Think,
215            Self::ExitPlanMode(_) => acp::ToolKind::Think,
216            Self::Other { .. } => acp::ToolKind::Other,
217        }
218    }
219
220    pub fn locations(&self) -> Vec<acp::ToolCallLocation> {
221        match &self {
222            Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp::ToolCallLocation {
223                path: abs_path.clone(),
224                line: None,
225            }],
226            Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => {
227                vec![acp::ToolCallLocation {
228                    path: file_path.clone(),
229                    line: None,
230                }]
231            }
232            Self::Write(Some(WriteToolParams { file_path, .. })) => {
233                vec![acp::ToolCallLocation {
234                    path: file_path.clone(),
235                    line: None,
236                }]
237            }
238            Self::ReadFile(Some(ReadToolParams {
239                abs_path, offset, ..
240            })) => vec![acp::ToolCallLocation {
241                path: abs_path.clone(),
242                line: *offset,
243            }],
244            Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => {
245                vec![acp::ToolCallLocation {
246                    path: notebook_path.clone(),
247                    line: None,
248                }]
249            }
250            Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => {
251                vec![acp::ToolCallLocation {
252                    path: notebook_path.clone(),
253                    line: None,
254                }]
255            }
256            Self::Glob(Some(GlobToolParams {
257                path: Some(path), ..
258            })) => vec![acp::ToolCallLocation {
259                path: path.clone(),
260                line: None,
261            }],
262            Self::Ls(Some(LsToolParams { path, .. })) => vec![acp::ToolCallLocation {
263                path: path.clone(),
264                line: None,
265            }],
266            Self::Grep(Some(GrepToolParams {
267                path: Some(path), ..
268            })) => vec![acp::ToolCallLocation {
269                path: PathBuf::from(path),
270                line: None,
271            }],
272            Self::Task(_)
273            | Self::NotebookRead(None)
274            | Self::NotebookEdit(None)
275            | Self::Edit(None)
276            | Self::MultiEdit(None)
277            | Self::Write(None)
278            | Self::ReadFile(None)
279            | Self::Ls(None)
280            | Self::Glob(_)
281            | Self::Grep(_)
282            | Self::Terminal(_)
283            | Self::WebFetch(_)
284            | Self::WebSearch(_)
285            | Self::TodoWrite(_)
286            | Self::ExitPlanMode(_)
287            | Self::Other { .. } => vec![],
288        }
289    }
290
291    pub fn as_acp(&self, id: acp::ToolCallId) -> acp::ToolCall {
292        acp::ToolCall {
293            id,
294            kind: self.kind(),
295            status: acp::ToolCallStatus::InProgress,
296            title: self.label(),
297            content: self.content(),
298            locations: self.locations(),
299            raw_input: None,
300        }
301    }
302}
303
304#[derive(Deserialize, JsonSchema, Debug)]
305pub struct EditToolParams {
306    /// The absolute path to the file to read.
307    pub abs_path: PathBuf,
308    /// The old text to replace (must be unique in the file)
309    pub old_text: String,
310    /// The new text.
311    pub new_text: String,
312}
313
314#[derive(Deserialize, JsonSchema, Debug)]
315pub struct ReadToolParams {
316    /// The absolute path to the file to read.
317    pub abs_path: PathBuf,
318    /// Which line to start reading from. Omit to start from the beginning.
319    #[serde(skip_serializing_if = "Option::is_none")]
320    pub offset: Option<u32>,
321    /// How many lines to read. Omit for the whole file.
322    #[serde(skip_serializing_if = "Option::is_none")]
323    pub limit: Option<u32>,
324}
325
326#[derive(Deserialize, JsonSchema, Debug)]
327pub struct WriteToolParams {
328    /// Absolute path for new file
329    pub file_path: PathBuf,
330    /// File content
331    pub content: String,
332}
333
334#[derive(Deserialize, JsonSchema, Debug)]
335pub struct BashToolParams {
336    /// Shell command to execute
337    pub command: String,
338    /// 5-10 word description of what command does
339    #[serde(skip_serializing_if = "Option::is_none")]
340    pub description: Option<String>,
341    /// Timeout in ms (max 600000ms/10min, default 120000ms)
342    #[serde(skip_serializing_if = "Option::is_none")]
343    pub timeout: Option<u32>,
344}
345
346#[derive(Deserialize, JsonSchema, Debug)]
347pub struct GlobToolParams {
348    /// Glob pattern like **/*.js or src/**/*.ts
349    pub pattern: String,
350    /// Directory to search in (omit for current directory)
351    #[serde(skip_serializing_if = "Option::is_none")]
352    pub path: Option<PathBuf>,
353}
354
355impl std::fmt::Display for GlobToolParams {
356    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357        if let Some(path) = &self.path {
358            write!(f, "{}", path.display())?;
359        }
360        write!(f, "{}", self.pattern)
361    }
362}
363
364#[derive(Deserialize, JsonSchema, Debug)]
365pub struct LsToolParams {
366    /// Absolute path to directory
367    pub path: PathBuf,
368    /// Array of glob patterns to ignore
369    #[serde(default, skip_serializing_if = "Vec::is_empty")]
370    pub ignore: Vec<String>,
371}
372
373#[derive(Deserialize, JsonSchema, Debug)]
374pub struct GrepToolParams {
375    /// Regex pattern to search for
376    pub pattern: String,
377    /// File/directory to search (defaults to current directory)
378    #[serde(skip_serializing_if = "Option::is_none")]
379    pub path: Option<String>,
380    /// "content" (shows lines), "files_with_matches" (default), "count"
381    #[serde(skip_serializing_if = "Option::is_none")]
382    pub output_mode: Option<GrepOutputMode>,
383    /// Filter files with glob pattern like "*.js"
384    #[serde(skip_serializing_if = "Option::is_none")]
385    pub glob: Option<String>,
386    /// File type filter like "js", "py", "rust"
387    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
388    pub file_type: Option<String>,
389    /// Case insensitive search
390    #[serde(rename = "-i", default, skip_serializing_if = "is_false")]
391    pub case_insensitive: bool,
392    /// Show line numbers (content mode only)
393    #[serde(rename = "-n", default, skip_serializing_if = "is_false")]
394    pub line_numbers: bool,
395    /// Lines after match (content mode only)
396    #[serde(rename = "-A", skip_serializing_if = "Option::is_none")]
397    pub after_context: Option<u32>,
398    /// Lines before match (content mode only)
399    #[serde(rename = "-B", skip_serializing_if = "Option::is_none")]
400    pub before_context: Option<u32>,
401    /// Lines before and after match (content mode only)
402    #[serde(rename = "-C", skip_serializing_if = "Option::is_none")]
403    pub context: Option<u32>,
404    /// Enable multiline/cross-line matching
405    #[serde(default, skip_serializing_if = "is_false")]
406    pub multiline: bool,
407    /// Limit output to first N results
408    #[serde(skip_serializing_if = "Option::is_none")]
409    pub head_limit: Option<u32>,
410}
411
412impl std::fmt::Display for GrepToolParams {
413    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414        write!(f, "grep")?;
415
416        // Boolean flags
417        if self.case_insensitive {
418            write!(f, " -i")?;
419        }
420        if self.line_numbers {
421            write!(f, " -n")?;
422        }
423
424        // Context options
425        if let Some(after) = self.after_context {
426            write!(f, " -A {}", after)?;
427        }
428        if let Some(before) = self.before_context {
429            write!(f, " -B {}", before)?;
430        }
431        if let Some(context) = self.context {
432            write!(f, " -C {}", context)?;
433        }
434
435        // Output mode
436        if let Some(mode) = &self.output_mode {
437            match mode {
438                GrepOutputMode::FilesWithMatches => write!(f, " -l")?,
439                GrepOutputMode::Count => write!(f, " -c")?,
440                GrepOutputMode::Content => {} // Default mode
441            }
442        }
443
444        // Head limit
445        if let Some(limit) = self.head_limit {
446            write!(f, " | head -{}", limit)?;
447        }
448
449        // Glob pattern
450        if let Some(glob) = &self.glob {
451            write!(f, " --include=\"{}\"", glob)?;
452        }
453
454        // File type
455        if let Some(file_type) = &self.file_type {
456            write!(f, " --type={}", file_type)?;
457        }
458
459        // Multiline
460        if self.multiline {
461            write!(f, " -P")?; // Perl-compatible regex for multiline
462        }
463
464        // Pattern (escaped if contains special characters)
465        write!(f, " \"{}\"", self.pattern)?;
466
467        // Path
468        if let Some(path) = &self.path {
469            write!(f, " {}", path)?;
470        }
471
472        Ok(())
473    }
474}
475
476#[derive(Default, Deserialize, Serialize, JsonSchema, strum::Display, Debug)]
477#[serde(rename_all = "snake_case")]
478pub enum TodoPriority {
479    High,
480    #[default]
481    Medium,
482    Low,
483}
484
485impl Into<acp::PlanEntryPriority> for TodoPriority {
486    fn into(self) -> acp::PlanEntryPriority {
487        match self {
488            TodoPriority::High => acp::PlanEntryPriority::High,
489            TodoPriority::Medium => acp::PlanEntryPriority::Medium,
490            TodoPriority::Low => acp::PlanEntryPriority::Low,
491        }
492    }
493}
494
495#[derive(Deserialize, Serialize, JsonSchema, Debug)]
496#[serde(rename_all = "snake_case")]
497pub enum TodoStatus {
498    Pending,
499    InProgress,
500    Completed,
501}
502
503impl Into<acp::PlanEntryStatus> for TodoStatus {
504    fn into(self) -> acp::PlanEntryStatus {
505        match self {
506            TodoStatus::Pending => acp::PlanEntryStatus::Pending,
507            TodoStatus::InProgress => acp::PlanEntryStatus::InProgress,
508            TodoStatus::Completed => acp::PlanEntryStatus::Completed,
509        }
510    }
511}
512
513#[derive(Deserialize, Serialize, JsonSchema, Debug)]
514pub struct Todo {
515    /// Task description
516    pub content: String,
517    /// Current status of the todo
518    pub status: TodoStatus,
519    /// Priority level of the todo
520    #[serde(default)]
521    pub priority: TodoPriority,
522}
523
524impl Into<acp::PlanEntry> for Todo {
525    fn into(self) -> acp::PlanEntry {
526        acp::PlanEntry {
527            content: self.content,
528            priority: self.priority.into(),
529            status: self.status.into(),
530        }
531    }
532}
533
534#[derive(Deserialize, JsonSchema, Debug)]
535pub struct TodoWriteToolParams {
536    pub todos: Vec<Todo>,
537}
538
539#[derive(Deserialize, JsonSchema, Debug)]
540pub struct ExitPlanModeToolParams {
541    /// Implementation plan in markdown format
542    pub plan: String,
543}
544
545#[derive(Deserialize, JsonSchema, Debug)]
546pub struct TaskToolParams {
547    /// Short 3-5 word description of task
548    pub description: String,
549    /// Detailed task for agent to perform
550    pub prompt: String,
551}
552
553#[derive(Deserialize, JsonSchema, Debug)]
554pub struct NotebookReadToolParams {
555    /// Absolute path to .ipynb file
556    pub notebook_path: PathBuf,
557    /// Specific cell ID to read
558    #[serde(skip_serializing_if = "Option::is_none")]
559    pub cell_id: Option<String>,
560}
561
562#[derive(Deserialize, Serialize, JsonSchema, Debug)]
563#[serde(rename_all = "snake_case")]
564pub enum CellType {
565    Code,
566    Markdown,
567}
568
569#[derive(Deserialize, Serialize, JsonSchema, Debug)]
570#[serde(rename_all = "snake_case")]
571pub enum EditMode {
572    Replace,
573    Insert,
574    Delete,
575}
576
577#[derive(Deserialize, JsonSchema, Debug)]
578pub struct NotebookEditToolParams {
579    /// Absolute path to .ipynb file
580    pub notebook_path: PathBuf,
581    /// New cell content
582    pub new_source: String,
583    /// Cell ID to edit
584    #[serde(skip_serializing_if = "Option::is_none")]
585    pub cell_id: Option<String>,
586    /// Type of cell (code or markdown)
587    #[serde(skip_serializing_if = "Option::is_none")]
588    pub cell_type: Option<CellType>,
589    /// Edit operation mode
590    #[serde(skip_serializing_if = "Option::is_none")]
591    pub edit_mode: Option<EditMode>,
592}
593
594#[derive(Deserialize, Serialize, JsonSchema, Debug)]
595pub struct MultiEditItem {
596    /// The text to search for and replace
597    pub old_string: String,
598    /// The replacement text
599    pub new_string: String,
600    /// Whether to replace all occurrences or just the first
601    #[serde(default, skip_serializing_if = "is_false")]
602    pub replace_all: bool,
603}
604
605#[derive(Deserialize, JsonSchema, Debug)]
606pub struct MultiEditToolParams {
607    /// Absolute path to file
608    pub file_path: PathBuf,
609    /// List of edits to apply
610    pub edits: Vec<MultiEditItem>,
611}
612
613fn is_false(v: &bool) -> bool {
614    !*v
615}
616
617#[derive(Deserialize, JsonSchema, Debug)]
618#[serde(rename_all = "snake_case")]
619pub enum GrepOutputMode {
620    Content,
621    FilesWithMatches,
622    Count,
623}
624
625#[derive(Deserialize, JsonSchema, Debug)]
626pub struct WebFetchToolParams {
627    /// Valid URL to fetch
628    #[serde(rename = "url")]
629    pub url: String,
630    /// What to extract from content
631    pub prompt: String,
632}
633
634#[derive(Deserialize, JsonSchema, Debug)]
635pub struct WebSearchToolParams {
636    /// Search query (min 2 chars)
637    pub query: String,
638    /// Only include these domains
639    #[serde(default, skip_serializing_if = "Vec::is_empty")]
640    pub allowed_domains: Vec<String>,
641    /// Exclude these domains
642    #[serde(default, skip_serializing_if = "Vec::is_empty")]
643    pub blocked_domains: Vec<String>,
644}
645
646impl std::fmt::Display for WebSearchToolParams {
647    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
648        write!(f, "\"{}\"", self.query)?;
649
650        if !self.allowed_domains.is_empty() {
651            write!(f, " (allowed: {})", self.allowed_domains.join(", "))?;
652        }
653
654        if !self.blocked_domains.is_empty() {
655            write!(f, " (blocked: {})", self.blocked_domains.join(", "))?;
656        }
657
658        Ok(())
659    }
660}