streaming_edit_file_tool.rs

  1use crate::{
  2    Templates,
  3    edit_agent::{EditAgent, EditAgentOutputEvent},
  4    edit_file_tool::EditFileToolCard,
  5    schema::json_schema_for,
  6};
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult};
  9use futures::StreamExt;
 10use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
 11use indoc::formatdoc;
 12use language_model::{
 13    LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
 14};
 15use project::Project;
 16use schemars::JsonSchema;
 17use serde::{Deserialize, Serialize};
 18use std::{path::PathBuf, sync::Arc};
 19use ui::prelude::*;
 20use util::ResultExt;
 21
 22pub struct StreamingEditFileTool;
 23
 24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 25pub struct StreamingEditFileToolInput {
 26    /// A one-line, user-friendly markdown description of the edit. This will be
 27    /// shown in the UI and also passed to another model to perform the edit.
 28    ///
 29    /// Be terse, but also descriptive in what you want to achieve with this
 30    /// edit. Avoid generic instructions.
 31    ///
 32    /// NEVER mention the file path in this description.
 33    ///
 34    /// <example>Fix API endpoint URLs</example>
 35    /// <example>Update copyright year in `page_footer`</example>
 36    ///
 37    /// Make sure to include this field before all the others in the input object
 38    /// so that we can display it immediately.
 39    pub display_description: String,
 40
 41    /// The full path of the file to create or modify in the project.
 42    ///
 43    /// WARNING: When specifying which file path need changing, you MUST
 44    /// start each path with one of the project's root directories.
 45    ///
 46    /// The following examples assume we have two root directories in the project:
 47    /// - backend
 48    /// - frontend
 49    ///
 50    /// <example>
 51    /// `backend/src/main.rs`
 52    ///
 53    /// Notice how the file path starts with root-1. Without that, the path
 54    /// would be ambiguous and the call would fail!
 55    /// </example>
 56    ///
 57    /// <example>
 58    /// `frontend/db.js`
 59    /// </example>
 60    pub path: PathBuf,
 61
 62    /// If true, this tool will recreate the file from scratch.
 63    /// If false, this tool will produce granular edits to an existing file.
 64    ///
 65    /// When a file already exists or you just created it, always prefer editing
 66    /// it as opposed to recreating it from scratch.
 67    pub create_or_overwrite: bool,
 68}
 69
 70#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 71struct PartialInput {
 72    #[serde(default)]
 73    path: String,
 74    #[serde(default)]
 75    display_description: String,
 76}
 77
 78const DEFAULT_UI_TEXT: &str = "Editing file";
 79
 80impl Tool for StreamingEditFileTool {
 81    fn name(&self) -> String {
 82        "edit_file".into()
 83    }
 84
 85    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 86        false
 87    }
 88
 89    fn description(&self) -> String {
 90        include_str!("streaming_edit_file_tool/description.md").to_string()
 91    }
 92
 93    fn icon(&self) -> IconName {
 94        IconName::Pencil
 95    }
 96
 97    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 98        json_schema_for::<StreamingEditFileToolInput>(format)
 99    }
100
101    fn ui_text(&self, input: &serde_json::Value) -> String {
102        match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
103            Ok(input) => input.display_description,
104            Err(_) => "Editing file".to_string(),
105        }
106    }
107
108    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
109        if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
110            let description = input.display_description.trim();
111            if !description.is_empty() {
112                return description.to_string();
113            }
114
115            let path = input.path.trim();
116            if !path.is_empty() {
117                return path.to_string();
118            }
119        }
120
121        DEFAULT_UI_TEXT.to_string()
122    }
123
124    fn run(
125        self: Arc<Self>,
126        input: serde_json::Value,
127        messages: &[LanguageModelRequestMessage],
128        project: Entity<Project>,
129        action_log: Entity<ActionLog>,
130        window: Option<AnyWindowHandle>,
131        cx: &mut App,
132    ) -> ToolResult {
133        let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
134            Ok(input) => input,
135            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
136        };
137
138        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
139            return Task::ready(Err(anyhow!(
140                "Path {} not found in project",
141                input.path.display()
142            )))
143            .into();
144        };
145        let Some(worktree) = project
146            .read(cx)
147            .worktree_for_id(project_path.worktree_id, cx)
148        else {
149            return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
150        };
151        let exists = worktree.update(cx, |worktree, cx| {
152            worktree.file_exists(&project_path.path, cx)
153        });
154
155        let card = window.and_then(|window| {
156            window
157                .update(cx, |_, window, cx| {
158                    cx.new(|cx| {
159                        EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
160                    })
161                })
162                .ok()
163        });
164
165        let card_clone = card.clone();
166        let messages = messages.to_vec();
167        let task = cx.spawn(async move |cx: &mut AsyncApp| {
168            if !input.create_or_overwrite && !exists.await? {
169                return Err(anyhow!("{} not found", input.path.display()));
170            }
171
172            let model = cx
173                .update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
174                .context("default model not set")?
175                .model;
176            let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
177
178            let buffer = project
179                .update(cx, |project, cx| {
180                    project.open_buffer(project_path.clone(), cx)
181                })?
182                .await?;
183
184            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
185            let old_text = cx
186                .background_spawn({
187                    let old_snapshot = old_snapshot.clone();
188                    async move { old_snapshot.text() }
189                })
190                .await;
191
192            let (output, mut events) = if input.create_or_overwrite {
193                edit_agent.overwrite(
194                    buffer.clone(),
195                    input.display_description.clone(),
196                    messages,
197                    cx,
198                )
199            } else {
200                edit_agent.edit(
201                    buffer.clone(),
202                    input.display_description.clone(),
203                    messages,
204                    cx,
205                )
206            };
207
208            let mut hallucinated_old_text = false;
209            while let Some(event) = events.next().await {
210                match event {
211                    EditAgentOutputEvent::Edited => {
212                        if let Some(card) = card_clone.as_ref() {
213                            let new_snapshot =
214                                buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
215                            let new_text = cx
216                                .background_spawn({
217                                    let new_snapshot = new_snapshot.clone();
218                                    async move { new_snapshot.text() }
219                                })
220                                .await;
221                            card.update(cx, |card, cx| {
222                                card.set_diff(
223                                    project_path.path.clone(),
224                                    old_text.clone(),
225                                    new_text,
226                                    cx,
227                                );
228                            })
229                            .log_err();
230                        }
231                    }
232                    EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
233                }
234            }
235            output.await?;
236
237            project
238                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
239                .await?;
240
241            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
242            let new_text = cx.background_spawn({
243                let new_snapshot = new_snapshot.clone();
244                async move { new_snapshot.text() }
245            });
246            let diff = cx.background_spawn(async move {
247                language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
248            });
249            let (new_text, diff) = futures::join!(new_text, diff);
250
251            if let Some(card) = card_clone {
252                card.update(cx, |card, cx| {
253                    card.set_diff(project_path.path.clone(), old_text, new_text, cx);
254                })
255                .log_err();
256            }
257
258            let input_path = input.path.display();
259            if diff.is_empty() {
260                if hallucinated_old_text {
261                    Err(anyhow!(formatdoc! {"
262                        Some edits were produced but none of them could be applied.
263                        Read the relevant sections of {input_path} again so that
264                        I can perform the requested edits.
265                    "}))
266                } else {
267                    Ok("No edits were made.".to_string())
268                }
269            } else {
270                Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff))
271            }
272        });
273
274        ToolResult {
275            output: task,
276            card: card.map(AnyToolCard::from),
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use serde_json::json;
285
286    #[test]
287    fn still_streaming_ui_text_with_path() {
288        let input = json!({
289            "path": "src/main.rs",
290            "display_description": "",
291            "old_string": "old code",
292            "new_string": "new code"
293        });
294
295        assert_eq!(
296            StreamingEditFileTool.still_streaming_ui_text(&input),
297            "src/main.rs"
298        );
299    }
300
301    #[test]
302    fn still_streaming_ui_text_with_description() {
303        let input = json!({
304            "path": "",
305            "display_description": "Fix error handling",
306            "old_string": "old code",
307            "new_string": "new code"
308        });
309
310        assert_eq!(
311            StreamingEditFileTool.still_streaming_ui_text(&input),
312            "Fix error handling",
313        );
314    }
315
316    #[test]
317    fn still_streaming_ui_text_with_path_and_description() {
318        let input = json!({
319            "path": "src/main.rs",
320            "display_description": "Fix error handling",
321            "old_string": "old code",
322            "new_string": "new code"
323        });
324
325        assert_eq!(
326            StreamingEditFileTool.still_streaming_ui_text(&input),
327            "Fix error handling",
328        );
329    }
330
331    #[test]
332    fn still_streaming_ui_text_no_path_or_description() {
333        let input = json!({
334            "path": "",
335            "display_description": "",
336            "old_string": "old code",
337            "new_string": "new code"
338        });
339
340        assert_eq!(
341            StreamingEditFileTool.still_streaming_ui_text(&input),
342            DEFAULT_UI_TEXT,
343        );
344    }
345
346    #[test]
347    fn still_streaming_ui_text_with_null() {
348        let input = serde_json::Value::Null;
349
350        assert_eq!(
351            StreamingEditFileTool.still_streaming_ui_text(&input),
352            DEFAULT_UI_TEXT,
353        );
354    }
355}