1use std::{collections::HashSet, path::Path};
2
3use anyhow::Result;
4use assistant_tools::{CreateFileToolInput, EditFileToolInput, ReadFileToolInput};
5use async_trait::async_trait;
6
7use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion, LanguageServer};
8
9pub struct AddArgToTraitMethod;
10
11#[async_trait(?Send)]
12impl Example for AddArgToTraitMethod {
13 fn meta(&self) -> ExampleMetadata {
14 ExampleMetadata {
15 name: "add_arg_to_trait_method".to_string(),
16 url: "https://github.com/zed-industries/zed.git".to_string(),
17 revision: "f69aeb6311dde3c0b8979c293d019d66498d54f2".to_string(),
18 language_server: Some(LanguageServer {
19 file_extension: "rs".to_string(),
20 allow_preexisting_diagnostics: false,
21 }),
22 max_assertions: None,
23 }
24 }
25
26 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
27 const FILENAME: &str = "assistant_tool.rs";
28 cx.push_user_message(format!(
29 r#"
30 Add a `window: Option<gpui::AnyWindowHandle>` argument to the `Tool::run` trait method in {FILENAME},
31 and update all the implementations of the trait and call sites accordingly.
32 "#
33 ));
34
35 let response = cx.run_to_end().await?;
36
37 // Reads files before it edits them
38
39 let mut read_files = HashSet::new();
40
41 for tool_use in response.tool_uses() {
42 match tool_use.name.as_str() {
43 "read_file" => {
44 if let Ok(input) = tool_use.parse_input::<ReadFileToolInput>() {
45 read_files.insert(input.path);
46 }
47 }
48 "create_file" => {
49 if let Ok(input) = tool_use.parse_input::<CreateFileToolInput>() {
50 read_files.insert(input.path);
51 }
52 }
53 "edit_file" => {
54 if let Ok(input) = tool_use.parse_input::<EditFileToolInput>() {
55 cx.assert(
56 read_files.contains(input.path.to_str().unwrap()),
57 format!(
58 "Read before edit: {}",
59 &input.path.file_stem().unwrap().to_str().unwrap()
60 ),
61 )
62 .ok();
63 }
64 }
65 _ => {}
66 }
67 }
68
69 // Adds ignored argument to all but `batch_tool`
70
71 let add_ignored_window_paths = &[
72 "code_action_tool",
73 "code_symbols_tool",
74 "contents_tool",
75 "copy_path_tool",
76 "create_directory_tool",
77 "create_file_tool",
78 "delete_path_tool",
79 "diagnostics_tool",
80 "edit_file_tool",
81 "fetch_tool",
82 "grep_tool",
83 "list_directory_tool",
84 "move_path_tool",
85 "now_tool",
86 "open_tool",
87 "path_search_tool",
88 "read_file_tool",
89 "rename_tool",
90 "symbol_info_tool",
91 "terminal_tool",
92 "thinking_tool",
93 "web_search_tool",
94 ];
95
96 let edits = cx.edits();
97
98 for tool_name in add_ignored_window_paths {
99 let path_str = format!("crates/assistant_tools/src/{}.rs", tool_name);
100 let edits = edits.get(Path::new(&path_str));
101
102 let ignored = edits.map_or(false, |edits| {
103 edits.has_added_line(" _window: Option<gpui::AnyWindowHandle>,\n")
104 });
105 let uningored = edits.map_or(false, |edits| {
106 edits.has_added_line(" window: Option<gpui::AnyWindowHandle>,\n")
107 });
108
109 cx.assert(ignored || uningored, format!("Argument: {}", tool_name))
110 .ok();
111
112 cx.assert(ignored, format!("`_` prefix: {}", tool_name))
113 .ok();
114 }
115
116 // Adds unignored argument to `batch_tool`
117
118 let batch_tool_edits = edits.get(Path::new("crates/assistant_tools/src/batch_tool.rs"));
119
120 cx.assert(
121 batch_tool_edits.map_or(false, |edits| {
122 edits.has_added_line(" window: Option<gpui::AnyWindowHandle>,\n")
123 }),
124 "Argument: batch_tool",
125 )
126 .ok();
127
128 Ok(())
129 }
130
131 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
132 vec![
133 JudgeAssertion {
134 id: "batch tool passes window to each".to_string(),
135 description:
136 "batch_tool is modified to pass a clone of the window to each tool it calls."
137 .to_string(),
138 },
139 JudgeAssertion {
140 id: "tool tests updated".to_string(),
141 description:
142 "tool tests are updated to pass the new `window` argument (`None` is ok)."
143 .to_string(),
144 },
145 ]
146 }
147}