streaming_edit_file_tool.rs

  1use crate::{
  2    Templates,
  3    edit_agent::{EditAgent, EditAgentOutputEvent},
  4    edit_file_tool::EditFileToolCard,
  5    schema::json_schema_for,
  6};
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult};
  9use futures::StreamExt;
 10use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
 11use indoc::formatdoc;
 12use language_model::{
 13    LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
 14};
 15use project::Project;
 16use schemars::JsonSchema;
 17use serde::{Deserialize, Serialize};
 18use std::{path::PathBuf, sync::Arc};
 19use ui::prelude::*;
 20use util::ResultExt;
 21
 22pub struct StreamingEditFileTool;
 23
 24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 25pub struct StreamingEditFileToolInput {
 26    /// A one-line, user-friendly markdown description of the edit. This will be
 27    /// shown in the UI and also passed to another model to perform the edit.
 28    ///
 29    /// Be terse, but also descriptive in what you want to achieve with this
 30    /// edit. Avoid generic instructions.
 31    ///
 32    /// NEVER mention the file path in this description.
 33    ///
 34    /// <example>Fix API endpoint URLs</example>
 35    /// <example>Update copyright year in `page_footer`</example>
 36    ///
 37    /// Make sure to include this field before all the others in the input object
 38    /// so that we can display it immediately.
 39    pub display_description: String,
 40
 41    /// The full path of the file to modify in the project.
 42    ///
 43    /// WARNING: When specifying which file path need changing, you MUST
 44    /// start each path with one of the project's root directories.
 45    ///
 46    /// The following examples assume we have two root directories in the project:
 47    /// - backend
 48    /// - frontend
 49    ///
 50    /// <example>
 51    /// `backend/src/main.rs`
 52    ///
 53    /// Notice how the file path starts with root-1. Without that, the path
 54    /// would be ambiguous and the call would fail!
 55    /// </example>
 56    ///
 57    /// <example>
 58    /// `frontend/db.js`
 59    /// </example>
 60    pub path: PathBuf,
 61}
 62
 63#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 64struct PartialInput {
 65    #[serde(default)]
 66    path: String,
 67    #[serde(default)]
 68    display_description: String,
 69}
 70
 71const DEFAULT_UI_TEXT: &str = "Editing file";
 72
 73impl Tool for StreamingEditFileTool {
 74    fn name(&self) -> String {
 75        "edit_file".into()
 76    }
 77
 78    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 79        false
 80    }
 81
 82    fn description(&self) -> String {
 83        include_str!("streaming_edit_file_tool/description.md").to_string()
 84    }
 85
 86    fn icon(&self) -> IconName {
 87        IconName::Pencil
 88    }
 89
 90    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 91        json_schema_for::<StreamingEditFileToolInput>(format)
 92    }
 93
 94    fn ui_text(&self, input: &serde_json::Value) -> String {
 95        match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
 96            Ok(input) => input.display_description,
 97            Err(_) => "Editing file".to_string(),
 98        }
 99    }
100
101    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
102        if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
103            let description = input.display_description.trim();
104            if !description.is_empty() {
105                return description.to_string();
106            }
107
108            let path = input.path.trim();
109            if !path.is_empty() {
110                return path.to_string();
111            }
112        }
113
114        DEFAULT_UI_TEXT.to_string()
115    }
116
117    fn run(
118        self: Arc<Self>,
119        input: serde_json::Value,
120        messages: &[LanguageModelRequestMessage],
121        project: Entity<Project>,
122        action_log: Entity<ActionLog>,
123        window: Option<AnyWindowHandle>,
124        cx: &mut App,
125    ) -> ToolResult {
126        let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
127            Ok(input) => input,
128            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
129        };
130
131        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
132            return Task::ready(Err(anyhow!(
133                "Path {} not found in project",
134                input.path.display()
135            )))
136            .into();
137        };
138        let Some(worktree) = project
139            .read(cx)
140            .worktree_for_id(project_path.worktree_id, cx)
141        else {
142            return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
143        };
144        let exists = worktree.update(cx, |worktree, cx| {
145            worktree.file_exists(&project_path.path, cx)
146        });
147
148        let card = window.and_then(|window| {
149            window
150                .update(cx, |_, window, cx| {
151                    cx.new(|cx| {
152                        EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
153                    })
154                })
155                .ok()
156        });
157
158        let card_clone = card.clone();
159        let messages = messages.to_vec();
160        let task = cx.spawn(async move |cx: &mut AsyncApp| {
161            if !exists.await? {
162                return Err(anyhow!("{} not found", input.path.display()));
163            }
164
165            let model = cx
166                .update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
167                .context("default model not set")?
168                .model;
169            let edit_agent = EditAgent::new(model, action_log, Templates::new());
170
171            let buffer = project
172                .update(cx, |project, cx| {
173                    project.open_buffer(project_path.clone(), cx)
174                })?
175                .await?;
176
177            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
178            let old_text = cx
179                .background_spawn({
180                    let old_snapshot = old_snapshot.clone();
181                    async move { old_snapshot.text() }
182                })
183                .await;
184
185            let (output, mut events) = edit_agent.edit(
186                buffer.clone(),
187                input.display_description.clone(),
188                messages,
189                cx,
190            );
191
192            let mut hallucinated_old_text = false;
193            while let Some(event) = events.next().await {
194                match event {
195                    EditAgentOutputEvent::Edited => {
196                        if let Some(card) = card_clone.as_ref() {
197                            let new_snapshot =
198                                buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
199                            let new_text = cx
200                                .background_spawn({
201                                    let new_snapshot = new_snapshot.clone();
202                                    async move { new_snapshot.text() }
203                                })
204                                .await;
205                            card.update(cx, |card, cx| {
206                                card.set_diff(
207                                    project_path.path.clone(),
208                                    old_text.clone(),
209                                    new_text,
210                                    cx,
211                                );
212                            })
213                            .log_err();
214                        }
215                    }
216                    EditAgentOutputEvent::HallucinatedOldText(_) => hallucinated_old_text = true,
217                }
218            }
219            output.await?;
220
221            project
222                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
223                .await?;
224
225            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
226            let new_text = cx.background_spawn({
227                let new_snapshot = new_snapshot.clone();
228                async move { new_snapshot.text() }
229            });
230            let diff = cx.background_spawn(async move {
231                language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
232            });
233            let (new_text, diff) = futures::join!(new_text, diff);
234
235            if let Some(card) = card_clone {
236                card.update(cx, |card, cx| {
237                    card.set_diff(project_path.path.clone(), old_text, new_text, cx);
238                })
239                .log_err();
240            }
241
242            let input_path = input.path.display();
243            if diff.is_empty() {
244                if hallucinated_old_text {
245                    Err(anyhow!(formatdoc! {"
246                        Some edits were produced but none of them could be applied.
247                        Read the relevant sections of {input_path} again so that
248                        I can perform the requested edits.
249                    "}))
250                } else {
251                    Ok("No edits were made.".to_string())
252                }
253            } else {
254                Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff))
255            }
256        });
257
258        ToolResult {
259            output: task,
260            card: card.map(AnyToolCard::from),
261        }
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use serde_json::json;
269
270    #[test]
271    fn still_streaming_ui_text_with_path() {
272        let input = json!({
273            "path": "src/main.rs",
274            "display_description": "",
275            "old_string": "old code",
276            "new_string": "new code"
277        });
278
279        assert_eq!(
280            StreamingEditFileTool.still_streaming_ui_text(&input),
281            "src/main.rs"
282        );
283    }
284
285    #[test]
286    fn still_streaming_ui_text_with_description() {
287        let input = json!({
288            "path": "",
289            "display_description": "Fix error handling",
290            "old_string": "old code",
291            "new_string": "new code"
292        });
293
294        assert_eq!(
295            StreamingEditFileTool.still_streaming_ui_text(&input),
296            "Fix error handling",
297        );
298    }
299
300    #[test]
301    fn still_streaming_ui_text_with_path_and_description() {
302        let input = json!({
303            "path": "src/main.rs",
304            "display_description": "Fix error handling",
305            "old_string": "old code",
306            "new_string": "new code"
307        });
308
309        assert_eq!(
310            StreamingEditFileTool.still_streaming_ui_text(&input),
311            "Fix error handling",
312        );
313    }
314
315    #[test]
316    fn still_streaming_ui_text_no_path_or_description() {
317        let input = json!({
318            "path": "",
319            "display_description": "",
320            "old_string": "old code",
321            "new_string": "new code"
322        });
323
324        assert_eq!(
325            StreamingEditFileTool.still_streaming_ui_text(&input),
326            DEFAULT_UI_TEXT,
327        );
328    }
329
330    #[test]
331    fn still_streaming_ui_text_with_null() {
332        let input = serde_json::Value::Null;
333
334        assert_eq!(
335            StreamingEditFileTool.still_streaming_ui_text(&input),
336            DEFAULT_UI_TEXT,
337        );
338    }
339}