tools.rs

  1use std::path::PathBuf;
  2
  3use agent_client_protocol as acp;
  4use itertools::Itertools;
  5use schemars::JsonSchema;
  6use serde::{Deserialize, Serialize};
  7use util::ResultExt;
  8
  9pub enum ClaudeTool {
 10    Task(Option<TaskToolParams>),
 11    NotebookRead(Option<NotebookReadToolParams>),
 12    NotebookEdit(Option<NotebookEditToolParams>),
 13    Edit(Option<EditToolParams>),
 14    MultiEdit(Option<MultiEditToolParams>),
 15    ReadFile(Option<ReadToolParams>),
 16    Write(Option<WriteToolParams>),
 17    Ls(Option<LsToolParams>),
 18    Glob(Option<GlobToolParams>),
 19    Grep(Option<GrepToolParams>),
 20    Terminal(Option<BashToolParams>),
 21    WebFetch(Option<WebFetchToolParams>),
 22    WebSearch(Option<WebSearchToolParams>),
 23    TodoWrite(Option<TodoWriteToolParams>),
 24    ExitPlanMode(Option<ExitPlanModeToolParams>),
 25    Other {
 26        name: String,
 27        input: serde_json::Value,
 28    },
 29}
 30
 31impl ClaudeTool {
 32    pub fn infer(tool_name: &str, input: serde_json::Value) -> Self {
 33        match tool_name {
 34            // Known tools
 35            "mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()),
 36            "mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()),
 37            "MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()),
 38            "Write" => Self::Write(serde_json::from_value(input).log_err()),
 39            "LS" => Self::Ls(serde_json::from_value(input).log_err()),
 40            "Glob" => Self::Glob(serde_json::from_value(input).log_err()),
 41            "Grep" => Self::Grep(serde_json::from_value(input).log_err()),
 42            "Bash" => Self::Terminal(serde_json::from_value(input).log_err()),
 43            "WebFetch" => Self::WebFetch(serde_json::from_value(input).log_err()),
 44            "WebSearch" => Self::WebSearch(serde_json::from_value(input).log_err()),
 45            "TodoWrite" => Self::TodoWrite(serde_json::from_value(input).log_err()),
 46            "exit_plan_mode" => Self::ExitPlanMode(serde_json::from_value(input).log_err()),
 47            "Task" => Self::Task(serde_json::from_value(input).log_err()),
 48            "NotebookRead" => Self::NotebookRead(serde_json::from_value(input).log_err()),
 49            "NotebookEdit" => Self::NotebookEdit(serde_json::from_value(input).log_err()),
 50            // Inferred from name
 51            _ => {
 52                let tool_name = tool_name.to_lowercase();
 53
 54                if tool_name.contains("edit") || tool_name.contains("write") {
 55                    Self::Edit(None)
 56                } else if tool_name.contains("terminal") {
 57                    Self::Terminal(None)
 58                } else {
 59                    Self::Other {
 60                        name: tool_name.to_string(),
 61                        input,
 62                    }
 63                }
 64            }
 65        }
 66    }
 67
 68    pub fn label(&self) -> String {
 69        match &self {
 70            Self::Task(Some(params)) => params.description.clone(),
 71            Self::Task(None) => "Task".into(),
 72            Self::NotebookRead(Some(params)) => {
 73                format!("Read Notebook {}", params.notebook_path.display())
 74            }
 75            Self::NotebookRead(None) => "Read Notebook".into(),
 76            Self::NotebookEdit(Some(params)) => {
 77                format!("Edit Notebook {}", params.notebook_path.display())
 78            }
 79            Self::NotebookEdit(None) => "Edit Notebook".into(),
 80            Self::Terminal(Some(params)) => format!("`{}`", params.command),
 81            Self::Terminal(None) => "Terminal".into(),
 82            Self::ReadFile(_) => "Read File".into(),
 83            Self::Ls(Some(params)) => {
 84                format!("List Directory {}", params.path.display())
 85            }
 86            Self::Ls(None) => "List Directory".into(),
 87            Self::Edit(Some(params)) => {
 88                format!("Edit {}", params.abs_path.display())
 89            }
 90            Self::Edit(None) => "Edit".into(),
 91            Self::MultiEdit(Some(params)) => {
 92                format!("Multi Edit {}", params.file_path.display())
 93            }
 94            Self::MultiEdit(None) => "Multi Edit".into(),
 95            Self::Write(Some(params)) => {
 96                format!("Write {}", params.file_path.display())
 97            }
 98            Self::Write(None) => "Write".into(),
 99            Self::Glob(Some(params)) => {
100                format!("Glob `{params}`")
101            }
102            Self::Glob(None) => "Glob".into(),
103            Self::Grep(Some(params)) => format!("`{params}`"),
104            Self::Grep(None) => "Grep".into(),
105            Self::WebFetch(Some(params)) => format!("Fetch {}", params.url),
106            Self::WebFetch(None) => "Fetch".into(),
107            Self::WebSearch(Some(params)) => format!("Web Search: {}", params),
108            Self::WebSearch(None) => "Web Search".into(),
109            Self::TodoWrite(Some(params)) => format!(
110                "Update TODOs: {}",
111                params.todos.iter().map(|todo| &todo.content).join(", ")
112            ),
113            Self::TodoWrite(None) => "Update TODOs".into(),
114            Self::ExitPlanMode(_) => "Exit Plan Mode".into(),
115            Self::Other { name, .. } => name.clone(),
116        }
117    }
118    pub fn content(&self) -> Vec<acp::ToolCallContent> {
119        match &self {
120            Self::Other { input, .. } => vec![
121                format!(
122                    "```json\n{}```",
123                    serde_json::to_string_pretty(&input).unwrap_or("{}".to_string())
124                )
125                .into(),
126            ],
127            Self::Task(Some(params)) => vec![params.prompt.clone().into()],
128            Self::NotebookRead(Some(params)) => {
129                vec![params.notebook_path.display().to_string().into()]
130            }
131            Self::NotebookEdit(Some(params)) => vec![params.new_source.clone().into()],
132            Self::Terminal(Some(params)) => vec![
133                format!(
134                    "`{}`\n\n{}",
135                    params.command,
136                    params.description.as_deref().unwrap_or_default()
137                )
138                .into(),
139            ],
140            Self::ReadFile(Some(params)) => vec![params.abs_path.display().to_string().into()],
141            Self::Ls(Some(params)) => vec![params.path.display().to_string().into()],
142            Self::Glob(Some(params)) => vec![params.to_string().into()],
143            Self::Grep(Some(params)) => vec![format!("`{params}`").into()],
144            Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()],
145            Self::WebSearch(Some(params)) => vec![params.to_string().into()],
146            Self::TodoWrite(Some(params)) => vec![
147                params
148                    .todos
149                    .iter()
150                    .map(|todo| {
151                        format!(
152                            "- {} {}: {}",
153                            match todo.status {
154                                TodoStatus::Completed => "",
155                                TodoStatus::InProgress => "🚧",
156                                TodoStatus::Pending => "",
157                            },
158                            todo.priority,
159                            todo.content
160                        )
161                    })
162                    .join("\n")
163                    .into(),
164            ],
165            Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()],
166            Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff {
167                diff: acp::Diff {
168                    path: params.abs_path.clone(),
169                    old_text: Some(params.old_text.clone()),
170                    new_text: params.new_text.clone(),
171                },
172            }],
173            Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff {
174                diff: acp::Diff {
175                    path: params.file_path.clone(),
176                    old_text: None,
177                    new_text: params.content.clone(),
178                },
179            }],
180            Self::MultiEdit(Some(params)) => {
181                // todo: show multiple edits in a multibuffer?
182                params
183                    .edits
184                    .first()
185                    .map(|edit| {
186                        vec![acp::ToolCallContent::Diff {
187                            diff: acp::Diff {
188                                path: params.file_path.clone(),
189                                old_text: Some(edit.old_string.clone()),
190                                new_text: edit.new_string.clone(),
191                            },
192                        }]
193                    })
194                    .unwrap_or_default()
195            }
196            Self::Task(None)
197            | Self::NotebookRead(None)
198            | Self::NotebookEdit(None)
199            | Self::Terminal(None)
200            | Self::ReadFile(None)
201            | Self::Ls(None)
202            | Self::Glob(None)
203            | Self::Grep(None)
204            | Self::WebFetch(None)
205            | Self::WebSearch(None)
206            | Self::TodoWrite(None)
207            | Self::ExitPlanMode(None)
208            | Self::Edit(None)
209            | Self::Write(None)
210            | Self::MultiEdit(None) => vec![],
211        }
212    }
213
214    pub fn kind(&self) -> acp::ToolKind {
215        match self {
216            Self::Task(_) => acp::ToolKind::Think,
217            Self::NotebookRead(_) => acp::ToolKind::Read,
218            Self::NotebookEdit(_) => acp::ToolKind::Edit,
219            Self::Edit(_) => acp::ToolKind::Edit,
220            Self::MultiEdit(_) => acp::ToolKind::Edit,
221            Self::Write(_) => acp::ToolKind::Edit,
222            Self::ReadFile(_) => acp::ToolKind::Read,
223            Self::Ls(_) => acp::ToolKind::Search,
224            Self::Glob(_) => acp::ToolKind::Search,
225            Self::Grep(_) => acp::ToolKind::Search,
226            Self::Terminal(_) => acp::ToolKind::Execute,
227            Self::WebSearch(_) => acp::ToolKind::Search,
228            Self::WebFetch(_) => acp::ToolKind::Fetch,
229            Self::TodoWrite(_) => acp::ToolKind::Think,
230            Self::ExitPlanMode(_) => acp::ToolKind::Think,
231            Self::Other { .. } => acp::ToolKind::Other,
232        }
233    }
234
235    pub fn locations(&self) -> Vec<acp::ToolCallLocation> {
236        match &self {
237            Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp::ToolCallLocation {
238                path: abs_path.clone(),
239                line: None,
240            }],
241            Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => {
242                vec![acp::ToolCallLocation {
243                    path: file_path.clone(),
244                    line: None,
245                }]
246            }
247            Self::Write(Some(WriteToolParams { file_path, .. })) => {
248                vec![acp::ToolCallLocation {
249                    path: file_path.clone(),
250                    line: None,
251                }]
252            }
253            Self::ReadFile(Some(ReadToolParams {
254                abs_path, offset, ..
255            })) => vec![acp::ToolCallLocation {
256                path: abs_path.clone(),
257                line: *offset,
258            }],
259            Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => {
260                vec![acp::ToolCallLocation {
261                    path: notebook_path.clone(),
262                    line: None,
263                }]
264            }
265            Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => {
266                vec![acp::ToolCallLocation {
267                    path: notebook_path.clone(),
268                    line: None,
269                }]
270            }
271            Self::Glob(Some(GlobToolParams {
272                path: Some(path), ..
273            })) => vec![acp::ToolCallLocation {
274                path: path.clone(),
275                line: None,
276            }],
277            Self::Ls(Some(LsToolParams { path, .. })) => vec![acp::ToolCallLocation {
278                path: path.clone(),
279                line: None,
280            }],
281            Self::Grep(Some(GrepToolParams {
282                path: Some(path), ..
283            })) => vec![acp::ToolCallLocation {
284                path: PathBuf::from(path),
285                line: None,
286            }],
287            Self::Task(_)
288            | Self::NotebookRead(None)
289            | Self::NotebookEdit(None)
290            | Self::Edit(None)
291            | Self::MultiEdit(None)
292            | Self::Write(None)
293            | Self::ReadFile(None)
294            | Self::Ls(None)
295            | Self::Glob(_)
296            | Self::Grep(_)
297            | Self::Terminal(_)
298            | Self::WebFetch(_)
299            | Self::WebSearch(_)
300            | Self::TodoWrite(_)
301            | Self::ExitPlanMode(_)
302            | Self::Other { .. } => vec![],
303        }
304    }
305
306    pub fn as_acp(&self, id: acp::ToolCallId) -> acp::ToolCall {
307        acp::ToolCall {
308            id,
309            kind: self.kind(),
310            status: acp::ToolCallStatus::InProgress,
311            label: self.label(),
312            content: self.content(),
313            locations: self.locations(),
314        }
315    }
316}
317
318#[derive(Deserialize, JsonSchema, Debug)]
319pub struct EditToolParams {
320    /// The absolute path to the file to read.
321    pub abs_path: PathBuf,
322    /// The old text to replace (must be unique in the file)
323    pub old_text: String,
324    /// The new text.
325    pub new_text: String,
326}
327
328#[derive(Deserialize, JsonSchema, Debug)]
329pub struct ReadToolParams {
330    /// The absolute path to the file to read.
331    pub abs_path: PathBuf,
332    /// Which line to start reading from. Omit to start from the beginning.
333    #[serde(skip_serializing_if = "Option::is_none")]
334    pub offset: Option<u32>,
335    /// How many lines to read. Omit for the whole file.
336    #[serde(skip_serializing_if = "Option::is_none")]
337    pub limit: Option<u32>,
338}
339
340#[derive(Deserialize, JsonSchema, Debug)]
341pub struct WriteToolParams {
342    /// Absolute path for new file
343    pub file_path: PathBuf,
344    /// File content
345    pub content: String,
346}
347
348#[derive(Deserialize, JsonSchema, Debug)]
349pub struct BashToolParams {
350    /// Shell command to execute
351    pub command: String,
352    /// 5-10 word description of what command does
353    #[serde(skip_serializing_if = "Option::is_none")]
354    pub description: Option<String>,
355    /// Timeout in ms (max 600000ms/10min, default 120000ms)
356    #[serde(skip_serializing_if = "Option::is_none")]
357    pub timeout: Option<u32>,
358}
359
360#[derive(Deserialize, JsonSchema, Debug)]
361pub struct GlobToolParams {
362    /// Glob pattern like **/*.js or src/**/*.ts
363    pub pattern: String,
364    /// Directory to search in (omit for current directory)
365    #[serde(skip_serializing_if = "Option::is_none")]
366    pub path: Option<PathBuf>,
367}
368
369impl std::fmt::Display for GlobToolParams {
370    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371        if let Some(path) = &self.path {
372            write!(f, "{}", path.display())?;
373        }
374        write!(f, "{}", self.pattern)
375    }
376}
377
378#[derive(Deserialize, JsonSchema, Debug)]
379pub struct LsToolParams {
380    /// Absolute path to directory
381    pub path: PathBuf,
382    /// Array of glob patterns to ignore
383    #[serde(default, skip_serializing_if = "Vec::is_empty")]
384    pub ignore: Vec<String>,
385}
386
387#[derive(Deserialize, JsonSchema, Debug)]
388pub struct GrepToolParams {
389    /// Regex pattern to search for
390    pub pattern: String,
391    /// File/directory to search (defaults to current directory)
392    #[serde(skip_serializing_if = "Option::is_none")]
393    pub path: Option<String>,
394    /// "content" (shows lines), "files_with_matches" (default), "count"
395    #[serde(skip_serializing_if = "Option::is_none")]
396    pub output_mode: Option<GrepOutputMode>,
397    /// Filter files with glob pattern like "*.js"
398    #[serde(skip_serializing_if = "Option::is_none")]
399    pub glob: Option<String>,
400    /// File type filter like "js", "py", "rust"
401    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
402    pub file_type: Option<String>,
403    /// Case insensitive search
404    #[serde(rename = "-i", default, skip_serializing_if = "is_false")]
405    pub case_insensitive: bool,
406    /// Show line numbers (content mode only)
407    #[serde(rename = "-n", default, skip_serializing_if = "is_false")]
408    pub line_numbers: bool,
409    /// Lines after match (content mode only)
410    #[serde(rename = "-A", skip_serializing_if = "Option::is_none")]
411    pub after_context: Option<u32>,
412    /// Lines before match (content mode only)
413    #[serde(rename = "-B", skip_serializing_if = "Option::is_none")]
414    pub before_context: Option<u32>,
415    /// Lines before and after match (content mode only)
416    #[serde(rename = "-C", skip_serializing_if = "Option::is_none")]
417    pub context: Option<u32>,
418    /// Enable multiline/cross-line matching
419    #[serde(default, skip_serializing_if = "is_false")]
420    pub multiline: bool,
421    /// Limit output to first N results
422    #[serde(skip_serializing_if = "Option::is_none")]
423    pub head_limit: Option<u32>,
424}
425
426impl std::fmt::Display for GrepToolParams {
427    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
428        write!(f, "grep")?;
429
430        // Boolean flags
431        if self.case_insensitive {
432            write!(f, " -i")?;
433        }
434        if self.line_numbers {
435            write!(f, " -n")?;
436        }
437
438        // Context options
439        if let Some(after) = self.after_context {
440            write!(f, " -A {}", after)?;
441        }
442        if let Some(before) = self.before_context {
443            write!(f, " -B {}", before)?;
444        }
445        if let Some(context) = self.context {
446            write!(f, " -C {}", context)?;
447        }
448
449        // Output mode
450        if let Some(mode) = &self.output_mode {
451            match mode {
452                GrepOutputMode::FilesWithMatches => write!(f, " -l")?,
453                GrepOutputMode::Count => write!(f, " -c")?,
454                GrepOutputMode::Content => {} // Default mode
455            }
456        }
457
458        // Head limit
459        if let Some(limit) = self.head_limit {
460            write!(f, " | head -{}", limit)?;
461        }
462
463        // Glob pattern
464        if let Some(glob) = &self.glob {
465            write!(f, " --include=\"{}\"", glob)?;
466        }
467
468        // File type
469        if let Some(file_type) = &self.file_type {
470            write!(f, " --type={}", file_type)?;
471        }
472
473        // Multiline
474        if self.multiline {
475            write!(f, " -P")?; // Perl-compatible regex for multiline
476        }
477
478        // Pattern (escaped if contains special characters)
479        write!(f, " \"{}\"", self.pattern)?;
480
481        // Path
482        if let Some(path) = &self.path {
483            write!(f, " {}", path)?;
484        }
485
486        Ok(())
487    }
488}
489
490#[derive(Deserialize, Serialize, JsonSchema, strum::Display, Debug)]
491#[serde(rename_all = "snake_case")]
492pub enum TodoPriority {
493    High,
494    Medium,
495    Low,
496}
497
498impl Into<acp::PlanEntryPriority> for TodoPriority {
499    fn into(self) -> acp::PlanEntryPriority {
500        match self {
501            TodoPriority::High => acp::PlanEntryPriority::High,
502            TodoPriority::Medium => acp::PlanEntryPriority::Medium,
503            TodoPriority::Low => acp::PlanEntryPriority::Low,
504        }
505    }
506}
507
508#[derive(Deserialize, Serialize, JsonSchema, Debug)]
509#[serde(rename_all = "snake_case")]
510pub enum TodoStatus {
511    Pending,
512    InProgress,
513    Completed,
514}
515
516impl Into<acp::PlanEntryStatus> for TodoStatus {
517    fn into(self) -> acp::PlanEntryStatus {
518        match self {
519            TodoStatus::Pending => acp::PlanEntryStatus::Pending,
520            TodoStatus::InProgress => acp::PlanEntryStatus::InProgress,
521            TodoStatus::Completed => acp::PlanEntryStatus::Completed,
522        }
523    }
524}
525
526#[derive(Deserialize, Serialize, JsonSchema, Debug)]
527pub struct Todo {
528    /// Unique identifier
529    pub id: String,
530    /// Task description
531    pub content: String,
532    /// Priority level of the todo
533    pub priority: TodoPriority,
534    /// Current status of the todo
535    pub status: TodoStatus,
536}
537
538impl Into<acp::PlanEntry> for Todo {
539    fn into(self) -> acp::PlanEntry {
540        acp::PlanEntry {
541            content: self.content,
542            priority: self.priority.into(),
543            status: self.status.into(),
544        }
545    }
546}
547
548#[derive(Deserialize, JsonSchema, Debug)]
549pub struct TodoWriteToolParams {
550    pub todos: Vec<Todo>,
551}
552
553#[derive(Deserialize, JsonSchema, Debug)]
554pub struct ExitPlanModeToolParams {
555    /// Implementation plan in markdown format
556    pub plan: String,
557}
558
559#[derive(Deserialize, JsonSchema, Debug)]
560pub struct TaskToolParams {
561    /// Short 3-5 word description of task
562    pub description: String,
563    /// Detailed task for agent to perform
564    pub prompt: String,
565}
566
567#[derive(Deserialize, JsonSchema, Debug)]
568pub struct NotebookReadToolParams {
569    /// Absolute path to .ipynb file
570    pub notebook_path: PathBuf,
571    /// Specific cell ID to read
572    #[serde(skip_serializing_if = "Option::is_none")]
573    pub cell_id: Option<String>,
574}
575
576#[derive(Deserialize, Serialize, JsonSchema, Debug)]
577#[serde(rename_all = "snake_case")]
578pub enum CellType {
579    Code,
580    Markdown,
581}
582
583#[derive(Deserialize, Serialize, JsonSchema, Debug)]
584#[serde(rename_all = "snake_case")]
585pub enum EditMode {
586    Replace,
587    Insert,
588    Delete,
589}
590
591#[derive(Deserialize, JsonSchema, Debug)]
592pub struct NotebookEditToolParams {
593    /// Absolute path to .ipynb file
594    pub notebook_path: PathBuf,
595    /// New cell content
596    pub new_source: String,
597    /// Cell ID to edit
598    #[serde(skip_serializing_if = "Option::is_none")]
599    pub cell_id: Option<String>,
600    /// Type of cell (code or markdown)
601    #[serde(skip_serializing_if = "Option::is_none")]
602    pub cell_type: Option<CellType>,
603    /// Edit operation mode
604    #[serde(skip_serializing_if = "Option::is_none")]
605    pub edit_mode: Option<EditMode>,
606}
607
608#[derive(Deserialize, Serialize, JsonSchema, Debug)]
609pub struct MultiEditItem {
610    /// The text to search for and replace
611    pub old_string: String,
612    /// The replacement text
613    pub new_string: String,
614    /// Whether to replace all occurrences or just the first
615    #[serde(default, skip_serializing_if = "is_false")]
616    pub replace_all: bool,
617}
618
619#[derive(Deserialize, JsonSchema, Debug)]
620pub struct MultiEditToolParams {
621    /// Absolute path to file
622    pub file_path: PathBuf,
623    /// List of edits to apply
624    pub edits: Vec<MultiEditItem>,
625}
626
627fn is_false(v: &bool) -> bool {
628    !*v
629}
630
631#[derive(Deserialize, JsonSchema, Debug)]
632#[serde(rename_all = "snake_case")]
633pub enum GrepOutputMode {
634    Content,
635    FilesWithMatches,
636    Count,
637}
638
639#[derive(Deserialize, JsonSchema, Debug)]
640pub struct WebFetchToolParams {
641    /// Valid URL to fetch
642    #[serde(rename = "url")]
643    pub url: String,
644    /// What to extract from content
645    pub prompt: String,
646}
647
648#[derive(Deserialize, JsonSchema, Debug)]
649pub struct WebSearchToolParams {
650    /// Search query (min 2 chars)
651    pub query: String,
652    /// Only include these domains
653    #[serde(default, skip_serializing_if = "Vec::is_empty")]
654    pub allowed_domains: Vec<String>,
655    /// Exclude these domains
656    #[serde(default, skip_serializing_if = "Vec::is_empty")]
657    pub blocked_domains: Vec<String>,
658}
659
660impl std::fmt::Display for WebSearchToolParams {
661    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
662        write!(f, "\"{}\"", self.query)?;
663
664        if !self.allowed_domains.is_empty() {
665            write!(f, " (allowed: {})", self.allowed_domains.join(", "))?;
666        }
667
668        if !self.blocked_domains.is_empty() {
669            write!(f, " (blocked: {})", self.blocked_domains.join(", "))?;
670        }
671
672        Ok(())
673    }
674}