streaming_edit_file_tool.rs

  1use crate::{
  2    Templates,
  3    edit_agent::{EditAgent, EditAgentOutputEvent},
  4    edit_file_tool::EditFileToolCard,
  5    schema::json_schema_for,
  6};
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult};
  9use futures::StreamExt;
 10use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
 11use indoc::formatdoc;
 12use language_model::{
 13    LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
 14};
 15use project::Project;
 16use schemars::JsonSchema;
 17use serde::{Deserialize, Serialize};
 18use std::{path::PathBuf, sync::Arc};
 19use ui::prelude::*;
 20use util::ResultExt;
 21
 22pub struct StreamingEditFileTool;
 23
 24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 25pub struct StreamingEditFileToolInput {
 26    /// A one-line, user-friendly markdown description of the edit. This will be
 27    /// shown in the UI and also passed to another model to perform the edit.
 28    ///
 29    /// Be terse, but also descriptive in what you want to achieve with this
 30    /// edit. Avoid generic instructions.
 31    ///
 32    /// NEVER mention the file path in this description.
 33    ///
 34    /// <example>Fix API endpoint URLs</example>
 35    /// <example>Update copyright year in `page_footer`</example>
 36    ///
 37    /// Make sure to include this field before all the others in the input object
 38    /// so that we can display it immediately.
 39    pub display_description: String,
 40
 41    /// The full path of the file to create or modify in the project.
 42    ///
 43    /// WARNING: When specifying which file path need changing, you MUST
 44    /// start each path with one of the project's root directories.
 45    ///
 46    /// The following examples assume we have two root directories in the project:
 47    /// - backend
 48    /// - frontend
 49    ///
 50    /// <example>
 51    /// `backend/src/main.rs`
 52    ///
 53    /// Notice how the file path starts with root-1. Without that, the path
 54    /// would be ambiguous and the call would fail!
 55    /// </example>
 56    ///
 57    /// <example>
 58    /// `frontend/db.js`
 59    /// </example>
 60    pub path: PathBuf,
 61
 62    /// If true, this tool will recreate the file from scratch.
 63    /// If false, this tool will produce granular edits to an existing file.
 64    pub create_or_overwrite: bool,
 65}
 66
 67#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 68struct PartialInput {
 69    #[serde(default)]
 70    path: String,
 71    #[serde(default)]
 72    display_description: String,
 73}
 74
 75const DEFAULT_UI_TEXT: &str = "Editing file";
 76
 77impl Tool for StreamingEditFileTool {
 78    fn name(&self) -> String {
 79        "edit_file".into()
 80    }
 81
 82    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 83        false
 84    }
 85
 86    fn description(&self) -> String {
 87        include_str!("streaming_edit_file_tool/description.md").to_string()
 88    }
 89
 90    fn icon(&self) -> IconName {
 91        IconName::Pencil
 92    }
 93
 94    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 95        json_schema_for::<StreamingEditFileToolInput>(format)
 96    }
 97
 98    fn ui_text(&self, input: &serde_json::Value) -> String {
 99        match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
100            Ok(input) => input.display_description,
101            Err(_) => "Editing file".to_string(),
102        }
103    }
104
105    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
106        if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
107            let description = input.display_description.trim();
108            if !description.is_empty() {
109                return description.to_string();
110            }
111
112            let path = input.path.trim();
113            if !path.is_empty() {
114                return path.to_string();
115            }
116        }
117
118        DEFAULT_UI_TEXT.to_string()
119    }
120
121    fn run(
122        self: Arc<Self>,
123        input: serde_json::Value,
124        messages: &[LanguageModelRequestMessage],
125        project: Entity<Project>,
126        action_log: Entity<ActionLog>,
127        window: Option<AnyWindowHandle>,
128        cx: &mut App,
129    ) -> ToolResult {
130        let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
131            Ok(input) => input,
132            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
133        };
134
135        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
136            return Task::ready(Err(anyhow!(
137                "Path {} not found in project",
138                input.path.display()
139            )))
140            .into();
141        };
142        let Some(worktree) = project
143            .read(cx)
144            .worktree_for_id(project_path.worktree_id, cx)
145        else {
146            return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
147        };
148        let exists = worktree.update(cx, |worktree, cx| {
149            worktree.file_exists(&project_path.path, cx)
150        });
151
152        let card = window.and_then(|window| {
153            window
154                .update(cx, |_, window, cx| {
155                    cx.new(|cx| {
156                        EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
157                    })
158                })
159                .ok()
160        });
161
162        let card_clone = card.clone();
163        let messages = messages.to_vec();
164        let task = cx.spawn(async move |cx: &mut AsyncApp| {
165            if !input.create_or_overwrite && !exists.await? {
166                return Err(anyhow!("{} not found", input.path.display()));
167            }
168
169            let model = cx
170                .update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
171                .context("default model not set")?
172                .model;
173            let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
174
175            let buffer = project
176                .update(cx, |project, cx| {
177                    project.open_buffer(project_path.clone(), cx)
178                })?
179                .await?;
180
181            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
182            let old_text = cx
183                .background_spawn({
184                    let old_snapshot = old_snapshot.clone();
185                    async move { old_snapshot.text() }
186                })
187                .await;
188
189            let (output, mut events) = if input.create_or_overwrite {
190                edit_agent.overwrite(
191                    buffer.clone(),
192                    input.display_description.clone(),
193                    messages,
194                    cx,
195                )
196            } else {
197                edit_agent.edit(
198                    buffer.clone(),
199                    input.display_description.clone(),
200                    messages,
201                    cx,
202                )
203            };
204
205            let mut hallucinated_old_text = false;
206            while let Some(event) = events.next().await {
207                match event {
208                    EditAgentOutputEvent::Edited => {
209                        if let Some(card) = card_clone.as_ref() {
210                            let new_snapshot =
211                                buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
212                            let new_text = cx
213                                .background_spawn({
214                                    let new_snapshot = new_snapshot.clone();
215                                    async move { new_snapshot.text() }
216                                })
217                                .await;
218                            card.update(cx, |card, cx| {
219                                card.set_diff(
220                                    project_path.path.clone(),
221                                    old_text.clone(),
222                                    new_text,
223                                    cx,
224                                );
225                            })
226                            .log_err();
227                        }
228                    }
229                    EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
230                }
231            }
232            output.await?;
233
234            project
235                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
236                .await?;
237
238            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
239            let new_text = cx.background_spawn({
240                let new_snapshot = new_snapshot.clone();
241                async move { new_snapshot.text() }
242            });
243            let diff = cx.background_spawn(async move {
244                language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
245            });
246            let (new_text, diff) = futures::join!(new_text, diff);
247
248            if let Some(card) = card_clone {
249                card.update(cx, |card, cx| {
250                    card.set_diff(project_path.path.clone(), old_text, new_text, cx);
251                })
252                .log_err();
253            }
254
255            let input_path = input.path.display();
256            if diff.is_empty() {
257                if hallucinated_old_text {
258                    Err(anyhow!(formatdoc! {"
259                        Some edits were produced but none of them could be applied.
260                        Read the relevant sections of {input_path} again so that
261                        I can perform the requested edits.
262                    "}))
263                } else {
264                    Ok("No edits were made.".to_string())
265                }
266            } else {
267                Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff))
268            }
269        });
270
271        ToolResult {
272            output: task,
273            card: card.map(AnyToolCard::from),
274        }
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use serde_json::json;
282
283    #[test]
284    fn still_streaming_ui_text_with_path() {
285        let input = json!({
286            "path": "src/main.rs",
287            "display_description": "",
288            "old_string": "old code",
289            "new_string": "new code"
290        });
291
292        assert_eq!(
293            StreamingEditFileTool.still_streaming_ui_text(&input),
294            "src/main.rs"
295        );
296    }
297
298    #[test]
299    fn still_streaming_ui_text_with_description() {
300        let input = json!({
301            "path": "",
302            "display_description": "Fix error handling",
303            "old_string": "old code",
304            "new_string": "new code"
305        });
306
307        assert_eq!(
308            StreamingEditFileTool.still_streaming_ui_text(&input),
309            "Fix error handling",
310        );
311    }
312
313    #[test]
314    fn still_streaming_ui_text_with_path_and_description() {
315        let input = json!({
316            "path": "src/main.rs",
317            "display_description": "Fix error handling",
318            "old_string": "old code",
319            "new_string": "new code"
320        });
321
322        assert_eq!(
323            StreamingEditFileTool.still_streaming_ui_text(&input),
324            "Fix error handling",
325        );
326    }
327
328    #[test]
329    fn still_streaming_ui_text_no_path_or_description() {
330        let input = json!({
331            "path": "",
332            "display_description": "",
333            "old_string": "old code",
334            "new_string": "new code"
335        });
336
337        assert_eq!(
338            StreamingEditFileTool.still_streaming_ui_text(&input),
339            DEFAULT_UI_TEXT,
340        );
341    }
342
343    #[test]
344    fn still_streaming_ui_text_with_null() {
345        let input = serde_json::Value::Null;
346
347        assert_eq!(
348            StreamingEditFileTool.still_streaming_ui_text(&input),
349            DEFAULT_UI_TEXT,
350        );
351    }
352}