1use crate::{
2 Templates,
3 edit_agent::{EditAgent, EditAgentOutputEvent},
4 edit_file_tool::EditFileToolCard,
5 schema::json_schema_for,
6};
7use anyhow::{Context as _, Result, anyhow};
8use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult};
9use futures::StreamExt;
10use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
11use indoc::formatdoc;
12use language_model::{
13 LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
14};
15use project::Project;
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use std::{path::PathBuf, sync::Arc};
19use ui::prelude::*;
20use util::ResultExt;
21
22pub struct StreamingEditFileTool;
23
24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
25pub struct StreamingEditFileToolInput {
26 /// A one-line, user-friendly markdown description of the edit. This will be
27 /// shown in the UI and also passed to another model to perform the edit.
28 ///
29 /// Be terse, but also descriptive in what you want to achieve with this
30 /// edit. Avoid generic instructions.
31 ///
32 /// NEVER mention the file path in this description.
33 ///
34 /// <example>Fix API endpoint URLs</example>
35 /// <example>Update copyright year in `page_footer`</example>
36 ///
37 /// Make sure to include this field before all the others in the input object
38 /// so that we can display it immediately.
39 pub display_description: String,
40
41 /// The full path of the file to create or modify in the project.
42 ///
43 /// WARNING: When specifying which file path need changing, you MUST
44 /// start each path with one of the project's root directories.
45 ///
46 /// The following examples assume we have two root directories in the project:
47 /// - backend
48 /// - frontend
49 ///
50 /// <example>
51 /// `backend/src/main.rs`
52 ///
53 /// Notice how the file path starts with root-1. Without that, the path
54 /// would be ambiguous and the call would fail!
55 /// </example>
56 ///
57 /// <example>
58 /// `frontend/db.js`
59 /// </example>
60 pub path: PathBuf,
61
62 /// If true, this tool will recreate the file from scratch.
63 /// If false, this tool will produce granular edits to an existing file.
64 pub create_or_overwrite: bool,
65}
66
67#[derive(Debug, Serialize, Deserialize, JsonSchema)]
68struct PartialInput {
69 #[serde(default)]
70 path: String,
71 #[serde(default)]
72 display_description: String,
73}
74
75const DEFAULT_UI_TEXT: &str = "Editing file";
76
77impl Tool for StreamingEditFileTool {
78 fn name(&self) -> String {
79 "edit_file".into()
80 }
81
82 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
83 false
84 }
85
86 fn description(&self) -> String {
87 include_str!("streaming_edit_file_tool/description.md").to_string()
88 }
89
90 fn icon(&self) -> IconName {
91 IconName::Pencil
92 }
93
94 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
95 json_schema_for::<StreamingEditFileToolInput>(format)
96 }
97
98 fn ui_text(&self, input: &serde_json::Value) -> String {
99 match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
100 Ok(input) => input.display_description,
101 Err(_) => "Editing file".to_string(),
102 }
103 }
104
105 fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
106 if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
107 let description = input.display_description.trim();
108 if !description.is_empty() {
109 return description.to_string();
110 }
111
112 let path = input.path.trim();
113 if !path.is_empty() {
114 return path.to_string();
115 }
116 }
117
118 DEFAULT_UI_TEXT.to_string()
119 }
120
121 fn run(
122 self: Arc<Self>,
123 input: serde_json::Value,
124 messages: &[LanguageModelRequestMessage],
125 project: Entity<Project>,
126 action_log: Entity<ActionLog>,
127 window: Option<AnyWindowHandle>,
128 cx: &mut App,
129 ) -> ToolResult {
130 let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
131 Ok(input) => input,
132 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
133 };
134
135 let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
136 return Task::ready(Err(anyhow!(
137 "Path {} not found in project",
138 input.path.display()
139 )))
140 .into();
141 };
142 let Some(worktree) = project
143 .read(cx)
144 .worktree_for_id(project_path.worktree_id, cx)
145 else {
146 return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
147 };
148 let exists = worktree.update(cx, |worktree, cx| {
149 worktree.file_exists(&project_path.path, cx)
150 });
151
152 let card = window.and_then(|window| {
153 window
154 .update(cx, |_, window, cx| {
155 cx.new(|cx| {
156 EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
157 })
158 })
159 .ok()
160 });
161
162 let card_clone = card.clone();
163 let messages = messages.to_vec();
164 let task = cx.spawn(async move |cx: &mut AsyncApp| {
165 if !input.create_or_overwrite && !exists.await? {
166 return Err(anyhow!("{} not found", input.path.display()));
167 }
168
169 let model = cx
170 .update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
171 .context("default model not set")?
172 .model;
173 let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
174
175 let buffer = project
176 .update(cx, |project, cx| {
177 project.open_buffer(project_path.clone(), cx)
178 })?
179 .await?;
180
181 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
182 let old_text = cx
183 .background_spawn({
184 let old_snapshot = old_snapshot.clone();
185 async move { old_snapshot.text() }
186 })
187 .await;
188
189 let (output, mut events) = if input.create_or_overwrite {
190 edit_agent.overwrite(
191 buffer.clone(),
192 input.display_description.clone(),
193 messages,
194 cx,
195 )
196 } else {
197 edit_agent.edit(
198 buffer.clone(),
199 input.display_description.clone(),
200 messages,
201 cx,
202 )
203 };
204
205 let mut hallucinated_old_text = false;
206 while let Some(event) = events.next().await {
207 match event {
208 EditAgentOutputEvent::Edited => {
209 if let Some(card) = card_clone.as_ref() {
210 let new_snapshot =
211 buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
212 let new_text = cx
213 .background_spawn({
214 let new_snapshot = new_snapshot.clone();
215 async move { new_snapshot.text() }
216 })
217 .await;
218 card.update(cx, |card, cx| {
219 card.set_diff(
220 project_path.path.clone(),
221 old_text.clone(),
222 new_text,
223 cx,
224 );
225 })
226 .log_err();
227 }
228 }
229 EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
230 }
231 }
232 output.await?;
233
234 project
235 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
236 .await?;
237
238 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
239 let new_text = cx.background_spawn({
240 let new_snapshot = new_snapshot.clone();
241 async move { new_snapshot.text() }
242 });
243 let diff = cx.background_spawn(async move {
244 language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
245 });
246 let (new_text, diff) = futures::join!(new_text, diff);
247
248 if let Some(card) = card_clone {
249 card.update(cx, |card, cx| {
250 card.set_diff(project_path.path.clone(), old_text, new_text, cx);
251 })
252 .log_err();
253 }
254
255 let input_path = input.path.display();
256 if diff.is_empty() {
257 if hallucinated_old_text {
258 Err(anyhow!(formatdoc! {"
259 Some edits were produced but none of them could be applied.
260 Read the relevant sections of {input_path} again so that
261 I can perform the requested edits.
262 "}))
263 } else {
264 Ok("No edits were made.".to_string())
265 }
266 } else {
267 Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff))
268 }
269 });
270
271 ToolResult {
272 output: task,
273 card: card.map(AnyToolCard::from),
274 }
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use serde_json::json;
282
283 #[test]
284 fn still_streaming_ui_text_with_path() {
285 let input = json!({
286 "path": "src/main.rs",
287 "display_description": "",
288 "old_string": "old code",
289 "new_string": "new code"
290 });
291
292 assert_eq!(
293 StreamingEditFileTool.still_streaming_ui_text(&input),
294 "src/main.rs"
295 );
296 }
297
298 #[test]
299 fn still_streaming_ui_text_with_description() {
300 let input = json!({
301 "path": "",
302 "display_description": "Fix error handling",
303 "old_string": "old code",
304 "new_string": "new code"
305 });
306
307 assert_eq!(
308 StreamingEditFileTool.still_streaming_ui_text(&input),
309 "Fix error handling",
310 );
311 }
312
313 #[test]
314 fn still_streaming_ui_text_with_path_and_description() {
315 let input = json!({
316 "path": "src/main.rs",
317 "display_description": "Fix error handling",
318 "old_string": "old code",
319 "new_string": "new code"
320 });
321
322 assert_eq!(
323 StreamingEditFileTool.still_streaming_ui_text(&input),
324 "Fix error handling",
325 );
326 }
327
328 #[test]
329 fn still_streaming_ui_text_no_path_or_description() {
330 let input = json!({
331 "path": "",
332 "display_description": "",
333 "old_string": "old code",
334 "new_string": "new code"
335 });
336
337 assert_eq!(
338 StreamingEditFileTool.still_streaming_ui_text(&input),
339 DEFAULT_UI_TEXT,
340 );
341 }
342
343 #[test]
344 fn still_streaming_ui_text_with_null() {
345 let input = serde_json::Value::Null;
346
347 assert_eq!(
348 StreamingEditFileTool.still_streaming_ui_text(&input),
349 DEFAULT_UI_TEXT,
350 );
351 }
352}