Streaming tools (#11629)

Kyle Kelley , Max Brunsfeld , Marshall , and Max created

Stream characters in for tool calls to allow rendering partial input.


https://github.com/zed-industries/zed/assets/836375/0f023a4b-9c46-4449-ae69-8b6bcab41673

Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Max <max@zed.dev>

Change summary

Cargo.lock                                          |  20 
Cargo.toml                                          |   1 
crates/assistant2/Cargo.toml                        |   1 
crates/assistant2/src/assistant2.rs                 |  68 -
crates/assistant2/src/attachments/active_file.rs    |   4 
crates/assistant2/src/tools/annotate_code.rs        | 259 ++++--
crates/assistant2/src/tools/create_buffer.rs        |  96 +
crates/assistant2/src/tools/project_index.rs        | 360 +++++----
crates/assistant_tooling/Cargo.toml                 |   2 
crates/assistant_tooling/src/assistant_tooling.rs   |   8 
crates/assistant_tooling/src/attachment_registry.rs |   8 
crates/assistant_tooling/src/tool_registry.rs       | 572 ++++++--------
crates/multi_buffer/src/multi_buffer.rs             |  24 
13 files changed, 778 insertions(+), 645 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -390,7 +390,6 @@ dependencies = [
  "language",
  "languages",
  "log",
- "nanoid",
  "node_runtime",
  "open_ai",
  "picker",
@@ -419,7 +418,9 @@ dependencies = [
  "collections",
  "futures 0.3.28",
  "gpui",
+ "log",
  "project",
+ "repair_json",
  "schemars",
  "serde",
  "serde_json",
@@ -8050,6 +8051,15 @@ dependencies = [
  "bytecheck",
 ]
 
+[[package]]
+name = "repair_json"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5ee191e184125fe72cb59b74160e25584e3908f2aaa84cbda1e161347102aa15"
+dependencies = [
+ "thiserror",
+]
+
 [[package]]
 name = "reqwest"
 version = "0.11.20"
@@ -10185,18 +10195,18 @@ dependencies = [
 
 [[package]]
 name = "thiserror"
-version = "1.0.48"
+version = "1.0.60"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7"
+checksum = "579e9083ca58dd9dcf91a9923bb9054071b9ebbd800b342194c9feb0ee89fc18"
 dependencies = [
  "thiserror-impl",
 ]
 
 [[package]]
 name = "thiserror-impl"
-version = "1.0.48"
+version = "1.0.60"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35"
+checksum = "e2470041c06ec3ac1ab38d0356a6119054dedaea53e12fbefc0de730a1c08524"
 dependencies = [
  "proc-macro2",
  "quote",

Cargo.toml 🔗

@@ -307,6 +307,7 @@ pulldown-cmark = { version = "0.10.0", default-features = false }
 rand = "0.8.5"
 refineable = { path = "./crates/refineable" }
 regex = "1.5"
+repair_json = "0.1.0"
 rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
 rust-embed = { version = "8.0", features = ["include-exclude"] }
 schemars = "0.8"

crates/assistant2/Cargo.toml 🔗

@@ -29,7 +29,6 @@ fuzzy.workspace = true
 gpui.workspace = true
 language.workspace = true
 log.workspace = true
-nanoid.workspace = true
 open_ai.workspace = true
 picker.workspace = true
 project.workspace = true

crates/assistant2/src/assistant2.rs 🔗

@@ -536,25 +536,27 @@ impl AssistantChat {
                                 body.push_str(content);
                             }
 
-                            for tool_call in delta.tool_calls {
-                                let index = tool_call.index as usize;
+                            for tool_call_delta in delta.tool_calls {
+                                let index = tool_call_delta.index as usize;
                                 if index >= message.tool_calls.len() {
                                     message.tool_calls.resize_with(index + 1, Default::default);
                                 }
-                                let call = &mut message.tool_calls[index];
+                                let tool_call = &mut message.tool_calls[index];
 
-                                if let Some(id) = &tool_call.id {
-                                    call.id.push_str(id);
+                                if let Some(id) = &tool_call_delta.id {
+                                    tool_call.id.push_str(id);
                                 }
 
-                                match tool_call.variant {
-                                    Some(proto::tool_call_delta::Variant::Function(tool_call)) => {
-                                        if let Some(name) = &tool_call.name {
-                                            call.name.push_str(name);
-                                        }
-                                        if let Some(arguments) = &tool_call.arguments {
-                                            call.arguments.push_str(arguments);
-                                        }
+                                match tool_call_delta.variant {
+                                    Some(proto::tool_call_delta::Variant::Function(
+                                        tool_call_delta,
+                                    )) => {
+                                        this.tool_registry.update_tool_call(
+                                            tool_call,
+                                            tool_call_delta.name.as_deref(),
+                                            tool_call_delta.arguments.as_deref(),
+                                            cx,
+                                        );
                                     }
                                     None => {}
                                 }
@@ -587,34 +589,20 @@ impl AssistantChat {
                     } else {
                         if let Some(current_message) = messages.last_mut() {
                             for tool_call in current_message.tool_calls.iter() {
-                                tool_tasks.push(this.tool_registry.call(tool_call, cx));
+                                tool_tasks
+                                    .extend(this.tool_registry.execute_tool_call(&tool_call, cx));
                             }
                         }
                     }
                 }
             })?;
 
+            // This ends recursion on calling for responses after tools
             if tool_tasks.is_empty() {
                 return Ok(());
             }
 
-            let tools = join_all(tool_tasks.into_iter()).await;
-            // If the WindowContext went away for any tool's view we don't include it
-            // especially since the below call would fail for the same reason.
-            let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
-
-            this.update(cx, |this, cx| {
-                if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
-                    this.messages.last_mut()
-                {
-                    if let Some(current_message) = messages.last_mut() {
-                        current_message.tool_calls = tools;
-                        cx.notify();
-                    } else {
-                        unreachable!()
-                    }
-                }
-            })?;
+            join_all(tool_tasks.into_iter()).await;
         }
     }
 
@@ -948,13 +936,11 @@ impl AssistantChat {
 
                         for tool_call in &message.tool_calls {
                             // Every tool call _must_ have a result by ID, otherwise OpenAI will error.
-                            let content = match &tool_call.result {
-                                Some(result) => {
-                                    result.generate(&tool_call.name, &mut project_context, cx)
-                                }
-                                None => "".to_string(),
-                            };
-
+                            let content = self.tool_registry.content_for_tool_call(
+                                tool_call,
+                                &mut project_context,
+                                cx,
+                            );
                             completion_messages.push(CompletionMessage::Tool {
                                 content,
                                 tool_call_id: tool_call.id.clone(),
@@ -1003,7 +989,11 @@ impl AssistantChat {
                         tool_calls: message
                             .tool_calls
                             .iter()
-                            .map(|tool_call| self.tool_registry.serialize_tool_call(tool_call))
+                            .filter_map(|tool_call| {
+                                self.tool_registry
+                                    .serialize_tool_call(tool_call, cx)
+                                    .log_err()
+                            })
                             .collect(),
                     })
                     .collect(),

crates/assistant2/src/attachments/active_file.rs 🔗

@@ -1,7 +1,7 @@
 use std::{path::PathBuf, sync::Arc};
 
 use anyhow::{anyhow, Result};
-use assistant_tooling::{LanguageModelAttachment, ProjectContext, ToolOutput};
+use assistant_tooling::{AttachmentOutput, LanguageModelAttachment, ProjectContext};
 use editor::Editor;
 use gpui::{Render, Task, View, WeakModel, WeakView};
 use language::Buffer;
@@ -52,7 +52,7 @@ impl Render for FileAttachmentView {
     }
 }
 
-impl ToolOutput for FileAttachmentView {
+impl AttachmentOutput for FileAttachmentView {
     fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
         if let Some(path) = &self.project_path {
             project.add_file(path.clone());

crates/assistant2/src/tools/annotate_code.rs 🔗

@@ -4,7 +4,8 @@ use editor::{
     display_map::{BlockContext, BlockDisposition, BlockProperties, BlockStyle},
     Editor, MultiBuffer,
 };
-use gpui::{prelude::*, AnyElement, Model, Task, View, WeakView};
+use futures::{channel::mpsc::UnboundedSender, StreamExt as _};
+use gpui::{prelude::*, AnyElement, AsyncWindowContext, Model, Task, View, WeakView};
 use language::ToPoint;
 use project::{search::SearchQuery, Project, ProjectPath};
 use schemars::JsonSchema;
@@ -25,14 +26,19 @@ impl AnnotationTool {
     }
 }
 
-#[derive(Debug, Deserialize, JsonSchema, Clone)]
+#[derive(Default, Debug, Deserialize, JsonSchema, Clone)]
 pub struct AnnotationInput {
     /// Name for this set of annotations
+    #[serde(default = "default_title")]
     title: String,
     /// Excerpts from the file to show to the user.
     excerpts: Vec<Excerpt>,
 }
 
+fn default_title() -> String {
+    "Untitled".to_string()
+}
+
 #[derive(Debug, Deserialize, JsonSchema, Clone)]
 struct Excerpt {
     /// Path to the file
@@ -44,8 +50,6 @@ struct Excerpt {
 }
 
 impl LanguageModelTool for AnnotationTool {
-    type Input = AnnotationInput;
-    type Output = String;
     type View = AnnotationResultView;
 
     fn name(&self) -> String {
@@ -56,67 +60,100 @@ impl LanguageModelTool for AnnotationTool {
         "Dynamically annotate symbols in the current codebase. Opens a buffer in a panel in their editor, to the side of the conversation. The annotations are shown in the editor as a block decoration.".to_string()
     }
 
-    fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
-        let workspace = self.workspace.clone();
-        let project = self.project.clone();
-        let excerpts = input.excerpts.clone();
-        let title = input.title.clone();
+    fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
+        cx.new_view(|cx| {
+            let (tx, mut rx) = futures::channel::mpsc::unbounded();
+            cx.spawn(|view, mut cx| async move {
+                while let Some(excerpt) = rx.next().await {
+                    AnnotationResultView::add_excerpt(view.clone(), excerpt, &mut cx).await?;
+                }
+                anyhow::Ok(())
+            })
+            .detach();
+
+            AnnotationResultView {
+                project: self.project.clone(),
+                workspace: self.workspace.clone(),
+                tx,
+                pending_excerpt: None,
+                added_editor_to_workspace: false,
+                editor: None,
+                error: None,
+                rendered_excerpt_count: 0,
+            }
+        })
+    }
+}
+
+pub struct AnnotationResultView {
+    workspace: WeakView<Workspace>,
+    project: Model<Project>,
+    pending_excerpt: Option<Excerpt>,
+    added_editor_to_workspace: bool,
+    editor: Option<View<Editor>>,
+    tx: UnboundedSender<Excerpt>,
+    error: Option<anyhow::Error>,
+    rendered_excerpt_count: usize,
+}
+
+impl AnnotationResultView {
+    async fn add_excerpt(
+        this: WeakView<Self>,
+        excerpt: Excerpt,
+        cx: &mut AsyncWindowContext,
+    ) -> Result<()> {
+        let project = this.update(cx, |this, _cx| this.project.clone())?;
 
         let worktree_id = project.update(cx, |project, cx| {
             let worktree = project.worktrees().next()?;
             let worktree_id = worktree.read(cx).id();
             Some(worktree_id)
-        });
+        })?;
 
         let worktree_id = if let Some(worktree_id) = worktree_id {
             worktree_id
         } else {
-            return Task::ready(Err(anyhow::anyhow!("No worktree found")));
+            return Err(anyhow::anyhow!("No worktree found"));
         };
 
-        let buffer_tasks = project.update(cx, |project, cx| {
-            excerpts
-                .iter()
-                .map(|excerpt| {
-                    project.open_buffer(
-                        ProjectPath {
-                            worktree_id,
-                            path: Path::new(&excerpt.path).into(),
-                        },
-                        cx,
-                    )
+        let buffer_task = project.update(cx, |project, cx| {
+            project.open_buffer(
+                ProjectPath {
+                    worktree_id,
+                    path: Path::new(&excerpt.path).into(),
+                },
+                cx,
+            )
+        })?;
+
+        let buffer = match buffer_task.await {
+            Ok(buffer) => buffer,
+            Err(error) => {
+                return this.update(cx, |this, cx| {
+                    this.error = Some(error);
+                    cx.notify();
                 })
-                .collect::<Vec<_>>()
-        });
+            }
+        };
 
-        cx.spawn(move |mut cx| async move {
-            let buffers = futures::future::try_join_all(buffer_tasks).await?;
+        let snapshot = buffer.update(cx, |buffer, _cx| buffer.snapshot())?;
+        let query = SearchQuery::text(&excerpt.text_passage, false, false, false, vec![], vec![])?;
+        let matches = query.search(&snapshot, None).await;
+        let Some(first_match) = matches.first() else {
+            log::warn!(
+                "text {:?} does not appear in '{}'",
+                excerpt.text_passage,
+                excerpt.path
+            );
+            return Ok(());
+        };
 
-            let multibuffer = cx.new_model(|_cx| {
-                MultiBuffer::new(0, language::Capability::ReadWrite).with_title(title)
-            })?;
-            let editor =
-                cx.new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), cx))?;
-
-            for (excerpt, buffer) in excerpts.iter().zip(buffers.iter()) {
-                let snapshot = buffer.update(&mut cx, |buffer, _cx| buffer.snapshot())?;
-
-                let query =
-                    SearchQuery::text(&excerpt.text_passage, false, false, false, vec![], vec![])?;
-
-                let matches = query.search(&snapshot, None).await;
-                let Some(first_match) = matches.first() else {
-                    log::warn!(
-                        "text {:?} does not appear in '{}'",
-                        excerpt.text_passage,
-                        excerpt.path
-                    );
-                    continue;
-                };
-                let mut start = first_match.start.to_point(&snapshot);
-                start.column = 0;
+        this.update(cx, |this, cx| {
+            let mut start = first_match.start.to_point(&snapshot);
+            start.column = 0;
 
-                editor.update(&mut cx, |editor, cx| {
+            if let Some(editor) = &this.editor {
+                editor.update(cx, |editor, cx| {
                     let ranges = editor.buffer().update(cx, |multibuffer, cx| {
                         multibuffer.push_excerpts_with_context_lines(
                             buffer.clone(),
@@ -125,7 +162,8 @@ impl LanguageModelTool for AnnotationTool {
                             cx,
                         )
                     });
-                    let annotation = SharedString::from(excerpt.annotation.clone());
+
+                    let annotation = SharedString::from(excerpt.annotation);
                     editor.insert_blocks(
                         [BlockProperties {
                             position: ranges[0].start,
@@ -137,30 +175,22 @@ impl LanguageModelTool for AnnotationTool {
                         None,
                         cx,
                     );
-                })?;
-            }
+                });
 
-            workspace
-                .update(&mut cx, |workspace, cx| {
-                    workspace.add_item_to_active_pane(Box::new(editor.clone()), None, cx);
-                })
-                .log_err();
+                if !this.added_editor_to_workspace {
+                    this.added_editor_to_workspace = true;
+                    this.workspace
+                        .update(cx, |workspace, cx| {
+                            workspace.add_item_to_active_pane(Box::new(editor.clone()), None, cx);
+                        })
+                        .log_err();
+                }
+            }
+        })?;
 
-            anyhow::Ok("showed comments to users in a new view".into())
-        })
+        Ok(())
     }
 
-    fn view(
-        &self,
-        _: Self::Input,
-        output: Result<Self::Output>,
-        cx: &mut WindowContext,
-    ) -> View<Self::View> {
-        cx.new_view(|_cx| AnnotationResultView { output })
-    }
-}
-
-impl AnnotationTool {
     fn render_note_block(explanation: &SharedString, cx: &mut BlockContext) -> AnyElement {
         let anchor_x = cx.anchor_x;
         let gutter_width = cx.gutter_dimensions.width;
@@ -186,24 +216,89 @@ impl AnnotationTool {
     }
 }
 
-pub struct AnnotationResultView {
-    output: Result<String>,
-}
-
 impl Render for AnnotationResultView {
     fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
-        match &self.output {
-            Ok(output) => div().child(output.clone().into_any_element()),
-            Err(error) => div().child(format!("failed to open path: {:?}", error)),
+        if let Some(error) = &self.error {
+            ui::Label::new(error.to_string()).into_any_element()
+        } else {
+            ui::Label::new(SharedString::from(format!(
+                "Opened a buffer with {} excerpts",
+                self.rendered_excerpt_count
+            )))
+            .into_any_element()
         }
     }
 }
 
 impl ToolOutput for AnnotationResultView {
-    fn generate(&self, _: &mut ProjectContext, _: &mut WindowContext) -> String {
-        match &self.output {
-            Ok(output) => output.clone(),
-            Err(err) => format!("Failed to create buffer: {err:?}"),
+    type Input = AnnotationInput;
+    type SerializedState = Option<String>;
+
+    fn generate(&self, _: &mut ProjectContext, _: &mut ViewContext<Self>) -> String {
+        if let Some(error) = &self.error {
+            format!("Failed to create buffer: {error:?}")
+        } else {
+            format!(
+                "opened {} excerpts in a buffer",
+                self.rendered_excerpt_count
+            )
+        }
+    }
+
+    fn set_input(&mut self, mut input: Self::Input, cx: &mut ViewContext<Self>) {
+        let editor = if let Some(editor) = &self.editor {
+            editor.clone()
+        } else {
+            let multibuffer = cx.new_model(|_cx| {
+                MultiBuffer::new(0, language::Capability::ReadWrite).with_title(String::new())
+            });
+            let editor = cx.new_view(|cx| {
+                Editor::for_multibuffer(multibuffer.clone(), Some(self.project.clone()), cx)
+            });
+
+            self.editor = Some(editor.clone());
+            editor
+        };
+
+        editor.update(cx, |editor, cx| {
+            editor.buffer().update(cx, |multibuffer, cx| {
+                if multibuffer.title(cx) != input.title {
+                    multibuffer.set_title(input.title.clone(), cx);
+                }
+            });
+
+            self.pending_excerpt = input.excerpts.pop();
+            for excerpt in input.excerpts.iter().skip(self.rendered_excerpt_count) {
+                self.tx.unbounded_send(excerpt.clone()).ok();
+            }
+            self.rendered_excerpt_count = input.excerpts.len();
+        });
+
+        cx.notify();
+    }
+
+    fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
+        if let Some(excerpt) = self.pending_excerpt.take() {
+            self.rendered_excerpt_count += 1;
+            self.tx.unbounded_send(excerpt.clone()).ok();
+        }
+
+        self.tx.close_channel();
+        Task::ready(Ok(()))
+    }
+
+    fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
+        self.error.as_ref().map(|error| error.to_string())
+    }
+
+    fn deserialize(
+        &mut self,
+        output: Self::SerializedState,
+        _cx: &mut ViewContext<Self>,
+    ) -> Result<()> {
+        if let Some(error_message) = output {
+            self.error = Some(anyhow::anyhow!("{}", error_message));
         }
+        Ok(())
     }
 }

crates/assistant2/src/tools/create_buffer.rs 🔗

@@ -1,4 +1,4 @@
-use anyhow::Result;
+use anyhow::{anyhow, Result};
 use assistant_tooling::{LanguageModelTool, ProjectContext, ToolOutput};
 use editor::Editor;
 use gpui::{prelude::*, Model, Task, View, WeakView};
@@ -20,7 +20,7 @@ impl CreateBufferTool {
     }
 }
 
-#[derive(Debug, Deserialize, JsonSchema)]
+#[derive(Debug, Clone, Deserialize, JsonSchema)]
 pub struct CreateBufferInput {
     /// The contents of the buffer.
     text: String,
@@ -32,8 +32,6 @@ pub struct CreateBufferInput {
 }
 
 impl LanguageModelTool for CreateBufferTool {
-    type Input = CreateBufferInput;
-    type Output = ();
     type View = CreateBufferView;
 
     fn name(&self) -> String {
@@ -44,13 +42,59 @@ impl LanguageModelTool for CreateBufferTool {
         "Create a new buffer in the current codebase".to_string()
     }
 
-    fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
+    fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
+        cx.new_view(|_cx| CreateBufferView {
+            workspace: self.workspace.clone(),
+            project: self.project.clone(),
+            input: None,
+            error: None,
+        })
+    }
+}
+
+pub struct CreateBufferView {
+    workspace: WeakView<Workspace>,
+    project: Model<Project>,
+    input: Option<CreateBufferInput>,
+    error: Option<anyhow::Error>,
+}
+
+impl Render for CreateBufferView {
+    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
+        div().child("Opening a buffer")
+    }
+}
+
+impl ToolOutput for CreateBufferView {
+    type Input = CreateBufferInput;
+
+    type SerializedState = ();
+
+    fn generate(&self, _project: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
+        let Some(input) = self.input.as_ref() else {
+            return "No input".to_string();
+        };
+
+        match &self.error {
+            None => format!("Created a new {} buffer", input.language),
+            Some(err) => format!("Failed to create buffer: {err:?}"),
+        }
+    }
+
+    fn set_input(&mut self, input: Self::Input, _cx: &mut ViewContext<Self>) {
+        self.input = Some(input);
+    }
+
+    fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
         cx.spawn({
             let workspace = self.workspace.clone();
             let project = self.project.clone();
-            let text = input.text.clone();
-            let language_name = input.language.clone();
-            |mut cx| async move {
+            let input = self.input.clone();
+            |_this, mut cx| async move {
+                let input = input.ok_or_else(|| anyhow!("no input"))?;
+
+                let text = input.text.clone();
+                let language_name = input.language.clone();
                 let language = cx
                     .update(|cx| {
                         project
@@ -86,35 +130,15 @@ impl LanguageModelTool for CreateBufferTool {
         })
     }
 
-    fn view(
-        &self,
-        input: Self::Input,
-        output: Result<Self::Output>,
-        cx: &mut WindowContext,
-    ) -> View<Self::View> {
-        cx.new_view(|_cx| CreateBufferView {
-            language: input.language,
-            output,
-        })
+    fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
+        ()
     }
-}
 
-pub struct CreateBufferView {
-    language: String,
-    output: Result<()>,
-}
-
-impl Render for CreateBufferView {
-    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
-        div().child("Opening a buffer")
-    }
-}
-
-impl ToolOutput for CreateBufferView {
-    fn generate(&self, _: &mut ProjectContext, _: &mut WindowContext) -> String {
-        match &self.output {
-            Ok(_) => format!("Created a new {} buffer", self.language),
-            Err(err) => format!("Failed to create buffer: {err:?}"),
-        }
+    fn deserialize(
+        &mut self,
+        _output: Self::SerializedState,
+        _cx: &mut ViewContext<Self>,
+    ) -> Result<()> {
+        Ok(())
     }
 }

crates/assistant2/src/tools/project_index.rs 🔗

@@ -1,4 +1,4 @@
-use anyhow::{anyhow, Result};
+use anyhow::Result;
 use assistant_tooling::{LanguageModelTool, ToolOutput};
 use collections::BTreeMap;
 use gpui::{prelude::*, Model, Task};
@@ -6,9 +6,8 @@ use project::ProjectPath;
 use schemars::JsonSchema;
 use semantic_index::{ProjectIndex, Status};
 use serde::{Deserialize, Serialize};
-use serde_json::Value;
 use std::{fmt::Write as _, ops::Range, path::Path, sync::Arc};
-use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
+use ui::{prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
 
 const DEFAULT_SEARCH_LIMIT: usize = 20;
 
@@ -16,10 +15,26 @@ pub struct ProjectIndexTool {
     project_index: Model<ProjectIndex>,
 }
 
-// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
-// Any changes or deletions to the `CodebaseQuery` comments will change model behavior.
+#[derive(Default)]
+enum ProjectIndexToolState {
+    #[default]
+    CollectingQuery,
+    Searching,
+    Error(anyhow::Error),
+    Finished {
+        excerpts: BTreeMap<ProjectPath, Vec<Range<usize>>>,
+        index_status: Status,
+    },
+}
+
+pub struct ProjectIndexView {
+    project_index: Model<ProjectIndex>,
+    input: CodebaseQuery,
+    expanded_header: bool,
+    state: ProjectIndexToolState,
+}
 
-#[derive(Deserialize, JsonSchema)]
+#[derive(Default, Deserialize, JsonSchema)]
 pub struct CodebaseQuery {
     /// Semantic search query
     query: String,
@@ -27,21 +42,14 @@ pub struct CodebaseQuery {
     limit: Option<usize>,
 }
 
-pub struct ProjectIndexView {
-    input: CodebaseQuery,
-    status: Status,
-    excerpts: Result<BTreeMap<ProjectPath, Vec<Range<usize>>>>,
-    element_id: ElementId,
-    expanded_header: bool,
-}
-
 #[derive(Serialize, Deserialize)]
-pub struct ProjectIndexOutput {
-    status: Status,
+pub struct SerializedState {
+    index_status: Status,
+    error_message: Option<String>,
     worktrees: BTreeMap<Arc<Path>, WorktreeIndexOutput>,
 }
 
-#[derive(Serialize, Deserialize)]
+#[derive(Default, Serialize, Deserialize)]
 struct WorktreeIndexOutput {
     excerpts: BTreeMap<Arc<Path>, Vec<Range<usize>>>,
 }
@@ -56,58 +64,80 @@ impl ProjectIndexView {
 impl Render for ProjectIndexView {
     fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
         let query = self.input.query.clone();
-        let excerpts = match &self.excerpts {
-            Err(err) => {
-                return div().child(Label::new(format!("Error: {}", err)).color(Color::Error));
+
+        let (header_text, content) = match &self.state {
+            ProjectIndexToolState::Error(error) => {
+                return format!("failed to search: {error:?}").into_any_element()
+            }
+            ProjectIndexToolState::CollectingQuery | ProjectIndexToolState::Searching => {
+                ("Searching...".to_string(), div())
+            }
+            ProjectIndexToolState::Finished { excerpts, .. } => {
+                let file_count = excerpts.len();
+
+                let header_text = format!(
+                    "Read {} {}",
+                    file_count,
+                    if file_count == 1 { "file" } else { "files" }
+                );
+
+                let el = v_flex().gap_2().children(excerpts.keys().map(|path| {
+                    h_flex().gap_2().child(Icon::new(IconName::File)).child(
+                        Label::new(path.path.to_string_lossy().to_string()).color(Color::Muted),
+                    )
+                }));
+
+                (header_text, el)
             }
-            Ok(excerpts) => excerpts,
         };
 
-        let file_count = excerpts.len();
         let header = h_flex()
             .gap_2()
             .child(Icon::new(IconName::File))
-            .child(format!(
-                "Read {} {}",
-                file_count,
-                if file_count == 1 { "file" } else { "files" }
-            ));
-
-        v_flex().gap_3().child(
-            CollapsibleContainer::new(self.element_id.clone(), self.expanded_header)
-                .start_slot(header)
-                .on_click(cx.listener(move |this, _, cx| {
-                    this.toggle_header(cx);
-                }))
-                .child(
-                    v_flex()
-                        .gap_3()
-                        .p_3()
-                        .child(
-                            h_flex()
-                                .gap_2()
-                                .child(Icon::new(IconName::MagnifyingGlass))
-                                .child(Label::new(format!("`{}`", query)).color(Color::Muted)),
-                        )
-                        .child(v_flex().gap_2().children(excerpts.keys().map(|path| {
-                            h_flex().gap_2().child(Icon::new(IconName::File)).child(
-                                Label::new(path.path.to_string_lossy().to_string())
-                                    .color(Color::Muted),
+            .child(header_text);
+
+        v_flex()
+            .gap_3()
+            .child(
+                CollapsibleContainer::new("collapsible-container", self.expanded_header)
+                    .start_slot(header)
+                    .on_click(cx.listener(move |this, _, cx| {
+                        this.toggle_header(cx);
+                    }))
+                    .child(
+                        v_flex()
+                            .gap_3()
+                            .p_3()
+                            .child(
+                                h_flex()
+                                    .gap_2()
+                                    .child(Icon::new(IconName::MagnifyingGlass))
+                                    .child(Label::new(format!("`{}`", query)).color(Color::Muted)),
                             )
-                        }))),
-                ),
-        )
+                            .child(content),
+                    ),
+            )
+            .into_any_element()
     }
 }
 
 impl ToolOutput for ProjectIndexView {
+    type Input = CodebaseQuery;
+    type SerializedState = SerializedState;
+
     fn generate(
         &self,
         context: &mut assistant_tooling::ProjectContext,
-        _: &mut WindowContext,
+        _: &mut ViewContext<Self>,
     ) -> String {
-        match &self.excerpts {
-            Ok(excerpts) => {
+        match &self.state {
+            ProjectIndexToolState::CollectingQuery => String::new(),
+            ProjectIndexToolState::Searching => String::new(),
+            ProjectIndexToolState::Error(error) => format!("failed to search: {error:?}"),
+            ProjectIndexToolState::Finished {
+                excerpts,
+                index_status,
+            } => {
                 let mut body = "found results in the following paths:\n".to_string();
 
                 for (project_path, ranges) in excerpts {
@@ -115,139 +145,151 @@ impl ToolOutput for ProjectIndexView {
                     writeln!(&mut body, "* {}", &project_path.path.display()).unwrap();
                 }
 
-                if self.status != Status::Idle {
+                if *index_status != Status::Idle {
                     body.push_str("Still indexing. Results may be incomplete.\n");
                 }
 
                 body
             }
-            Err(err) => format!("Error: {}", err),
         }
     }
-}
-
-impl ProjectIndexTool {
-    pub fn new(project_index: Model<ProjectIndex>) -> Self {
-        Self { project_index }
-    }
-}
 
-impl LanguageModelTool for ProjectIndexTool {
-    type Input = CodebaseQuery;
-    type Output = ProjectIndexOutput;
-    type View = ProjectIndexView;
-
-    fn name(&self) -> String {
-        "query_codebase".to_string()
+    fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
+        self.input = input;
+        cx.notify();
     }
 
-    fn description(&self) -> String {
-        "Semantic search against the user's current codebase, returning excerpts related to the query by computing a dot product against embeddings of code chunks in the code base and an embedding of the query.".to_string()
-    }
+    fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
+        self.state = ProjectIndexToolState::Searching;
+        cx.notify();
 
-    fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
         let project_index = self.project_index.read(cx);
-        let status = project_index.status();
+        let index_status = project_index.status();
         let search = project_index.search(
-            query.query.clone(),
-            query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
+            self.input.query.clone(),
+            self.input.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
             cx,
         );
 
-        cx.spawn(|mut cx| async move {
-            let search_results = search.await?;
-
-            cx.update(|cx| {
-                let mut output = ProjectIndexOutput {
-                    status,
-                    worktrees: Default::default(),
-                };
-
-                for search_result in search_results {
-                    let worktree_path = search_result.worktree.read(cx).abs_path();
-                    let excerpts = &mut output
-                        .worktrees
-                        .entry(worktree_path)
-                        .or_insert(WorktreeIndexOutput {
-                            excerpts: Default::default(),
-                        })
-                        .excerpts;
-
-                    let excerpts_for_path = excerpts.entry(search_result.path).or_default();
-                    let ix = match excerpts_for_path
-                        .binary_search_by_key(&search_result.range.start, |r| r.start)
-                    {
-                        Ok(ix) | Err(ix) => ix,
-                    };
-                    excerpts_for_path.insert(ix, search_result.range);
+        cx.spawn(|this, mut cx| async move {
+            let search_result = search.await;
+            this.update(&mut cx, |this, cx| {
+                match search_result {
+                    Ok(search_results) => {
+                        let mut excerpts = BTreeMap::<ProjectPath, Vec<Range<usize>>>::new();
+                        for search_result in search_results {
+                            let project_path = ProjectPath {
+                                worktree_id: search_result.worktree.read(cx).id(),
+                                path: search_result.path,
+                            };
+                            excerpts
+                                .entry(project_path)
+                                .or_default()
+                                .push(search_result.range);
+                        }
+                        this.state = ProjectIndexToolState::Finished {
+                            excerpts,
+                            index_status,
+                        };
+                    }
+                    Err(error) => {
+                        this.state = ProjectIndexToolState::Error(error);
+                    }
                 }
-
-                output
+                cx.notify();
             })
         })
     }
 
-    fn view(
-        &self,
-        input: Self::Input,
-        output: Result<Self::Output>,
-        cx: &mut WindowContext,
-    ) -> gpui::View<Self::View> {
-        cx.new_view(|cx| {
-            let status;
-            let excerpts;
-            match output {
-                Ok(output) => {
-                    status = output.status;
-                    let project_index = self.project_index.read(cx);
-                    if let Some(project) = project_index.project().upgrade() {
-                        let project = project.read(cx);
-                        excerpts = Ok(output
-                            .worktrees
-                            .into_iter()
-                            .filter_map(|(abs_path, output)| {
-                                for worktree in project.worktrees() {
-                                    let worktree = worktree.read(cx);
-                                    if worktree.abs_path() == abs_path {
-                                        return Some((worktree.id(), output.excerpts));
-                                    }
-                                }
-                                None
-                            })
-                            .flat_map(|(worktree_id, excerpts)| {
-                                excerpts.into_iter().map(move |(path, ranges)| {
-                                    (ProjectPath { worktree_id, path }, ranges)
-                                })
-                            })
-                            .collect::<BTreeMap<_, _>>());
-                    } else {
-                        excerpts = Err(anyhow!("project was dropped"));
+    fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState {
+        let mut serialized = SerializedState {
+            error_message: None,
+            index_status: Status::Idle,
+            worktrees: Default::default(),
+        };
+        match &self.state {
+            ProjectIndexToolState::Error(err) => serialized.error_message = Some(err.to_string()),
+            ProjectIndexToolState::Finished {
+                excerpts,
+                index_status,
+            } => {
+                serialized.index_status = *index_status;
+                if let Some(project) = self.project_index.read(cx).project().upgrade() {
+                    let project = project.read(cx);
+                    for (project_path, excerpts) in excerpts {
+                        if let Some(worktree) =
+                            project.worktree_for_id(project_path.worktree_id, cx)
+                        {
+                            let worktree_path = worktree.read(cx).abs_path();
+                            serialized
+                                .worktrees
+                                .entry(worktree_path)
+                                .or_default()
+                                .excerpts
+                                .insert(project_path.path.clone(), excerpts.clone());
+                        }
                     }
                 }
-                Err(err) => {
-                    status = Status::Idle;
-                    excerpts = Err(err);
+            }
+            _ => {}
+        }
+        serialized
+    }
+
+    fn deserialize(
+        &mut self,
+        serialized: Self::SerializedState,
+        cx: &mut ViewContext<Self>,
+    ) -> Result<()> {
+        if !serialized.worktrees.is_empty() {
+            let mut excerpts = BTreeMap::<ProjectPath, Vec<Range<usize>>>::new();
+            if let Some(project) = self.project_index.read(cx).project().upgrade() {
+                let project = project.read(cx);
+                for (worktree_path, worktree_state) in serialized.worktrees {
+                    if let Some(worktree) = project
+                        .worktrees()
+                        .find(|worktree| worktree.read(cx).abs_path() == worktree_path)
+                    {
+                        let worktree_id = worktree.read(cx).id();
+                        for (path, serialized_excerpts) in worktree_state.excerpts {
+                            excerpts.insert(ProjectPath { worktree_id, path }, serialized_excerpts);
+                        }
+                    }
                 }
+            }
+            self.state = ProjectIndexToolState::Finished {
+                excerpts,
+                index_status: serialized.index_status,
             };
+        }
+        cx.notify();
+        Ok(())
+    }
+}
 
-            ProjectIndexView {
-                input,
-                status,
-                excerpts,
-                element_id: ElementId::Name(nanoid::nanoid!().into()),
-                expanded_header: false,
-            }
-        })
+impl ProjectIndexTool {
+    pub fn new(project_index: Model<ProjectIndex>) -> Self {
+        Self { project_index }
     }
+}
 
-    fn render_running(arguments: &Option<Value>, _: &mut WindowContext) -> impl IntoElement {
-        let text: String = arguments
-            .as_ref()
-            .and_then(|arguments| arguments.get("query"))
-            .and_then(|query| query.as_str())
-            .map(|query| format!("Searching for: {}", query))
-            .unwrap_or_else(|| "Preparing search...".to_string());
+impl LanguageModelTool for ProjectIndexTool {
+    type View = ProjectIndexView;
 
-        CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false).start_slot(text)
+    fn name(&self) -> String {
+        "query_codebase".to_string()
+    }
+
+    fn description(&self) -> String {
+        "Semantic search against the user's current codebase, returning excerpts related to the query by computing a dot product against embeddings of code chunks in the code base and an embedding of the query.".to_string()
+    }
+
+    fn view(&self, cx: &mut WindowContext) -> gpui::View<Self::View> {
+        cx.new_view(|_| ProjectIndexView {
+            state: ProjectIndexToolState::CollectingQuery,
+            input: Default::default(),
+            expanded_header: false,
+            project_index: self.project_index.clone(),
+        })
     }
 }

crates/assistant_tooling/Cargo.toml 🔗

@@ -16,7 +16,9 @@ anyhow.workspace = true
 collections.workspace = true
 futures.workspace = true
 gpui.workspace = true
+log.workspace = true
 project.workspace = true
+repair_json.workspace = true
 schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true

crates/assistant_tooling/src/assistant_tooling.rs 🔗

@@ -3,11 +3,11 @@ mod project_context;
 mod tool_registry;
 
 pub use attachment_registry::{
-    AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, UserAttachment,
+    AttachmentOutput, AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment,
+    UserAttachment,
 };
 pub use project_context::ProjectContext;
 pub use tool_registry::{
-    tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall,
-    SavedToolFunctionCallResult, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
-    ToolOutput, ToolRegistry,
+    tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall, SavedToolFunctionCallState,
+    ToolFunctionCall, ToolFunctionCallState, ToolFunctionDefinition, ToolOutput, ToolRegistry,
 };

crates/assistant_tooling/src/attachment_registry.rs 🔗

@@ -1,4 +1,4 @@
-use crate::{ProjectContext, ToolOutput};
+use crate::ProjectContext;
 use anyhow::{anyhow, Result};
 use collections::HashMap;
 use futures::future::join_all;
@@ -18,9 +18,13 @@ pub struct AttachmentRegistry {
     registered_attachments: HashMap<TypeId, RegisteredAttachment>,
 }
 
+pub trait AttachmentOutput {
+    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
+}
+
 pub trait LanguageModelAttachment {
     type Output: DeserializeOwned + Serialize + 'static;
-    type View: Render + ToolOutput;
+    type View: Render + AttachmentOutput;
 
     fn name(&self) -> Arc<str>;
     fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;

crates/assistant_tooling/src/tool_registry.rs 🔗

@@ -1,11 +1,10 @@
 use crate::ProjectContext;
 use anyhow::{anyhow, Result};
-use gpui::{
-    div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
-};
+use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
+use repair_json::repair;
 use schemars::{schema::RootSchema, schema_for, JsonSchema};
 use serde::{de::DeserializeOwned, Deserialize, Serialize};
-use serde_json::{value::RawValue, Value};
+use serde_json::value::RawValue;
 use std::{
     any::TypeId,
     collections::HashMap,
@@ -15,6 +14,7 @@ use std::{
         Arc,
     },
 };
+use ui::ViewContext;
 
 pub struct ToolRegistry {
     registered_tools: HashMap<String, RegisteredTool>,
@@ -25,7 +25,25 @@ pub struct ToolFunctionCall {
     pub id: String,
     pub name: String,
     pub arguments: String,
-    pub result: Option<ToolFunctionCallResult>,
+    state: ToolFunctionCallState,
+}
+
+#[derive(Default)]
+pub enum ToolFunctionCallState {
+    #[default]
+    Initializing,
+    NoSuchTool,
+    KnownTool(Box<dyn ToolView>),
+    ExecutedTool(Box<dyn ToolView>),
+}
+
+pub trait ToolView {
+    fn view(&self) -> AnyView;
+    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
+    fn set_input(&self, input: &str, cx: &mut WindowContext);
+    fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
+    fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
+    fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
 }
 
 #[derive(Default, Serialize, Deserialize)]
@@ -33,29 +51,19 @@ pub struct SavedToolFunctionCall {
     pub id: String,
     pub name: String,
     pub arguments: String,
-    pub result: Option<SavedToolFunctionCallResult>,
-}
-
-pub enum ToolFunctionCallResult {
-    NoSuchTool,
-    ParsingFailed,
-    Finished {
-        view: AnyView,
-        serialized_output: Result<Box<RawValue>, String>,
-        generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
-    },
+    pub state: SavedToolFunctionCallState,
 }
 
-#[derive(Serialize, Deserialize)]
-pub enum SavedToolFunctionCallResult {
+#[derive(Default, Serialize, Deserialize)]
+pub enum SavedToolFunctionCallState {
+    #[default]
+    Initializing,
     NoSuchTool,
-    ParsingFailed,
-    Finished {
-        serialized_output: Result<Box<RawValue>, String>,
-    },
+    KnownTool,
+    ExecutedTool(Box<RawValue>),
 }
 
-#[derive(Clone)]
+#[derive(Clone, Debug)]
 pub struct ToolFunctionDefinition {
     pub name: String,
     pub description: String,
@@ -63,14 +71,7 @@ pub struct ToolFunctionDefinition {
 }
 
 pub trait LanguageModelTool {
-    /// The input type that will be passed in to `execute` when the tool is called
-    /// by the language model.
-    type Input: DeserializeOwned + JsonSchema;
-
-    /// The output returned by executing the tool.
-    type Output: DeserializeOwned + Serialize + 'static;
-
-    type View: Render + ToolOutput;
+    type View: ToolOutput;
 
     /// Returns the name of the tool.
     ///
@@ -86,7 +87,7 @@ pub trait LanguageModelTool {
 
     /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
     fn definition(&self) -> ToolFunctionDefinition {
-        let root_schema = schema_for!(Self::Input);
+        let root_schema = schema_for!(<Self::View as ToolOutput>::Input);
 
         ToolFunctionDefinition {
             name: self.name(),
@@ -95,36 +96,46 @@ pub trait LanguageModelTool {
         }
     }
 
-    /// Executes the tool with the given input.
-    fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
-
     /// A view of the output of running the tool, for displaying to the user.
-    fn view(
-        &self,
-        input: Self::Input,
-        output: Result<Self::Output>,
-        cx: &mut WindowContext,
-    ) -> View<Self::View>;
-
-    fn render_running(_arguments: &Option<Value>, _cx: &mut WindowContext) -> impl IntoElement {
-        tool_running_placeholder()
-    }
+    fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
 }
 
 pub fn tool_running_placeholder() -> AnyElement {
     ui::Label::new("Researching...").into_any_element()
 }
 
-pub trait ToolOutput: Sized {
-    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
+pub fn unknown_tool_placeholder() -> AnyElement {
+    ui::Label::new("Unknown tool").into_any_element()
+}
+
+pub fn no_such_tool_placeholder() -> AnyElement {
+    ui::Label::new("No such tool").into_any_element()
+}
+
+pub trait ToolOutput: Render {
+    /// The input type that will be passed in to `execute` when the tool is called
+    /// by the language model.
+    type Input: DeserializeOwned + JsonSchema;
+
+    /// The output returned by executing the tool.
+    type SerializedState: DeserializeOwned + Serialize;
+
+    fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext<Self>) -> String;
+    fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>);
+    fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>>;
+
+    fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState;
+    fn deserialize(
+        &mut self,
+        output: Self::SerializedState,
+        cx: &mut ViewContext<Self>,
+    ) -> Result<()>;
 }
 
 struct RegisteredTool {
     enabled: AtomicBool,
     type_id: TypeId,
-    execute: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
-    deserialize: Box<dyn Fn(&SavedToolFunctionCall, &mut WindowContext) -> ToolFunctionCall>,
-    render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
+    build_view: Box<dyn Fn(&mut WindowContext) -> Box<dyn ToolView>>,
     definition: ToolFunctionDefinition,
 }
 
@@ -161,63 +172,132 @@ impl ToolRegistry {
             .collect()
     }
 
-    pub fn render_tool_call(
+    pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option<Box<dyn ToolView>> {
+        let tool = self.registered_tools.get(name)?;
+        Some((tool.build_view)(cx))
+    }
+
+    pub fn update_tool_call(
         &self,
-        tool_call: &ToolFunctionCall,
+        call: &mut ToolFunctionCall,
+        name: Option<&str>,
+        arguments: Option<&str>,
         cx: &mut WindowContext,
-    ) -> AnyElement {
-        match &tool_call.result {
-            Some(result) => div()
-                .p_2()
-                .child(result.into_any_element(&tool_call.name))
-                .into_any_element(),
-            None => {
-                let tool = self.registered_tools.get(&tool_call.name);
-
-                if let Some(tool) = tool {
-                    (tool.render_running)(&tool_call, cx)
+    ) {
+        if let Some(name) = name {
+            call.name.push_str(name);
+        }
+        if let Some(arguments) = arguments {
+            if call.arguments.is_empty() {
+                if let Some(view) = self.view_for_tool(&call.name, cx) {
+                    call.state = ToolFunctionCallState::KnownTool(view);
                 } else {
-                    tool_running_placeholder()
+                    call.state = ToolFunctionCallState::NoSuchTool;
+                }
+            }
+            call.arguments.push_str(arguments);
+
+            if let ToolFunctionCallState::KnownTool(view) = &call.state {
+                if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
+                    view.set_input(&repaired_arguments, cx)
                 }
             }
         }
     }
 
-    pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall {
-        SavedToolFunctionCall {
+    pub fn execute_tool_call(
+        &self,
+        tool_call: &ToolFunctionCall,
+        cx: &mut WindowContext,
+    ) -> Option<Task<Result<()>>> {
+        if let ToolFunctionCallState::KnownTool(view) = &tool_call.state {
+            Some(view.execute(cx))
+        } else {
+            None
+        }
+    }
+
+    pub fn render_tool_call(
+        &self,
+        tool_call: &ToolFunctionCall,
+        _cx: &mut WindowContext,
+    ) -> AnyElement {
+        match &tool_call.state {
+            ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(),
+            ToolFunctionCallState::Initializing => unknown_tool_placeholder(),
+            ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
+                view.view().into_any_element()
+            }
+        }
+    }
+
+    pub fn content_for_tool_call(
+        &self,
+        tool_call: &ToolFunctionCall,
+        project_context: &mut ProjectContext,
+        cx: &mut WindowContext,
+    ) -> String {
+        match &tool_call.state {
+            ToolFunctionCallState::Initializing => String::new(),
+            ToolFunctionCallState::NoSuchTool => {
+                format!("No such tool: {}", tool_call.name)
+            }
+            ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
+                view.generate(project_context, cx)
+            }
+        }
+    }
+
+    pub fn serialize_tool_call(
+        &self,
+        call: &ToolFunctionCall,
+        cx: &mut WindowContext,
+    ) -> Result<SavedToolFunctionCall> {
+        Ok(SavedToolFunctionCall {
             id: call.id.clone(),
             name: call.name.clone(),
             arguments: call.arguments.clone(),
-            result: call.result.as_ref().map(|result| match result {
-                ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool,
-                ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed,
-                ToolFunctionCallResult::Finished {
-                    serialized_output, ..
-                } => SavedToolFunctionCallResult::Finished {
-                    serialized_output: match serialized_output {
-                        Ok(value) => Ok(value.clone()),
-                        Err(e) => Err(e.to_string()),
-                    },
-                },
-            }),
-        }
+            state: match &call.state {
+                ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing,
+                ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool,
+                ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool,
+                ToolFunctionCallState::ExecutedTool(view) => {
+                    SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?)
+                }
+            },
+        })
     }
 
     pub fn deserialize_tool_call(
         &self,
         call: &SavedToolFunctionCall,
         cx: &mut WindowContext,
-    ) -> ToolFunctionCall {
-        if let Some(tool) = &self.registered_tools.get(&call.name) {
-            (tool.deserialize)(call, cx)
-        } else {
-            ToolFunctionCall {
-                id: call.id.clone(),
-                name: call.name.clone(),
-                arguments: call.arguments.clone(),
-                result: Some(ToolFunctionCallResult::NoSuchTool),
-            }
-        }
+    ) -> Result<ToolFunctionCall> {
+        let Some(tool) = self.registered_tools.get(&call.name) else {
+            return Err(anyhow!("no such tool {}", call.name));
+        };
+
+        Ok(ToolFunctionCall {
+            id: call.id.clone(),
+            name: call.name.clone(),
+            arguments: call.arguments.clone(),
+            state: match &call.state {
+                SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing,
+                SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool,
+                SavedToolFunctionCallState::KnownTool => {
+                    log::error!("Deserialized tool that had not executed");
+                    let view = (tool.build_view)(cx);
+                    view.set_input(&call.arguments, cx);
+                    ToolFunctionCallState::KnownTool(view)
+                }
+                SavedToolFunctionCallState::ExecutedTool(output) => {
+                    let view = (tool.build_view)(cx);
+                    view.set_input(&call.arguments, cx);
+                    view.deserialize_output(output, cx)?;
+                    ToolFunctionCallState::ExecutedTool(view)
+                }
+            },
+        })
     }
 
     pub fn register<T: 'static + LanguageModelTool>(
@@ -231,114 +311,7 @@ impl ToolRegistry {
             type_id: TypeId::of::<T>(),
             definition: tool.definition(),
             enabled: AtomicBool::new(true),
-            deserialize: Box::new({
-                let tool = tool.clone();
-                move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| {
-                    let id = tool_call.id.clone();
-                    let name = tool_call.name.clone();
-                    let arguments = tool_call.arguments.clone();
-
-                    let Ok(input) = serde_json::from_str::<T::Input>(&tool_call.arguments) else {
-                        return ToolFunctionCall {
-                            id,
-                            name: name.clone(),
-                            arguments,
-                            result: Some(ToolFunctionCallResult::ParsingFailed),
-                        };
-                    };
-
-                    let result = match &tool_call.result {
-                        Some(result) => match result {
-                            SavedToolFunctionCallResult::NoSuchTool => {
-                                Some(ToolFunctionCallResult::NoSuchTool)
-                            }
-                            SavedToolFunctionCallResult::ParsingFailed => {
-                                Some(ToolFunctionCallResult::ParsingFailed)
-                            }
-                            SavedToolFunctionCallResult::Finished { serialized_output } => {
-                                let output = match serialized_output {
-                                    Ok(value) => {
-                                        match serde_json::from_str::<T::Output>(value.get()) {
-                                            Ok(value) => Ok(value),
-                                            Err(_) => {
-                                                return ToolFunctionCall {
-                                                    id,
-                                                    name: name.clone(),
-                                                    arguments,
-                                                    result: Some(
-                                                        ToolFunctionCallResult::ParsingFailed,
-                                                    ),
-                                                };
-                                            }
-                                        }
-                                    }
-                                    Err(e) => Err(anyhow!("{e}")),
-                                };
-
-                                let view = tool.view(input, output, cx).into();
-                                Some(ToolFunctionCallResult::Finished {
-                                    serialized_output: serialized_output.clone(),
-                                    generate_fn: generate::<T>,
-                                    view,
-                                })
-                            }
-                        },
-                        None => None,
-                    };
-
-                    ToolFunctionCall {
-                        id: tool_call.id.clone(),
-                        name: name.clone(),
-                        arguments: tool_call.arguments.clone(),
-                        result,
-                    }
-                }
-            }),
-            execute: Box::new({
-                let tool = tool.clone();
-                move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
-                    let id = tool_call.id.clone();
-                    let name = tool_call.name.clone();
-                    let arguments = tool_call.arguments.clone();
-
-                    let Ok(input) = serde_json::from_str::<T::Input>(&arguments) else {
-                        return Task::ready(Ok(ToolFunctionCall {
-                            id,
-                            name: name.clone(),
-                            arguments,
-                            result: Some(ToolFunctionCallResult::ParsingFailed),
-                        }));
-                    };
-
-                    let result = tool.execute(&input, cx);
-                    let tool = tool.clone();
-                    cx.spawn(move |mut cx| async move {
-                        let result = result.await;
-                        let serialized_output = result
-                            .as_ref()
-                            .map_err(ToString::to_string)
-                            .and_then(|output| {
-                                Ok(RawValue::from_string(
-                                    serde_json::to_string(output).map_err(|e| e.to_string())?,
-                                )
-                                .unwrap())
-                            });
-                        let view = cx.update(|cx| tool.view(input, result, cx))?;
-
-                        Ok(ToolFunctionCall {
-                            id,
-                            name: name.clone(),
-                            arguments,
-                            result: Some(ToolFunctionCallResult::Finished {
-                                serialized_output,
-                                view: view.into(),
-                                generate_fn: generate::<T>,
-                            }),
-                        })
-                    })
-                }
-            }),
-            render_running: render_running::<T>,
+            build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))),
         };
 
         let previous = self.registered_tools.insert(name.clone(), registered_tool);
@@ -347,83 +320,40 @@ impl ToolRegistry {
         }
 
         return Ok(());
+    }
+}
 
-        fn render_running<T: LanguageModelTool>(
-            tool_call: &ToolFunctionCall,
-            cx: &mut WindowContext,
-        ) -> AnyElement {
-            // Attempt to parse the string arguments that are JSON as a JSON value
-            let maybe_arguments = serde_json::to_value(tool_call.arguments.clone()).ok();
+impl<T: ToolOutput> ToolView for View<T> {
+    fn view(&self) -> AnyView {
+        self.clone().into()
+    }
 
-            T::render_running(&maybe_arguments, cx).into_any_element()
-        }
+    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
+        self.update(cx, |view, cx| view.generate(project, cx))
+    }
 
-        fn generate<T: LanguageModelTool>(
-            view: AnyView,
-            project: &mut ProjectContext,
-            cx: &mut WindowContext,
-        ) -> String {
-            view.downcast::<T::View>()
-                .unwrap()
-                .update(cx, |view, cx| T::View::generate(view, project, cx))
+    fn set_input(&self, input: &str, cx: &mut WindowContext) {
+        if let Ok(input) = serde_json::from_str::<T::Input>(input) {
+            self.update(cx, |view, cx| {
+                view.set_input(input, cx);
+                cx.notify();
+            });
         }
     }
 
-    /// Task yields an error if the window for the given WindowContext is closed before the task completes.
-    pub fn call(
-        &self,
-        tool_call: &ToolFunctionCall,
-        cx: &mut WindowContext,
-    ) -> Task<Result<ToolFunctionCall>> {
-        let name = tool_call.name.clone();
-        let arguments = tool_call.arguments.clone();
-        let id = tool_call.id.clone();
-
-        let tool = match self.registered_tools.get(&name) {
-            Some(tool) => tool,
-            None => {
-                let name = name.clone();
-                return Task::ready(Ok(ToolFunctionCall {
-                    id,
-                    name: name.clone(),
-                    arguments,
-                    result: Some(ToolFunctionCallResult::NoSuchTool),
-                }));
-            }
-        };
-
-        (tool.execute)(tool_call, cx)
+    fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>> {
+        self.update(cx, |view, cx| view.execute(cx))
     }
-}
 
-impl ToolFunctionCallResult {
-    pub fn generate(
-        &self,
-        name: &String,
-        project: &mut ProjectContext,
-        cx: &mut WindowContext,
-    ) -> String {
-        match self {
-            ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
-            ToolFunctionCallResult::ParsingFailed => {
-                format!("Unable to parse arguments for {name}")
-            }
-            ToolFunctionCallResult::Finished {
-                generate_fn, view, ..
-            } => (generate_fn)(view.clone(), project, cx),
-        }
+    fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>> {
+        let output = self.update(cx, |view, cx| view.serialize(cx));
+        Ok(RawValue::from_string(serde_json::to_string(&output)?)?)
     }
 
-    fn into_any_element(&self, name: &String) -> AnyElement {
-        match self {
-            ToolFunctionCallResult::NoSuchTool => {
-                format!("Language Model attempted to call {name}").into_any_element()
-            }
-            ToolFunctionCallResult::ParsingFailed => {
-                format!("Language Model called {name} with bad arguments").into_any_element()
-            }
-            ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
-        }
+    fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> {
+        let state = serde_json::from_str::<T::SerializedState>(output.get())?;
+        self.update(cx, |view, cx| view.deserialize(state, cx))?;
+        Ok(())
     }
 }
 
@@ -453,10 +383,6 @@ mod test {
         unit: String,
     }
 
-    struct WeatherTool {
-        current_weather: WeatherResult,
-    }
-
     #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
     struct WeatherResult {
         location: String,
@@ -465,24 +391,81 @@ mod test {
     }
 
     struct WeatherView {
-        result: WeatherResult,
+        input: Option<WeatherQuery>,
+        result: Option<WeatherResult>,
+
+        // Fake API call
+        current_weather: WeatherResult,
+    }
+
+    #[derive(Clone, Serialize)]
+    struct WeatherTool {
+        current_weather: WeatherResult,
+    }
+
+    impl WeatherView {
+        fn new(current_weather: WeatherResult) -> Self {
+            Self {
+                input: None,
+                result: None,
+                current_weather,
+            }
+        }
     }
 
     impl Render for WeatherView {
         fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
-            div().child(format!("temperature: {}", self.result.temperature))
+            match self.result {
+                Some(ref result) => div()
+                    .child(format!("temperature: {}", result.temperature))
+                    .into_any_element(),
+                None => div().child("Calculating weather...").into_any_element(),
+            }
         }
     }
 
     impl ToolOutput for WeatherView {
-        fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
+        type Input = WeatherQuery;
+
+        type SerializedState = WeatherResult;
+
+        fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
             serde_json::to_string(&self.result).unwrap()
         }
+
+        fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
+            self.input = Some(input);
+            cx.notify();
+        }
+
+        fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
+            let input = self.input.as_ref().unwrap();
+
+            let _location = input.location.clone();
+            let _unit = input.unit.clone();
+
+            let weather = self.current_weather.clone();
+
+            self.result = Some(weather);
+
+            Task::ready(Ok(()))
+        }
+
+        fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
+            self.current_weather.clone()
+        }
+
+        fn deserialize(
+            &mut self,
+            output: Self::SerializedState,
+            _cx: &mut ViewContext<Self>,
+        ) -> Result<()> {
+            self.current_weather = output;
+            Ok(())
+        }
     }
 
     impl LanguageModelTool for WeatherTool {
-        type Input = WeatherQuery;
-        type Output = WeatherResult;
         type View = WeatherView;
 
         fn name(&self) -> String {
@@ -493,29 +476,8 @@ mod test {
             "Fetches the current weather for a given location.".to_string()
         }
 
-        fn execute(
-            &self,
-            input: &Self::Input,
-            _cx: &mut WindowContext,
-        ) -> Task<Result<Self::Output>> {
-            let _location = input.location.clone();
-            let _unit = input.unit.clone();
-
-            let weather = self.current_weather.clone();
-
-            Task::ready(Ok(weather))
-        }
-
-        fn view(
-            &self,
-            _input: Self::Input,
-            result: Result<Self::Output>,
-            cx: &mut WindowContext,
-        ) -> View<Self::View> {
-            cx.new_view(|_cx| {
-                let result = result.unwrap();
-                WeatherView { result }
-            })
+        fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
+            cx.new_view(|_cx| WeatherView::new(self.current_weather.clone()))
         }
     }
 
@@ -564,18 +526,14 @@ mod test {
             })
         );
 
-        let args = json!({
-            "location": "San Francisco",
-            "unit": "Celsius"
-        });
-
-        let query: WeatherQuery = serde_json::from_value(args).unwrap();
+        let view = cx.update(|cx| tool.view(cx));
 
-        let result = cx.update(|cx| tool.execute(&query, cx)).await;
+        cx.update(|cx| {
+            view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx);
+        });
 
-        assert!(result.is_ok());
-        let result = result.unwrap();
+        let finished = cx.update(|cx| view.execute(cx)).await;
 
-        assert_eq!(result, tool.current_weather);
+        assert!(finished.is_ok());
     }
 }

crates/multi_buffer/src/multi_buffer.rs 🔗

@@ -1603,6 +1603,11 @@ impl MultiBuffer {
         "untitled".into()
     }
 
+    pub fn set_title(&mut self, title: String, cx: &mut ModelContext<Self>) {
+        self.title = Some(title);
+        cx.notify();
+    }
+
     #[cfg(any(test, feature = "test-support"))]
     pub fn is_parsing(&self, cx: &AppContext) -> bool {
         self.as_singleton().unwrap().read(cx).is_parsing()
@@ -3151,10 +3156,10 @@ impl MultiBufferSnapshot {
                         .redacted_ranges(excerpt.range.context.clone())
                         .map(move |mut redacted_range| {
                             // Re-base onto the excerpts coordinates in the multibuffer
-                            redacted_range.start =
-                                excerpt_offset + (redacted_range.start - excerpt_buffer_start);
-                            redacted_range.end =
-                                excerpt_offset + (redacted_range.end - excerpt_buffer_start);
+                            redacted_range.start = excerpt_offset
+                                + redacted_range.start.saturating_sub(excerpt_buffer_start);
+                            redacted_range.end = excerpt_offset
+                                + redacted_range.end.saturating_sub(excerpt_buffer_start);
 
                             redacted_range
                         })
@@ -3179,10 +3184,13 @@ impl MultiBufferSnapshot {
                     .runnable_ranges(excerpt.range.context.clone())
                     .map(move |mut runnable| {
                         // Re-base onto the excerpts coordinates in the multibuffer
-                        runnable.run_range.start =
-                            excerpt_offset + (runnable.run_range.start - excerpt_buffer_start);
-                        runnable.run_range.end =
-                            excerpt_offset + (runnable.run_range.end - excerpt_buffer_start);
+                        runnable.run_range.start = excerpt_offset
+                            + runnable
+                                .run_range
+                                .start
+                                .saturating_sub(excerpt_buffer_start);
+                        runnable.run_range.end = excerpt_offset
+                            + runnable.run_range.end.saturating_sub(excerpt_buffer_start);
                         runnable
                     })
                     .skip_while(move |runnable| runnable.run_range.end < range.start)