1use crate::{
2 Templates,
3 edit_agent::{EditAgent, EditAgentOutputEvent},
4 edit_file_tool::EditFileToolCard,
5 schema::json_schema_for,
6};
7use anyhow::{Context as _, Result, anyhow};
8use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult, ToolResultOutput};
9use futures::StreamExt;
10use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
11use indoc::formatdoc;
12use language_model::{
13 LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
14};
15use project::Project;
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use std::{path::PathBuf, sync::Arc};
19use ui::prelude::*;
20use util::ResultExt;
21
22pub struct StreamingEditFileTool;
23
24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
25pub struct StreamingEditFileToolInput {
26 /// A one-line, user-friendly markdown description of the edit. This will be
27 /// shown in the UI and also passed to another model to perform the edit.
28 ///
29 /// Be terse, but also descriptive in what you want to achieve with this
30 /// edit. Avoid generic instructions.
31 ///
32 /// NEVER mention the file path in this description.
33 ///
34 /// <example>Fix API endpoint URLs</example>
35 /// <example>Update copyright year in `page_footer`</example>
36 ///
37 /// Make sure to include this field before all the others in the input object
38 /// so that we can display it immediately.
39 pub display_description: String,
40
41 /// The full path of the file to create or modify in the project.
42 ///
43 /// WARNING: When specifying which file path need changing, you MUST
44 /// start each path with one of the project's root directories.
45 ///
46 /// The following examples assume we have two root directories in the project:
47 /// - backend
48 /// - frontend
49 ///
50 /// <example>
51 /// `backend/src/main.rs`
52 ///
53 /// Notice how the file path starts with root-1. Without that, the path
54 /// would be ambiguous and the call would fail!
55 /// </example>
56 ///
57 /// <example>
58 /// `frontend/db.js`
59 /// </example>
60 pub path: PathBuf,
61
62 /// If true, this tool will recreate the file from scratch.
63 /// If false, this tool will produce granular edits to an existing file.
64 ///
65 /// When a file already exists or you just created it, always prefer editing
66 /// it as opposed to recreating it from scratch.
67 pub create_or_overwrite: bool,
68}
69
70#[derive(Debug, Serialize, Deserialize, JsonSchema)]
71pub struct StreamingEditFileToolOutput {
72 pub original_path: PathBuf,
73 pub new_text: String,
74 pub old_text: String,
75}
76
77#[derive(Debug, Serialize, Deserialize, JsonSchema)]
78struct PartialInput {
79 #[serde(default)]
80 path: String,
81 #[serde(default)]
82 display_description: String,
83}
84
85const DEFAULT_UI_TEXT: &str = "Editing file";
86
87impl Tool for StreamingEditFileTool {
88 fn name(&self) -> String {
89 "edit_file".into()
90 }
91
92 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
93 false
94 }
95
96 fn description(&self) -> String {
97 include_str!("streaming_edit_file_tool/description.md").to_string()
98 }
99
100 fn icon(&self) -> IconName {
101 IconName::Pencil
102 }
103
104 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
105 json_schema_for::<StreamingEditFileToolInput>(format)
106 }
107
108 fn ui_text(&self, input: &serde_json::Value) -> String {
109 match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
110 Ok(input) => input.display_description,
111 Err(_) => "Editing file".to_string(),
112 }
113 }
114
115 fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
116 if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
117 let description = input.display_description.trim();
118 if !description.is_empty() {
119 return description.to_string();
120 }
121
122 let path = input.path.trim();
123 if !path.is_empty() {
124 return path.to_string();
125 }
126 }
127
128 DEFAULT_UI_TEXT.to_string()
129 }
130
131 fn run(
132 self: Arc<Self>,
133 input: serde_json::Value,
134 messages: &[LanguageModelRequestMessage],
135 project: Entity<Project>,
136 action_log: Entity<ActionLog>,
137 window: Option<AnyWindowHandle>,
138 cx: &mut App,
139 ) -> ToolResult {
140 let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
141 Ok(input) => input,
142 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
143 };
144
145 let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
146 return Task::ready(Err(anyhow!(
147 "Path {} not found in project",
148 input.path.display()
149 )))
150 .into();
151 };
152 let Some(worktree) = project
153 .read(cx)
154 .worktree_for_id(project_path.worktree_id, cx)
155 else {
156 return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
157 };
158 let exists = worktree.update(cx, |worktree, cx| {
159 worktree.file_exists(&project_path.path, cx)
160 });
161
162 let card = window.and_then(|window| {
163 window
164 .update(cx, |_, window, cx| {
165 cx.new(|cx| {
166 EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
167 })
168 })
169 .ok()
170 });
171
172 let card_clone = card.clone();
173 let messages = messages.to_vec();
174 let task = cx.spawn(async move |cx: &mut AsyncApp| {
175 if !input.create_or_overwrite && !exists.await? {
176 return Err(anyhow!("{} not found", input.path.display()));
177 }
178
179 let model = cx
180 .update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
181 .context("default model not set")?
182 .model;
183 let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
184
185 let buffer = project
186 .update(cx, |project, cx| {
187 project.open_buffer(project_path.clone(), cx)
188 })?
189 .await?;
190
191 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
192 let old_text = cx
193 .background_spawn({
194 let old_snapshot = old_snapshot.clone();
195 async move { old_snapshot.text() }
196 })
197 .await;
198
199 let (output, mut events) = if input.create_or_overwrite {
200 edit_agent.overwrite(
201 buffer.clone(),
202 input.display_description.clone(),
203 messages,
204 cx,
205 )
206 } else {
207 edit_agent.edit(
208 buffer.clone(),
209 input.display_description.clone(),
210 messages,
211 cx,
212 )
213 };
214
215 let mut hallucinated_old_text = false;
216 while let Some(event) = events.next().await {
217 match event {
218 EditAgentOutputEvent::Edited => {
219 if let Some(card) = card_clone.as_ref() {
220 let new_snapshot =
221 buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
222 let new_text = cx
223 .background_spawn({
224 let new_snapshot = new_snapshot.clone();
225 async move { new_snapshot.text() }
226 })
227 .await;
228 card.update(cx, |card, cx| {
229 card.set_diff(
230 project_path.path.clone(),
231 old_text.clone(),
232 new_text,
233 cx,
234 );
235 })
236 .log_err();
237 }
238 }
239 EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
240 }
241 }
242 output.await?;
243
244 project
245 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
246 .await?;
247
248 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
249 let new_text = cx.background_spawn({
250 let new_snapshot = new_snapshot.clone();
251 async move { new_snapshot.text() }
252 });
253 let diff = cx.background_spawn(async move {
254 language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
255 });
256 let (new_text, diff) = futures::join!(new_text, diff);
257
258 let output = StreamingEditFileToolOutput {
259 original_path: project_path.path.to_path_buf(),
260 new_text: new_text.clone(),
261 old_text: old_text.clone(),
262 };
263
264 if let Some(card) = card_clone {
265 card.update(cx, |card, cx| {
266 card.set_diff(project_path.path.clone(), old_text, new_text, cx);
267 })
268 .log_err();
269 }
270
271 let input_path = input.path.display();
272 if diff.is_empty() {
273 if hallucinated_old_text {
274 Err(anyhow!(formatdoc! {"
275 Some edits were produced but none of them could be applied.
276 Read the relevant sections of {input_path} again so that
277 I can perform the requested edits.
278 "}))
279 } else {
280 Ok("No edits were made.".to_string().into())
281 }
282 } else {
283 Ok(ToolResultOutput {
284 content: format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff),
285 output: serde_json::to_value(output).ok(),
286 })
287 }
288 });
289
290 ToolResult {
291 output: task,
292 card: card.map(AnyToolCard::from),
293 }
294 }
295
296 fn deserialize_card(
297 self: Arc<Self>,
298 output: serde_json::Value,
299 project: Entity<Project>,
300 window: &mut Window,
301 cx: &mut App,
302 ) -> Option<AnyToolCard> {
303 let output = match serde_json::from_value::<StreamingEditFileToolOutput>(output) {
304 Ok(output) => output,
305 Err(_) => return None,
306 };
307
308 let card = cx.new(|cx| {
309 let mut card = EditFileToolCard::new(output.original_path.clone(), project, window, cx);
310 card.set_diff(
311 output.original_path.into(),
312 output.old_text,
313 output.new_text,
314 cx,
315 );
316 card
317 });
318
319 Some(card.into())
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use serde_json::json;
327
328 #[test]
329 fn still_streaming_ui_text_with_path() {
330 let input = json!({
331 "path": "src/main.rs",
332 "display_description": "",
333 "old_string": "old code",
334 "new_string": "new code"
335 });
336
337 assert_eq!(
338 StreamingEditFileTool.still_streaming_ui_text(&input),
339 "src/main.rs"
340 );
341 }
342
343 #[test]
344 fn still_streaming_ui_text_with_description() {
345 let input = json!({
346 "path": "",
347 "display_description": "Fix error handling",
348 "old_string": "old code",
349 "new_string": "new code"
350 });
351
352 assert_eq!(
353 StreamingEditFileTool.still_streaming_ui_text(&input),
354 "Fix error handling",
355 );
356 }
357
358 #[test]
359 fn still_streaming_ui_text_with_path_and_description() {
360 let input = json!({
361 "path": "src/main.rs",
362 "display_description": "Fix error handling",
363 "old_string": "old code",
364 "new_string": "new code"
365 });
366
367 assert_eq!(
368 StreamingEditFileTool.still_streaming_ui_text(&input),
369 "Fix error handling",
370 );
371 }
372
373 #[test]
374 fn still_streaming_ui_text_no_path_or_description() {
375 let input = json!({
376 "path": "",
377 "display_description": "",
378 "old_string": "old code",
379 "new_string": "new code"
380 });
381
382 assert_eq!(
383 StreamingEditFileTool.still_streaming_ui_text(&input),
384 DEFAULT_UI_TEXT,
385 );
386 }
387
388 #[test]
389 fn still_streaming_ui_text_with_null() {
390 let input = serde_json::Value::Null;
391
392 assert_eq!(
393 StreamingEditFileTool.still_streaming_ui_text(&input),
394 DEFAULT_UI_TEXT,
395 );
396 }
397}