tools.rs

  1use std::path::PathBuf;
  2
  3use agent_client_protocol as acp;
  4use itertools::Itertools;
  5use schemars::JsonSchema;
  6use serde::{Deserialize, Serialize};
  7use util::ResultExt;
  8
  9pub enum ClaudeTool {
 10    Task(Option<TaskToolParams>),
 11    NotebookRead(Option<NotebookReadToolParams>),
 12    NotebookEdit(Option<NotebookEditToolParams>),
 13    Edit(Option<EditToolParams>),
 14    MultiEdit(Option<MultiEditToolParams>),
 15    ReadFile(Option<ReadToolParams>),
 16    Write(Option<WriteToolParams>),
 17    Ls(Option<LsToolParams>),
 18    Glob(Option<GlobToolParams>),
 19    Grep(Option<GrepToolParams>),
 20    Terminal(Option<BashToolParams>),
 21    WebFetch(Option<WebFetchToolParams>),
 22    WebSearch(Option<WebSearchToolParams>),
 23    TodoWrite(Option<TodoWriteToolParams>),
 24    ExitPlanMode(Option<ExitPlanModeToolParams>),
 25    Other {
 26        name: String,
 27        input: serde_json::Value,
 28    },
 29}
 30
 31impl ClaudeTool {
 32    pub fn infer(tool_name: &str, input: serde_json::Value) -> Self {
 33        match tool_name {
 34            // Known tools
 35            "mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()),
 36            "mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()),
 37            "MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()),
 38            "Write" => Self::Write(serde_json::from_value(input).log_err()),
 39            "LS" => Self::Ls(serde_json::from_value(input).log_err()),
 40            "Glob" => Self::Glob(serde_json::from_value(input).log_err()),
 41            "Grep" => Self::Grep(serde_json::from_value(input).log_err()),
 42            "Bash" => Self::Terminal(serde_json::from_value(input).log_err()),
 43            "WebFetch" => Self::WebFetch(serde_json::from_value(input).log_err()),
 44            "WebSearch" => Self::WebSearch(serde_json::from_value(input).log_err()),
 45            "TodoWrite" => Self::TodoWrite(serde_json::from_value(input).log_err()),
 46            "exit_plan_mode" => Self::ExitPlanMode(serde_json::from_value(input).log_err()),
 47            "Task" => Self::Task(serde_json::from_value(input).log_err()),
 48            "NotebookRead" => Self::NotebookRead(serde_json::from_value(input).log_err()),
 49            "NotebookEdit" => Self::NotebookEdit(serde_json::from_value(input).log_err()),
 50            // Inferred from name
 51            _ => {
 52                let tool_name = tool_name.to_lowercase();
 53
 54                if tool_name.contains("edit") || tool_name.contains("write") {
 55                    Self::Edit(None)
 56                } else if tool_name.contains("terminal") {
 57                    Self::Terminal(None)
 58                } else {
 59                    Self::Other {
 60                        name: tool_name.to_string(),
 61                        input,
 62                    }
 63                }
 64            }
 65        }
 66    }
 67
 68    pub fn label(&self) -> String {
 69        match &self {
 70            Self::Task(Some(params)) => params.description.clone(),
 71            Self::Task(None) => "Task".into(),
 72            Self::NotebookRead(Some(params)) => {
 73                format!("Read Notebook {}", params.notebook_path.display())
 74            }
 75            Self::NotebookRead(None) => "Read Notebook".into(),
 76            Self::NotebookEdit(Some(params)) => {
 77                format!("Edit Notebook {}", params.notebook_path.display())
 78            }
 79            Self::NotebookEdit(None) => "Edit Notebook".into(),
 80            Self::Terminal(Some(params)) => format!("`{}`", params.command),
 81            Self::Terminal(None) => "Terminal".into(),
 82            Self::ReadFile(_) => "Read File".into(),
 83            Self::Ls(Some(params)) => {
 84                format!("List Directory {}", params.path.display())
 85            }
 86            Self::Ls(None) => "List Directory".into(),
 87            Self::Edit(Some(params)) => {
 88                format!("Edit {}", params.abs_path.display())
 89            }
 90            Self::Edit(None) => "Edit".into(),
 91            Self::MultiEdit(Some(params)) => {
 92                format!("Multi Edit {}", params.file_path.display())
 93            }
 94            Self::MultiEdit(None) => "Multi Edit".into(),
 95            Self::Write(Some(params)) => {
 96                format!("Write {}", params.file_path.display())
 97            }
 98            Self::Write(None) => "Write".into(),
 99            Self::Glob(Some(params)) => {
100                format!("Glob `{params}`")
101            }
102            Self::Glob(None) => "Glob".into(),
103            Self::Grep(Some(params)) => format!("`{params}`"),
104            Self::Grep(None) => "Grep".into(),
105            Self::WebFetch(Some(params)) => format!("Fetch {}", params.url),
106            Self::WebFetch(None) => "Fetch".into(),
107            Self::WebSearch(Some(params)) => format!("Web Search: {}", params),
108            Self::WebSearch(None) => "Web Search".into(),
109            Self::TodoWrite(Some(params)) => format!(
110                "Update TODOs: {}",
111                params.todos.iter().map(|todo| &todo.content).join(", ")
112            ),
113            Self::TodoWrite(None) => "Update TODOs".into(),
114            Self::ExitPlanMode(_) => "Exit Plan Mode".into(),
115            Self::Other { name, .. } => name.clone(),
116        }
117    }
118    pub fn content(&self) -> Vec<acp::ToolCallContent> {
119        match &self {
120            Self::Other { input, .. } => vec![
121                format!(
122                    "```json\n{}```",
123                    serde_json::to_string_pretty(&input).unwrap_or("{}".to_string())
124                )
125                .into(),
126            ],
127            Self::Task(Some(params)) => vec![params.prompt.clone().into()],
128            Self::NotebookRead(Some(params)) => {
129                vec![params.notebook_path.display().to_string().into()]
130            }
131            Self::NotebookEdit(Some(params)) => vec![params.new_source.clone().into()],
132            Self::Terminal(Some(params)) => vec![
133                format!(
134                    "`{}`\n\n{}",
135                    params.command,
136                    params.description.as_deref().unwrap_or_default()
137                )
138                .into(),
139            ],
140            Self::ReadFile(Some(params)) => vec![params.abs_path.display().to_string().into()],
141            Self::Ls(Some(params)) => vec![params.path.display().to_string().into()],
142            Self::Glob(Some(params)) => vec![params.to_string().into()],
143            Self::Grep(Some(params)) => vec![format!("`{params}`").into()],
144            Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()],
145            Self::WebSearch(Some(params)) => vec![params.to_string().into()],
146            Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()],
147            Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff {
148                diff: acp::Diff {
149                    path: params.abs_path.clone(),
150                    old_text: Some(params.old_text.clone()),
151                    new_text: params.new_text.clone(),
152                },
153            }],
154            Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff {
155                diff: acp::Diff {
156                    path: params.file_path.clone(),
157                    old_text: None,
158                    new_text: params.content.clone(),
159                },
160            }],
161            Self::MultiEdit(Some(params)) => {
162                // todo: show multiple edits in a multibuffer?
163                params
164                    .edits
165                    .first()
166                    .map(|edit| {
167                        vec![acp::ToolCallContent::Diff {
168                            diff: acp::Diff {
169                                path: params.file_path.clone(),
170                                old_text: Some(edit.old_string.clone()),
171                                new_text: edit.new_string.clone(),
172                            },
173                        }]
174                    })
175                    .unwrap_or_default()
176            }
177            Self::TodoWrite(Some(_)) => {
178                // These are mapped to plan updates later
179                vec![]
180            }
181            Self::Task(None)
182            | Self::NotebookRead(None)
183            | Self::NotebookEdit(None)
184            | Self::Terminal(None)
185            | Self::ReadFile(None)
186            | Self::Ls(None)
187            | Self::Glob(None)
188            | Self::Grep(None)
189            | Self::WebFetch(None)
190            | Self::WebSearch(None)
191            | Self::TodoWrite(None)
192            | Self::ExitPlanMode(None)
193            | Self::Edit(None)
194            | Self::Write(None)
195            | Self::MultiEdit(None) => vec![],
196        }
197    }
198
199    pub fn kind(&self) -> acp::ToolKind {
200        match self {
201            Self::Task(_) => acp::ToolKind::Think,
202            Self::NotebookRead(_) => acp::ToolKind::Read,
203            Self::NotebookEdit(_) => acp::ToolKind::Edit,
204            Self::Edit(_) => acp::ToolKind::Edit,
205            Self::MultiEdit(_) => acp::ToolKind::Edit,
206            Self::Write(_) => acp::ToolKind::Edit,
207            Self::ReadFile(_) => acp::ToolKind::Read,
208            Self::Ls(_) => acp::ToolKind::Search,
209            Self::Glob(_) => acp::ToolKind::Search,
210            Self::Grep(_) => acp::ToolKind::Search,
211            Self::Terminal(_) => acp::ToolKind::Execute,
212            Self::WebSearch(_) => acp::ToolKind::Search,
213            Self::WebFetch(_) => acp::ToolKind::Fetch,
214            Self::TodoWrite(_) => acp::ToolKind::Think,
215            Self::ExitPlanMode(_) => acp::ToolKind::Think,
216            Self::Other { .. } => acp::ToolKind::Other,
217        }
218    }
219
220    pub fn locations(&self) -> Vec<acp::ToolCallLocation> {
221        match &self {
222            Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp::ToolCallLocation {
223                path: abs_path.clone(),
224                line: None,
225            }],
226            Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => {
227                vec![acp::ToolCallLocation {
228                    path: file_path.clone(),
229                    line: None,
230                }]
231            }
232            Self::Write(Some(WriteToolParams { file_path, .. })) => {
233                vec![acp::ToolCallLocation {
234                    path: file_path.clone(),
235                    line: None,
236                }]
237            }
238            Self::ReadFile(Some(ReadToolParams {
239                abs_path, offset, ..
240            })) => vec![acp::ToolCallLocation {
241                path: abs_path.clone(),
242                line: *offset,
243            }],
244            Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => {
245                vec![acp::ToolCallLocation {
246                    path: notebook_path.clone(),
247                    line: None,
248                }]
249            }
250            Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => {
251                vec![acp::ToolCallLocation {
252                    path: notebook_path.clone(),
253                    line: None,
254                }]
255            }
256            Self::Glob(Some(GlobToolParams {
257                path: Some(path), ..
258            })) => vec![acp::ToolCallLocation {
259                path: path.clone(),
260                line: None,
261            }],
262            Self::Ls(Some(LsToolParams { path, .. })) => vec![acp::ToolCallLocation {
263                path: path.clone(),
264                line: None,
265            }],
266            Self::Grep(Some(GrepToolParams {
267                path: Some(path), ..
268            })) => vec![acp::ToolCallLocation {
269                path: PathBuf::from(path),
270                line: None,
271            }],
272            Self::Task(_)
273            | Self::NotebookRead(None)
274            | Self::NotebookEdit(None)
275            | Self::Edit(None)
276            | Self::MultiEdit(None)
277            | Self::Write(None)
278            | Self::ReadFile(None)
279            | Self::Ls(None)
280            | Self::Glob(_)
281            | Self::Grep(_)
282            | Self::Terminal(_)
283            | Self::WebFetch(_)
284            | Self::WebSearch(_)
285            | Self::TodoWrite(_)
286            | Self::ExitPlanMode(_)
287            | Self::Other { .. } => vec![],
288        }
289    }
290
291    pub fn as_acp(&self, id: acp::ToolCallId) -> acp::ToolCall {
292        acp::ToolCall {
293            id,
294            kind: self.kind(),
295            status: acp::ToolCallStatus::InProgress,
296            title: self.label(),
297            content: self.content(),
298            locations: self.locations(),
299            raw_input: None,
300            raw_output: None,
301        }
302    }
303}
304
305#[derive(Deserialize, JsonSchema, Debug)]
306pub struct EditToolParams {
307    /// The absolute path to the file to read.
308    pub abs_path: PathBuf,
309    /// The old text to replace (must be unique in the file)
310    pub old_text: String,
311    /// The new text.
312    pub new_text: String,
313}
314
315#[derive(Deserialize, JsonSchema, Debug)]
316pub struct ReadToolParams {
317    /// The absolute path to the file to read.
318    pub abs_path: PathBuf,
319    /// Which line to start reading from. Omit to start from the beginning.
320    #[serde(skip_serializing_if = "Option::is_none")]
321    pub offset: Option<u32>,
322    /// How many lines to read. Omit for the whole file.
323    #[serde(skip_serializing_if = "Option::is_none")]
324    pub limit: Option<u32>,
325}
326
327#[derive(Deserialize, JsonSchema, Debug)]
328pub struct WriteToolParams {
329    /// Absolute path for new file
330    pub file_path: PathBuf,
331    /// File content
332    pub content: String,
333}
334
335#[derive(Deserialize, JsonSchema, Debug)]
336pub struct BashToolParams {
337    /// Shell command to execute
338    pub command: String,
339    /// 5-10 word description of what command does
340    #[serde(skip_serializing_if = "Option::is_none")]
341    pub description: Option<String>,
342    /// Timeout in ms (max 600000ms/10min, default 120000ms)
343    #[serde(skip_serializing_if = "Option::is_none")]
344    pub timeout: Option<u32>,
345}
346
347#[derive(Deserialize, JsonSchema, Debug)]
348pub struct GlobToolParams {
349    /// Glob pattern like **/*.js or src/**/*.ts
350    pub pattern: String,
351    /// Directory to search in (omit for current directory)
352    #[serde(skip_serializing_if = "Option::is_none")]
353    pub path: Option<PathBuf>,
354}
355
356impl std::fmt::Display for GlobToolParams {
357    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358        if let Some(path) = &self.path {
359            write!(f, "{}", path.display())?;
360        }
361        write!(f, "{}", self.pattern)
362    }
363}
364
365#[derive(Deserialize, JsonSchema, Debug)]
366pub struct LsToolParams {
367    /// Absolute path to directory
368    pub path: PathBuf,
369    /// Array of glob patterns to ignore
370    #[serde(default, skip_serializing_if = "Vec::is_empty")]
371    pub ignore: Vec<String>,
372}
373
374#[derive(Deserialize, JsonSchema, Debug)]
375pub struct GrepToolParams {
376    /// Regex pattern to search for
377    pub pattern: String,
378    /// File/directory to search (defaults to current directory)
379    #[serde(skip_serializing_if = "Option::is_none")]
380    pub path: Option<String>,
381    /// "content" (shows lines), "files_with_matches" (default), "count"
382    #[serde(skip_serializing_if = "Option::is_none")]
383    pub output_mode: Option<GrepOutputMode>,
384    /// Filter files with glob pattern like "*.js"
385    #[serde(skip_serializing_if = "Option::is_none")]
386    pub glob: Option<String>,
387    /// File type filter like "js", "py", "rust"
388    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
389    pub file_type: Option<String>,
390    /// Case insensitive search
391    #[serde(rename = "-i", default, skip_serializing_if = "is_false")]
392    pub case_insensitive: bool,
393    /// Show line numbers (content mode only)
394    #[serde(rename = "-n", default, skip_serializing_if = "is_false")]
395    pub line_numbers: bool,
396    /// Lines after match (content mode only)
397    #[serde(rename = "-A", skip_serializing_if = "Option::is_none")]
398    pub after_context: Option<u32>,
399    /// Lines before match (content mode only)
400    #[serde(rename = "-B", skip_serializing_if = "Option::is_none")]
401    pub before_context: Option<u32>,
402    /// Lines before and after match (content mode only)
403    #[serde(rename = "-C", skip_serializing_if = "Option::is_none")]
404    pub context: Option<u32>,
405    /// Enable multiline/cross-line matching
406    #[serde(default, skip_serializing_if = "is_false")]
407    pub multiline: bool,
408    /// Limit output to first N results
409    #[serde(skip_serializing_if = "Option::is_none")]
410    pub head_limit: Option<u32>,
411}
412
413impl std::fmt::Display for GrepToolParams {
414    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415        write!(f, "grep")?;
416
417        // Boolean flags
418        if self.case_insensitive {
419            write!(f, " -i")?;
420        }
421        if self.line_numbers {
422            write!(f, " -n")?;
423        }
424
425        // Context options
426        if let Some(after) = self.after_context {
427            write!(f, " -A {}", after)?;
428        }
429        if let Some(before) = self.before_context {
430            write!(f, " -B {}", before)?;
431        }
432        if let Some(context) = self.context {
433            write!(f, " -C {}", context)?;
434        }
435
436        // Output mode
437        if let Some(mode) = &self.output_mode {
438            match mode {
439                GrepOutputMode::FilesWithMatches => write!(f, " -l")?,
440                GrepOutputMode::Count => write!(f, " -c")?,
441                GrepOutputMode::Content => {} // Default mode
442            }
443        }
444
445        // Head limit
446        if let Some(limit) = self.head_limit {
447            write!(f, " | head -{}", limit)?;
448        }
449
450        // Glob pattern
451        if let Some(glob) = &self.glob {
452            write!(f, " --include=\"{}\"", glob)?;
453        }
454
455        // File type
456        if let Some(file_type) = &self.file_type {
457            write!(f, " --type={}", file_type)?;
458        }
459
460        // Multiline
461        if self.multiline {
462            write!(f, " -P")?; // Perl-compatible regex for multiline
463        }
464
465        // Pattern (escaped if contains special characters)
466        write!(f, " \"{}\"", self.pattern)?;
467
468        // Path
469        if let Some(path) = &self.path {
470            write!(f, " {}", path)?;
471        }
472
473        Ok(())
474    }
475}
476
477#[derive(Default, Deserialize, Serialize, JsonSchema, strum::Display, Debug)]
478#[serde(rename_all = "snake_case")]
479pub enum TodoPriority {
480    High,
481    #[default]
482    Medium,
483    Low,
484}
485
486impl Into<acp::PlanEntryPriority> for TodoPriority {
487    fn into(self) -> acp::PlanEntryPriority {
488        match self {
489            TodoPriority::High => acp::PlanEntryPriority::High,
490            TodoPriority::Medium => acp::PlanEntryPriority::Medium,
491            TodoPriority::Low => acp::PlanEntryPriority::Low,
492        }
493    }
494}
495
496#[derive(Deserialize, Serialize, JsonSchema, Debug)]
497#[serde(rename_all = "snake_case")]
498pub enum TodoStatus {
499    Pending,
500    InProgress,
501    Completed,
502}
503
504impl Into<acp::PlanEntryStatus> for TodoStatus {
505    fn into(self) -> acp::PlanEntryStatus {
506        match self {
507            TodoStatus::Pending => acp::PlanEntryStatus::Pending,
508            TodoStatus::InProgress => acp::PlanEntryStatus::InProgress,
509            TodoStatus::Completed => acp::PlanEntryStatus::Completed,
510        }
511    }
512}
513
514#[derive(Deserialize, Serialize, JsonSchema, Debug)]
515pub struct Todo {
516    /// Task description
517    pub content: String,
518    /// Current status of the todo
519    pub status: TodoStatus,
520    /// Priority level of the todo
521    #[serde(default)]
522    pub priority: TodoPriority,
523}
524
525impl Into<acp::PlanEntry> for Todo {
526    fn into(self) -> acp::PlanEntry {
527        acp::PlanEntry {
528            content: self.content,
529            priority: self.priority.into(),
530            status: self.status.into(),
531        }
532    }
533}
534
535#[derive(Deserialize, JsonSchema, Debug)]
536pub struct TodoWriteToolParams {
537    pub todos: Vec<Todo>,
538}
539
540#[derive(Deserialize, JsonSchema, Debug)]
541pub struct ExitPlanModeToolParams {
542    /// Implementation plan in markdown format
543    pub plan: String,
544}
545
546#[derive(Deserialize, JsonSchema, Debug)]
547pub struct TaskToolParams {
548    /// Short 3-5 word description of task
549    pub description: String,
550    /// Detailed task for agent to perform
551    pub prompt: String,
552}
553
554#[derive(Deserialize, JsonSchema, Debug)]
555pub struct NotebookReadToolParams {
556    /// Absolute path to .ipynb file
557    pub notebook_path: PathBuf,
558    /// Specific cell ID to read
559    #[serde(skip_serializing_if = "Option::is_none")]
560    pub cell_id: Option<String>,
561}
562
563#[derive(Deserialize, Serialize, JsonSchema, Debug)]
564#[serde(rename_all = "snake_case")]
565pub enum CellType {
566    Code,
567    Markdown,
568}
569
570#[derive(Deserialize, Serialize, JsonSchema, Debug)]
571#[serde(rename_all = "snake_case")]
572pub enum EditMode {
573    Replace,
574    Insert,
575    Delete,
576}
577
578#[derive(Deserialize, JsonSchema, Debug)]
579pub struct NotebookEditToolParams {
580    /// Absolute path to .ipynb file
581    pub notebook_path: PathBuf,
582    /// New cell content
583    pub new_source: String,
584    /// Cell ID to edit
585    #[serde(skip_serializing_if = "Option::is_none")]
586    pub cell_id: Option<String>,
587    /// Type of cell (code or markdown)
588    #[serde(skip_serializing_if = "Option::is_none")]
589    pub cell_type: Option<CellType>,
590    /// Edit operation mode
591    #[serde(skip_serializing_if = "Option::is_none")]
592    pub edit_mode: Option<EditMode>,
593}
594
595#[derive(Deserialize, Serialize, JsonSchema, Debug)]
596pub struct MultiEditItem {
597    /// The text to search for and replace
598    pub old_string: String,
599    /// The replacement text
600    pub new_string: String,
601    /// Whether to replace all occurrences or just the first
602    #[serde(default, skip_serializing_if = "is_false")]
603    pub replace_all: bool,
604}
605
606#[derive(Deserialize, JsonSchema, Debug)]
607pub struct MultiEditToolParams {
608    /// Absolute path to file
609    pub file_path: PathBuf,
610    /// List of edits to apply
611    pub edits: Vec<MultiEditItem>,
612}
613
614fn is_false(v: &bool) -> bool {
615    !*v
616}
617
618#[derive(Deserialize, JsonSchema, Debug)]
619#[serde(rename_all = "snake_case")]
620pub enum GrepOutputMode {
621    Content,
622    FilesWithMatches,
623    Count,
624}
625
626#[derive(Deserialize, JsonSchema, Debug)]
627pub struct WebFetchToolParams {
628    /// Valid URL to fetch
629    #[serde(rename = "url")]
630    pub url: String,
631    /// What to extract from content
632    pub prompt: String,
633}
634
635#[derive(Deserialize, JsonSchema, Debug)]
636pub struct WebSearchToolParams {
637    /// Search query (min 2 chars)
638    pub query: String,
639    /// Only include these domains
640    #[serde(default, skip_serializing_if = "Vec::is_empty")]
641    pub allowed_domains: Vec<String>,
642    /// Exclude these domains
643    #[serde(default, skip_serializing_if = "Vec::is_empty")]
644    pub blocked_domains: Vec<String>,
645}
646
647impl std::fmt::Display for WebSearchToolParams {
648    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
649        write!(f, "\"{}\"", self.query)?;
650
651        if !self.allowed_domains.is_empty() {
652            write!(f, " (allowed: {})", self.allowed_domains.join(", "))?;
653        }
654
655        if !self.blocked_domains.is_empty() {
656            write!(f, " (blocked: {})", self.blocked_domains.join(", "))?;
657        }
658
659        Ok(())
660    }
661}