streaming_edit_file_tool.rs

  1use crate::{
  2    Templates,
  3    edit_agent::{EditAgent, EditAgentOutputEvent},
  4    edit_file_tool::EditFileToolCard,
  5    schema::json_schema_for,
  6};
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult, ToolResultOutput};
  9use futures::StreamExt;
 10use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
 11use indoc::formatdoc;
 12use language_model::{
 13    LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
 14};
 15use project::Project;
 16use schemars::JsonSchema;
 17use serde::{Deserialize, Serialize};
 18use std::{path::PathBuf, sync::Arc};
 19use ui::prelude::*;
 20use util::ResultExt;
 21
 22pub struct StreamingEditFileTool;
 23
 24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 25pub struct StreamingEditFileToolInput {
 26    /// A one-line, user-friendly markdown description of the edit. This will be
 27    /// shown in the UI and also passed to another model to perform the edit.
 28    ///
 29    /// Be terse, but also descriptive in what you want to achieve with this
 30    /// edit. Avoid generic instructions.
 31    ///
 32    /// NEVER mention the file path in this description.
 33    ///
 34    /// <example>Fix API endpoint URLs</example>
 35    /// <example>Update copyright year in `page_footer`</example>
 36    ///
 37    /// Make sure to include this field before all the others in the input object
 38    /// so that we can display it immediately.
 39    pub display_description: String,
 40
 41    /// The full path of the file to create or modify in the project.
 42    ///
 43    /// WARNING: When specifying which file path need changing, you MUST
 44    /// start each path with one of the project's root directories.
 45    ///
 46    /// The following examples assume we have two root directories in the project:
 47    /// - backend
 48    /// - frontend
 49    ///
 50    /// <example>
 51    /// `backend/src/main.rs`
 52    ///
 53    /// Notice how the file path starts with root-1. Without that, the path
 54    /// would be ambiguous and the call would fail!
 55    /// </example>
 56    ///
 57    /// <example>
 58    /// `frontend/db.js`
 59    /// </example>
 60    pub path: PathBuf,
 61
 62    /// If true, this tool will recreate the file from scratch.
 63    /// If false, this tool will produce granular edits to an existing file.
 64    ///
 65    /// When a file already exists or you just created it, always prefer editing
 66    /// it as opposed to recreating it from scratch.
 67    pub create_or_overwrite: bool,
 68}
 69
 70#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 71pub struct StreamingEditFileToolOutput {
 72    pub original_path: PathBuf,
 73    pub new_text: String,
 74    pub old_text: String,
 75}
 76
 77#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 78struct PartialInput {
 79    #[serde(default)]
 80    path: String,
 81    #[serde(default)]
 82    display_description: String,
 83}
 84
 85const DEFAULT_UI_TEXT: &str = "Editing file";
 86
 87impl Tool for StreamingEditFileTool {
 88    fn name(&self) -> String {
 89        "edit_file".into()
 90    }
 91
 92    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 93        false
 94    }
 95
 96    fn description(&self) -> String {
 97        include_str!("streaming_edit_file_tool/description.md").to_string()
 98    }
 99
100    fn icon(&self) -> IconName {
101        IconName::Pencil
102    }
103
104    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
105        json_schema_for::<StreamingEditFileToolInput>(format)
106    }
107
108    fn ui_text(&self, input: &serde_json::Value) -> String {
109        match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
110            Ok(input) => input.display_description,
111            Err(_) => "Editing file".to_string(),
112        }
113    }
114
115    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
116        if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
117            let description = input.display_description.trim();
118            if !description.is_empty() {
119                return description.to_string();
120            }
121
122            let path = input.path.trim();
123            if !path.is_empty() {
124                return path.to_string();
125            }
126        }
127
128        DEFAULT_UI_TEXT.to_string()
129    }
130
131    fn run(
132        self: Arc<Self>,
133        input: serde_json::Value,
134        messages: &[LanguageModelRequestMessage],
135        project: Entity<Project>,
136        action_log: Entity<ActionLog>,
137        window: Option<AnyWindowHandle>,
138        cx: &mut App,
139    ) -> ToolResult {
140        let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
141            Ok(input) => input,
142            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
143        };
144
145        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
146            return Task::ready(Err(anyhow!(
147                "Path {} not found in project",
148                input.path.display()
149            )))
150            .into();
151        };
152        let Some(worktree) = project
153            .read(cx)
154            .worktree_for_id(project_path.worktree_id, cx)
155        else {
156            return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
157        };
158        let exists = worktree.update(cx, |worktree, cx| {
159            worktree.file_exists(&project_path.path, cx)
160        });
161
162        let card = window.and_then(|window| {
163            window
164                .update(cx, |_, window, cx| {
165                    cx.new(|cx| {
166                        EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
167                    })
168                })
169                .ok()
170        });
171
172        let card_clone = card.clone();
173        let messages = messages.to_vec();
174        let task = cx.spawn(async move |cx: &mut AsyncApp| {
175            if !input.create_or_overwrite && !exists.await? {
176                return Err(anyhow!("{} not found", input.path.display()));
177            }
178
179            let model = cx
180                .update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
181                .context("default model not set")?
182                .model;
183            let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
184
185            let buffer = project
186                .update(cx, |project, cx| {
187                    project.open_buffer(project_path.clone(), cx)
188                })?
189                .await?;
190
191            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
192            let old_text = cx
193                .background_spawn({
194                    let old_snapshot = old_snapshot.clone();
195                    async move { old_snapshot.text() }
196                })
197                .await;
198
199            let (output, mut events) = if input.create_or_overwrite {
200                edit_agent.overwrite(
201                    buffer.clone(),
202                    input.display_description.clone(),
203                    messages,
204                    cx,
205                )
206            } else {
207                edit_agent.edit(
208                    buffer.clone(),
209                    input.display_description.clone(),
210                    messages,
211                    cx,
212                )
213            };
214
215            let mut hallucinated_old_text = false;
216            while let Some(event) = events.next().await {
217                match event {
218                    EditAgentOutputEvent::Edited => {
219                        if let Some(card) = card_clone.as_ref() {
220                            let new_snapshot =
221                                buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
222                            let new_text = cx
223                                .background_spawn({
224                                    let new_snapshot = new_snapshot.clone();
225                                    async move { new_snapshot.text() }
226                                })
227                                .await;
228                            card.update(cx, |card, cx| {
229                                card.set_diff(
230                                    project_path.path.clone(),
231                                    old_text.clone(),
232                                    new_text,
233                                    cx,
234                                );
235                            })
236                            .log_err();
237                        }
238                    }
239                    EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
240                }
241            }
242            output.await?;
243
244            project
245                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
246                .await?;
247
248            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
249            let new_text = cx.background_spawn({
250                let new_snapshot = new_snapshot.clone();
251                async move { new_snapshot.text() }
252            });
253            let diff = cx.background_spawn(async move {
254                language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
255            });
256            let (new_text, diff) = futures::join!(new_text, diff);
257
258            let output = StreamingEditFileToolOutput {
259                original_path: project_path.path.to_path_buf(),
260                new_text: new_text.clone(),
261                old_text: old_text.clone(),
262            };
263
264            if let Some(card) = card_clone {
265                card.update(cx, |card, cx| {
266                    card.set_diff(project_path.path.clone(), old_text, new_text, cx);
267                })
268                .log_err();
269            }
270
271            let input_path = input.path.display();
272            if diff.is_empty() {
273                if hallucinated_old_text {
274                    Err(anyhow!(formatdoc! {"
275                        Some edits were produced but none of them could be applied.
276                        Read the relevant sections of {input_path} again so that
277                        I can perform the requested edits.
278                    "}))
279                } else {
280                    Ok("No edits were made.".to_string().into())
281                }
282            } else {
283                Ok(ToolResultOutput {
284                    content: format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff),
285                    output: serde_json::to_value(output).ok(),
286                })
287            }
288        });
289
290        ToolResult {
291            output: task,
292            card: card.map(AnyToolCard::from),
293        }
294    }
295
296    fn deserialize_card(
297        self: Arc<Self>,
298        output: serde_json::Value,
299        project: Entity<Project>,
300        window: &mut Window,
301        cx: &mut App,
302    ) -> Option<AnyToolCard> {
303        let output = match serde_json::from_value::<StreamingEditFileToolOutput>(output) {
304            Ok(output) => output,
305            Err(_) => return None,
306        };
307
308        let card = cx.new(|cx| {
309            let mut card = EditFileToolCard::new(output.original_path.clone(), project, window, cx);
310            card.set_diff(
311                output.original_path.into(),
312                output.old_text,
313                output.new_text,
314                cx,
315            );
316            card
317        });
318
319        Some(card.into())
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use serde_json::json;
327
328    #[test]
329    fn still_streaming_ui_text_with_path() {
330        let input = json!({
331            "path": "src/main.rs",
332            "display_description": "",
333            "old_string": "old code",
334            "new_string": "new code"
335        });
336
337        assert_eq!(
338            StreamingEditFileTool.still_streaming_ui_text(&input),
339            "src/main.rs"
340        );
341    }
342
343    #[test]
344    fn still_streaming_ui_text_with_description() {
345        let input = json!({
346            "path": "",
347            "display_description": "Fix error handling",
348            "old_string": "old code",
349            "new_string": "new code"
350        });
351
352        assert_eq!(
353            StreamingEditFileTool.still_streaming_ui_text(&input),
354            "Fix error handling",
355        );
356    }
357
358    #[test]
359    fn still_streaming_ui_text_with_path_and_description() {
360        let input = json!({
361            "path": "src/main.rs",
362            "display_description": "Fix error handling",
363            "old_string": "old code",
364            "new_string": "new code"
365        });
366
367        assert_eq!(
368            StreamingEditFileTool.still_streaming_ui_text(&input),
369            "Fix error handling",
370        );
371    }
372
373    #[test]
374    fn still_streaming_ui_text_no_path_or_description() {
375        let input = json!({
376            "path": "",
377            "display_description": "",
378            "old_string": "old code",
379            "new_string": "new code"
380        });
381
382        assert_eq!(
383            StreamingEditFileTool.still_streaming_ui_text(&input),
384            DEFAULT_UI_TEXT,
385        );
386    }
387
388    #[test]
389    fn still_streaming_ui_text_with_null() {
390        let input = serde_json::Value::Null;
391
392        assert_eq!(
393            StreamingEditFileTool.still_streaming_ui_text(&input),
394            DEFAULT_UI_TEXT,
395        );
396    }
397}