1use crate::{
2 Templates,
3 edit_agent::{EditAgent, EditAgentOutputEvent},
4 edit_file_tool::EditFileToolCard,
5 schema::json_schema_for,
6};
7use anyhow::{Context as _, Result, anyhow};
8use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult};
9use futures::StreamExt;
10use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
11use indoc::formatdoc;
12use language_model::{
13 LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
14};
15use project::Project;
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use std::{path::PathBuf, sync::Arc};
19use ui::prelude::*;
20use util::ResultExt;
21
22pub struct StreamingEditFileTool;
23
24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
25pub struct StreamingEditFileToolInput {
26 /// A one-line, user-friendly markdown description of the edit. This will be
27 /// shown in the UI and also passed to another model to perform the edit.
28 ///
29 /// Be terse, but also descriptive in what you want to achieve with this
30 /// edit. Avoid generic instructions.
31 ///
32 /// NEVER mention the file path in this description.
33 ///
34 /// <example>Fix API endpoint URLs</example>
35 /// <example>Update copyright year in `page_footer`</example>
36 ///
37 /// Make sure to include this field before all the others in the input object
38 /// so that we can display it immediately.
39 pub display_description: String,
40
41 /// The full path of the file to create or modify in the project.
42 ///
43 /// WARNING: When specifying which file path need changing, you MUST
44 /// start each path with one of the project's root directories.
45 ///
46 /// The following examples assume we have two root directories in the project:
47 /// - backend
48 /// - frontend
49 ///
50 /// <example>
51 /// `backend/src/main.rs`
52 ///
53 /// Notice how the file path starts with root-1. Without that, the path
54 /// would be ambiguous and the call would fail!
55 /// </example>
56 ///
57 /// <example>
58 /// `frontend/db.js`
59 /// </example>
60 pub path: PathBuf,
61
62 /// If true, this tool will recreate the file from scratch.
63 /// If false, this tool will produce granular edits to an existing file.
64 ///
65 /// When a file already exists or you just created it, always prefer editing
66 /// it as opposed to recreating it from scratch.
67 pub create_or_overwrite: bool,
68}
69
70#[derive(Debug, Serialize, Deserialize, JsonSchema)]
71struct PartialInput {
72 #[serde(default)]
73 path: String,
74 #[serde(default)]
75 display_description: String,
76}
77
78const DEFAULT_UI_TEXT: &str = "Editing file";
79
80impl Tool for StreamingEditFileTool {
81 fn name(&self) -> String {
82 "edit_file".into()
83 }
84
85 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
86 false
87 }
88
89 fn description(&self) -> String {
90 include_str!("streaming_edit_file_tool/description.md").to_string()
91 }
92
93 fn icon(&self) -> IconName {
94 IconName::Pencil
95 }
96
97 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
98 json_schema_for::<StreamingEditFileToolInput>(format)
99 }
100
101 fn ui_text(&self, input: &serde_json::Value) -> String {
102 match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
103 Ok(input) => input.display_description,
104 Err(_) => "Editing file".to_string(),
105 }
106 }
107
108 fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
109 if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
110 let description = input.display_description.trim();
111 if !description.is_empty() {
112 return description.to_string();
113 }
114
115 let path = input.path.trim();
116 if !path.is_empty() {
117 return path.to_string();
118 }
119 }
120
121 DEFAULT_UI_TEXT.to_string()
122 }
123
124 fn run(
125 self: Arc<Self>,
126 input: serde_json::Value,
127 messages: &[LanguageModelRequestMessage],
128 project: Entity<Project>,
129 action_log: Entity<ActionLog>,
130 window: Option<AnyWindowHandle>,
131 cx: &mut App,
132 ) -> ToolResult {
133 let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
134 Ok(input) => input,
135 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
136 };
137
138 let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
139 return Task::ready(Err(anyhow!(
140 "Path {} not found in project",
141 input.path.display()
142 )))
143 .into();
144 };
145 let Some(worktree) = project
146 .read(cx)
147 .worktree_for_id(project_path.worktree_id, cx)
148 else {
149 return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
150 };
151 let exists = worktree.update(cx, |worktree, cx| {
152 worktree.file_exists(&project_path.path, cx)
153 });
154
155 let card = window.and_then(|window| {
156 window
157 .update(cx, |_, window, cx| {
158 cx.new(|cx| {
159 EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
160 })
161 })
162 .ok()
163 });
164
165 let card_clone = card.clone();
166 let messages = messages.to_vec();
167 let task = cx.spawn(async move |cx: &mut AsyncApp| {
168 if !input.create_or_overwrite && !exists.await? {
169 return Err(anyhow!("{} not found", input.path.display()));
170 }
171
172 let model = cx
173 .update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
174 .context("default model not set")?
175 .model;
176 let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
177
178 let buffer = project
179 .update(cx, |project, cx| {
180 project.open_buffer(project_path.clone(), cx)
181 })?
182 .await?;
183
184 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
185 let old_text = cx
186 .background_spawn({
187 let old_snapshot = old_snapshot.clone();
188 async move { old_snapshot.text() }
189 })
190 .await;
191
192 let (output, mut events) = if input.create_or_overwrite {
193 edit_agent.overwrite(
194 buffer.clone(),
195 input.display_description.clone(),
196 messages,
197 cx,
198 )
199 } else {
200 edit_agent.edit(
201 buffer.clone(),
202 input.display_description.clone(),
203 messages,
204 cx,
205 )
206 };
207
208 let mut hallucinated_old_text = false;
209 while let Some(event) = events.next().await {
210 match event {
211 EditAgentOutputEvent::Edited => {
212 if let Some(card) = card_clone.as_ref() {
213 let new_snapshot =
214 buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
215 let new_text = cx
216 .background_spawn({
217 let new_snapshot = new_snapshot.clone();
218 async move { new_snapshot.text() }
219 })
220 .await;
221 card.update(cx, |card, cx| {
222 card.set_diff(
223 project_path.path.clone(),
224 old_text.clone(),
225 new_text,
226 cx,
227 );
228 })
229 .log_err();
230 }
231 }
232 EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
233 }
234 }
235 output.await?;
236
237 project
238 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
239 .await?;
240
241 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
242 let new_text = cx.background_spawn({
243 let new_snapshot = new_snapshot.clone();
244 async move { new_snapshot.text() }
245 });
246 let diff = cx.background_spawn(async move {
247 language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
248 });
249 let (new_text, diff) = futures::join!(new_text, diff);
250
251 if let Some(card) = card_clone {
252 card.update(cx, |card, cx| {
253 card.set_diff(project_path.path.clone(), old_text, new_text, cx);
254 })
255 .log_err();
256 }
257
258 let input_path = input.path.display();
259 if diff.is_empty() {
260 if hallucinated_old_text {
261 Err(anyhow!(formatdoc! {"
262 Some edits were produced but none of them could be applied.
263 Read the relevant sections of {input_path} again so that
264 I can perform the requested edits.
265 "}))
266 } else {
267 Ok("No edits were made.".to_string())
268 }
269 } else {
270 Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff))
271 }
272 });
273
274 ToolResult {
275 output: task,
276 card: card.map(AnyToolCard::from),
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use serde_json::json;
285
286 #[test]
287 fn still_streaming_ui_text_with_path() {
288 let input = json!({
289 "path": "src/main.rs",
290 "display_description": "",
291 "old_string": "old code",
292 "new_string": "new code"
293 });
294
295 assert_eq!(
296 StreamingEditFileTool.still_streaming_ui_text(&input),
297 "src/main.rs"
298 );
299 }
300
301 #[test]
302 fn still_streaming_ui_text_with_description() {
303 let input = json!({
304 "path": "",
305 "display_description": "Fix error handling",
306 "old_string": "old code",
307 "new_string": "new code"
308 });
309
310 assert_eq!(
311 StreamingEditFileTool.still_streaming_ui_text(&input),
312 "Fix error handling",
313 );
314 }
315
316 #[test]
317 fn still_streaming_ui_text_with_path_and_description() {
318 let input = json!({
319 "path": "src/main.rs",
320 "display_description": "Fix error handling",
321 "old_string": "old code",
322 "new_string": "new code"
323 });
324
325 assert_eq!(
326 StreamingEditFileTool.still_streaming_ui_text(&input),
327 "Fix error handling",
328 );
329 }
330
331 #[test]
332 fn still_streaming_ui_text_no_path_or_description() {
333 let input = json!({
334 "path": "",
335 "display_description": "",
336 "old_string": "old code",
337 "new_string": "new code"
338 });
339
340 assert_eq!(
341 StreamingEditFileTool.still_streaming_ui_text(&input),
342 DEFAULT_UI_TEXT,
343 );
344 }
345
346 #[test]
347 fn still_streaming_ui_text_with_null() {
348 let input = serde_json::Value::Null;
349
350 assert_eq!(
351 StreamingEditFileTool.still_streaming_ui_text(&input),
352 DEFAULT_UI_TEXT,
353 );
354 }
355}