1use crate::{
2 Templates,
3 edit_agent::{EditAgent, EditAgentOutputEvent},
4 edit_file_tool::EditFileToolCard,
5 schema::json_schema_for,
6};
7use anyhow::{Context as _, Result, anyhow};
8use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult};
9use futures::StreamExt;
10use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
11use indoc::formatdoc;
12use language_model::{
13 LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
14};
15use project::Project;
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use std::{path::PathBuf, sync::Arc};
19use ui::prelude::*;
20use util::ResultExt;
21
22pub struct StreamingEditFileTool;
23
24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
25pub struct StreamingEditFileToolInput {
26 /// A one-line, user-friendly markdown description of the edit. This will be
27 /// shown in the UI and also passed to another model to perform the edit.
28 ///
29 /// Be terse, but also descriptive in what you want to achieve with this
30 /// edit. Avoid generic instructions.
31 ///
32 /// NEVER mention the file path in this description.
33 ///
34 /// <example>Fix API endpoint URLs</example>
35 /// <example>Update copyright year in `page_footer`</example>
36 ///
37 /// Make sure to include this field before all the others in the input object
38 /// so that we can display it immediately.
39 pub display_description: String,
40
41 /// The full path of the file to modify in the project.
42 ///
43 /// WARNING: When specifying which file path need changing, you MUST
44 /// start each path with one of the project's root directories.
45 ///
46 /// The following examples assume we have two root directories in the project:
47 /// - backend
48 /// - frontend
49 ///
50 /// <example>
51 /// `backend/src/main.rs`
52 ///
53 /// Notice how the file path starts with root-1. Without that, the path
54 /// would be ambiguous and the call would fail!
55 /// </example>
56 ///
57 /// <example>
58 /// `frontend/db.js`
59 /// </example>
60 pub path: PathBuf,
61}
62
63#[derive(Debug, Serialize, Deserialize, JsonSchema)]
64struct PartialInput {
65 #[serde(default)]
66 path: String,
67 #[serde(default)]
68 display_description: String,
69}
70
71const DEFAULT_UI_TEXT: &str = "Editing file";
72
73impl Tool for StreamingEditFileTool {
74 fn name(&self) -> String {
75 "edit_file".into()
76 }
77
78 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
79 false
80 }
81
82 fn description(&self) -> String {
83 include_str!("streaming_edit_file_tool/description.md").to_string()
84 }
85
86 fn icon(&self) -> IconName {
87 IconName::Pencil
88 }
89
90 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
91 json_schema_for::<StreamingEditFileToolInput>(format)
92 }
93
94 fn ui_text(&self, input: &serde_json::Value) -> String {
95 match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
96 Ok(input) => input.display_description,
97 Err(_) => "Editing file".to_string(),
98 }
99 }
100
101 fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
102 if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
103 let description = input.display_description.trim();
104 if !description.is_empty() {
105 return description.to_string();
106 }
107
108 let path = input.path.trim();
109 if !path.is_empty() {
110 return path.to_string();
111 }
112 }
113
114 DEFAULT_UI_TEXT.to_string()
115 }
116
117 fn run(
118 self: Arc<Self>,
119 input: serde_json::Value,
120 messages: &[LanguageModelRequestMessage],
121 project: Entity<Project>,
122 action_log: Entity<ActionLog>,
123 window: Option<AnyWindowHandle>,
124 cx: &mut App,
125 ) -> ToolResult {
126 let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
127 Ok(input) => input,
128 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
129 };
130
131 let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
132 return Task::ready(Err(anyhow!(
133 "Path {} not found in project",
134 input.path.display()
135 )))
136 .into();
137 };
138 let Some(worktree) = project
139 .read(cx)
140 .worktree_for_id(project_path.worktree_id, cx)
141 else {
142 return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
143 };
144 let exists = worktree.update(cx, |worktree, cx| {
145 worktree.file_exists(&project_path.path, cx)
146 });
147
148 let card = window.and_then(|window| {
149 window
150 .update(cx, |_, window, cx| {
151 cx.new(|cx| {
152 EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
153 })
154 })
155 .ok()
156 });
157
158 let card_clone = card.clone();
159 let messages = messages.to_vec();
160 let task = cx.spawn(async move |cx: &mut AsyncApp| {
161 if !exists.await? {
162 return Err(anyhow!("{} not found", input.path.display()));
163 }
164
165 let model = cx
166 .update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
167 .context("default model not set")?
168 .model;
169 let edit_agent = EditAgent::new(model, action_log, Templates::new());
170
171 let buffer = project
172 .update(cx, |project, cx| {
173 project.open_buffer(project_path.clone(), cx)
174 })?
175 .await?;
176
177 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
178 let old_text = cx
179 .background_spawn({
180 let old_snapshot = old_snapshot.clone();
181 async move { old_snapshot.text() }
182 })
183 .await;
184
185 let (output, mut events) = edit_agent.edit(
186 buffer.clone(),
187 input.display_description.clone(),
188 messages,
189 cx,
190 );
191
192 let mut hallucinated_old_text = false;
193 while let Some(event) = events.next().await {
194 match event {
195 EditAgentOutputEvent::Edited => {
196 if let Some(card) = card_clone.as_ref() {
197 let new_snapshot =
198 buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
199 let new_text = cx
200 .background_spawn({
201 let new_snapshot = new_snapshot.clone();
202 async move { new_snapshot.text() }
203 })
204 .await;
205 card.update(cx, |card, cx| {
206 card.set_diff(
207 project_path.path.clone(),
208 old_text.clone(),
209 new_text,
210 cx,
211 );
212 })
213 .log_err();
214 }
215 }
216 EditAgentOutputEvent::HallucinatedOldText(_) => hallucinated_old_text = true,
217 }
218 }
219 output.await?;
220
221 project
222 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
223 .await?;
224
225 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
226 let new_text = cx.background_spawn({
227 let new_snapshot = new_snapshot.clone();
228 async move { new_snapshot.text() }
229 });
230 let diff = cx.background_spawn(async move {
231 language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
232 });
233 let (new_text, diff) = futures::join!(new_text, diff);
234
235 if let Some(card) = card_clone {
236 card.update(cx, |card, cx| {
237 card.set_diff(project_path.path.clone(), old_text, new_text, cx);
238 })
239 .log_err();
240 }
241
242 let input_path = input.path.display();
243 if diff.is_empty() {
244 if hallucinated_old_text {
245 Err(anyhow!(formatdoc! {"
246 Some edits were produced but none of them could be applied.
247 Read the relevant sections of {input_path} again so that
248 I can perform the requested edits.
249 "}))
250 } else {
251 Ok("No edits were made.".to_string())
252 }
253 } else {
254 Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff))
255 }
256 });
257
258 ToolResult {
259 output: task,
260 card: card.map(AnyToolCard::from),
261 }
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use serde_json::json;
269
270 #[test]
271 fn still_streaming_ui_text_with_path() {
272 let input = json!({
273 "path": "src/main.rs",
274 "display_description": "",
275 "old_string": "old code",
276 "new_string": "new code"
277 });
278
279 assert_eq!(
280 StreamingEditFileTool.still_streaming_ui_text(&input),
281 "src/main.rs"
282 );
283 }
284
285 #[test]
286 fn still_streaming_ui_text_with_description() {
287 let input = json!({
288 "path": "",
289 "display_description": "Fix error handling",
290 "old_string": "old code",
291 "new_string": "new code"
292 });
293
294 assert_eq!(
295 StreamingEditFileTool.still_streaming_ui_text(&input),
296 "Fix error handling",
297 );
298 }
299
300 #[test]
301 fn still_streaming_ui_text_with_path_and_description() {
302 let input = json!({
303 "path": "src/main.rs",
304 "display_description": "Fix error handling",
305 "old_string": "old code",
306 "new_string": "new code"
307 });
308
309 assert_eq!(
310 StreamingEditFileTool.still_streaming_ui_text(&input),
311 "Fix error handling",
312 );
313 }
314
315 #[test]
316 fn still_streaming_ui_text_no_path_or_description() {
317 let input = json!({
318 "path": "",
319 "display_description": "",
320 "old_string": "old code",
321 "new_string": "new code"
322 });
323
324 assert_eq!(
325 StreamingEditFileTool.still_streaming_ui_text(&input),
326 DEFAULT_UI_TEXT,
327 );
328 }
329
330 #[test]
331 fn still_streaming_ui_text_with_null() {
332 let input = serde_json::Value::Null;
333
334 assert_eq!(
335 StreamingEditFileTool.still_streaming_ui_text(&input),
336 DEFAULT_UI_TEXT,
337 );
338 }
339}