Detailed changes
@@ -390,7 +390,6 @@ dependencies = [
"language",
"languages",
"log",
- "nanoid",
"node_runtime",
"open_ai",
"picker",
@@ -419,7 +418,9 @@ dependencies = [
"collections",
"futures 0.3.28",
"gpui",
+ "log",
"project",
+ "repair_json",
"schemars",
"serde",
"serde_json",
@@ -8050,6 +8051,15 @@ dependencies = [
"bytecheck",
]
+[[package]]
+name = "repair_json"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5ee191e184125fe72cb59b74160e25584e3908f2aaa84cbda1e161347102aa15"
+dependencies = [
+ "thiserror",
+]
+
[[package]]
name = "reqwest"
version = "0.11.20"
@@ -10185,18 +10195,18 @@ dependencies = [
[[package]]
name = "thiserror"
-version = "1.0.48"
+version = "1.0.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7"
+checksum = "579e9083ca58dd9dcf91a9923bb9054071b9ebbd800b342194c9feb0ee89fc18"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
-version = "1.0.48"
+version = "1.0.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35"
+checksum = "e2470041c06ec3ac1ab38d0356a6119054dedaea53e12fbefc0de730a1c08524"
dependencies = [
"proc-macro2",
"quote",
@@ -307,6 +307,7 @@ pulldown-cmark = { version = "0.10.0", default-features = false }
rand = "0.8.5"
refineable = { path = "./crates/refineable" }
regex = "1.5"
+repair_json = "0.1.0"
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
rust-embed = { version = "8.0", features = ["include-exclude"] }
schemars = "0.8"
@@ -29,7 +29,6 @@ fuzzy.workspace = true
gpui.workspace = true
language.workspace = true
log.workspace = true
-nanoid.workspace = true
open_ai.workspace = true
picker.workspace = true
project.workspace = true
@@ -536,25 +536,27 @@ impl AssistantChat {
body.push_str(content);
}
- for tool_call in delta.tool_calls {
- let index = tool_call.index as usize;
+ for tool_call_delta in delta.tool_calls {
+ let index = tool_call_delta.index as usize;
if index >= message.tool_calls.len() {
message.tool_calls.resize_with(index + 1, Default::default);
}
- let call = &mut message.tool_calls[index];
+ let tool_call = &mut message.tool_calls[index];
- if let Some(id) = &tool_call.id {
- call.id.push_str(id);
+ if let Some(id) = &tool_call_delta.id {
+ tool_call.id.push_str(id);
}
- match tool_call.variant {
- Some(proto::tool_call_delta::Variant::Function(tool_call)) => {
- if let Some(name) = &tool_call.name {
- call.name.push_str(name);
- }
- if let Some(arguments) = &tool_call.arguments {
- call.arguments.push_str(arguments);
- }
+ match tool_call_delta.variant {
+ Some(proto::tool_call_delta::Variant::Function(
+ tool_call_delta,
+ )) => {
+ this.tool_registry.update_tool_call(
+ tool_call,
+ tool_call_delta.name.as_deref(),
+ tool_call_delta.arguments.as_deref(),
+ cx,
+ );
}
None => {}
}
@@ -587,34 +589,20 @@ impl AssistantChat {
} else {
if let Some(current_message) = messages.last_mut() {
for tool_call in current_message.tool_calls.iter() {
- tool_tasks.push(this.tool_registry.call(tool_call, cx));
+ tool_tasks
+ .extend(this.tool_registry.execute_tool_call(&tool_call, cx));
}
}
}
}
})?;
+ // This ends recursion on calling for responses after tools
if tool_tasks.is_empty() {
return Ok(());
}
- let tools = join_all(tool_tasks.into_iter()).await;
- // If the WindowContext went away for any tool's view we don't include it
- // especially since the below call would fail for the same reason.
- let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
-
- this.update(cx, |this, cx| {
- if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
- this.messages.last_mut()
- {
- if let Some(current_message) = messages.last_mut() {
- current_message.tool_calls = tools;
- cx.notify();
- } else {
- unreachable!()
- }
- }
- })?;
+ join_all(tool_tasks.into_iter()).await;
}
}
@@ -948,13 +936,11 @@ impl AssistantChat {
for tool_call in &message.tool_calls {
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
- let content = match &tool_call.result {
- Some(result) => {
- result.generate(&tool_call.name, &mut project_context, cx)
- }
- None => "".to_string(),
- };
-
+ let content = self.tool_registry.content_for_tool_call(
+ tool_call,
+ &mut project_context,
+ cx,
+ );
completion_messages.push(CompletionMessage::Tool {
content,
tool_call_id: tool_call.id.clone(),
@@ -1003,7 +989,11 @@ impl AssistantChat {
tool_calls: message
.tool_calls
.iter()
- .map(|tool_call| self.tool_registry.serialize_tool_call(tool_call))
+ .filter_map(|tool_call| {
+ self.tool_registry
+ .serialize_tool_call(tool_call, cx)
+ .log_err()
+ })
.collect(),
})
.collect(),
@@ -1,7 +1,7 @@
use std::{path::PathBuf, sync::Arc};
use anyhow::{anyhow, Result};
-use assistant_tooling::{LanguageModelAttachment, ProjectContext, ToolOutput};
+use assistant_tooling::{AttachmentOutput, LanguageModelAttachment, ProjectContext};
use editor::Editor;
use gpui::{Render, Task, View, WeakModel, WeakView};
use language::Buffer;
@@ -52,7 +52,7 @@ impl Render for FileAttachmentView {
}
}
-impl ToolOutput for FileAttachmentView {
+impl AttachmentOutput for FileAttachmentView {
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
if let Some(path) = &self.project_path {
project.add_file(path.clone());
@@ -4,7 +4,8 @@ use editor::{
display_map::{BlockContext, BlockDisposition, BlockProperties, BlockStyle},
Editor, MultiBuffer,
};
-use gpui::{prelude::*, AnyElement, Model, Task, View, WeakView};
+use futures::{channel::mpsc::UnboundedSender, StreamExt as _};
+use gpui::{prelude::*, AnyElement, AsyncWindowContext, Model, Task, View, WeakView};
use language::ToPoint;
use project::{search::SearchQuery, Project, ProjectPath};
use schemars::JsonSchema;
@@ -25,14 +26,19 @@ impl AnnotationTool {
}
}
-#[derive(Debug, Deserialize, JsonSchema, Clone)]
+#[derive(Default, Debug, Deserialize, JsonSchema, Clone)]
pub struct AnnotationInput {
/// Name for this set of annotations
+ #[serde(default = "default_title")]
title: String,
/// Excerpts from the file to show to the user.
excerpts: Vec<Excerpt>,
}
+fn default_title() -> String {
+ "Untitled".to_string()
+}
+
#[derive(Debug, Deserialize, JsonSchema, Clone)]
struct Excerpt {
/// Path to the file
@@ -44,8 +50,6 @@ struct Excerpt {
}
impl LanguageModelTool for AnnotationTool {
- type Input = AnnotationInput;
- type Output = String;
type View = AnnotationResultView;
fn name(&self) -> String {
@@ -56,67 +60,100 @@ impl LanguageModelTool for AnnotationTool {
"Dynamically annotate symbols in the current codebase. Opens a buffer in a panel in their editor, to the side of the conversation. The annotations are shown in the editor as a block decoration.".to_string()
}
- fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
- let workspace = self.workspace.clone();
- let project = self.project.clone();
- let excerpts = input.excerpts.clone();
- let title = input.title.clone();
+ fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
+ cx.new_view(|cx| {
+ let (tx, mut rx) = futures::channel::mpsc::unbounded();
+ cx.spawn(|view, mut cx| async move {
+ while let Some(excerpt) = rx.next().await {
+ AnnotationResultView::add_excerpt(view.clone(), excerpt, &mut cx).await?;
+ }
+ anyhow::Ok(())
+ })
+ .detach();
+
+ AnnotationResultView {
+ project: self.project.clone(),
+ workspace: self.workspace.clone(),
+ tx,
+ pending_excerpt: None,
+ added_editor_to_workspace: false,
+ editor: None,
+ error: None,
+ rendered_excerpt_count: 0,
+ }
+ })
+ }
+}
+
+pub struct AnnotationResultView {
+ workspace: WeakView<Workspace>,
+ project: Model<Project>,
+ pending_excerpt: Option<Excerpt>,
+ added_editor_to_workspace: bool,
+ editor: Option<View<Editor>>,
+ tx: UnboundedSender<Excerpt>,
+ error: Option<anyhow::Error>,
+ rendered_excerpt_count: usize,
+}
+
+impl AnnotationResultView {
+ async fn add_excerpt(
+ this: WeakView<Self>,
+ excerpt: Excerpt,
+ cx: &mut AsyncWindowContext,
+ ) -> Result<()> {
+ let project = this.update(cx, |this, _cx| this.project.clone())?;
let worktree_id = project.update(cx, |project, cx| {
let worktree = project.worktrees().next()?;
let worktree_id = worktree.read(cx).id();
Some(worktree_id)
- });
+ })?;
let worktree_id = if let Some(worktree_id) = worktree_id {
worktree_id
} else {
- return Task::ready(Err(anyhow::anyhow!("No worktree found")));
+ return Err(anyhow::anyhow!("No worktree found"));
};
- let buffer_tasks = project.update(cx, |project, cx| {
- excerpts
- .iter()
- .map(|excerpt| {
- project.open_buffer(
- ProjectPath {
- worktree_id,
- path: Path::new(&excerpt.path).into(),
- },
- cx,
- )
+ let buffer_task = project.update(cx, |project, cx| {
+ project.open_buffer(
+ ProjectPath {
+ worktree_id,
+ path: Path::new(&excerpt.path).into(),
+ },
+ cx,
+ )
+ })?;
+
+ let buffer = match buffer_task.await {
+ Ok(buffer) => buffer,
+ Err(error) => {
+ return this.update(cx, |this, cx| {
+ this.error = Some(error);
+ cx.notify();
})
- .collect::<Vec<_>>()
- });
+ }
+ };
- cx.spawn(move |mut cx| async move {
- let buffers = futures::future::try_join_all(buffer_tasks).await?;
+ let snapshot = buffer.update(cx, |buffer, _cx| buffer.snapshot())?;
+ let query = SearchQuery::text(&excerpt.text_passage, false, false, false, vec![], vec![])?;
+ let matches = query.search(&snapshot, None).await;
+ let Some(first_match) = matches.first() else {
+ log::warn!(
+ "text {:?} does not appear in '{}'",
+ excerpt.text_passage,
+ excerpt.path
+ );
+ return Ok(());
+ };
- let multibuffer = cx.new_model(|_cx| {
- MultiBuffer::new(0, language::Capability::ReadWrite).with_title(title)
- })?;
- let editor =
- cx.new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), cx))?;
-
- for (excerpt, buffer) in excerpts.iter().zip(buffers.iter()) {
- let snapshot = buffer.update(&mut cx, |buffer, _cx| buffer.snapshot())?;
-
- let query =
- SearchQuery::text(&excerpt.text_passage, false, false, false, vec![], vec![])?;
-
- let matches = query.search(&snapshot, None).await;
- let Some(first_match) = matches.first() else {
- log::warn!(
- "text {:?} does not appear in '{}'",
- excerpt.text_passage,
- excerpt.path
- );
- continue;
- };
- let mut start = first_match.start.to_point(&snapshot);
- start.column = 0;
+ this.update(cx, |this, cx| {
+ let mut start = first_match.start.to_point(&snapshot);
+ start.column = 0;
- editor.update(&mut cx, |editor, cx| {
+ if let Some(editor) = &this.editor {
+ editor.update(cx, |editor, cx| {
let ranges = editor.buffer().update(cx, |multibuffer, cx| {
multibuffer.push_excerpts_with_context_lines(
buffer.clone(),
@@ -125,7 +162,8 @@ impl LanguageModelTool for AnnotationTool {
cx,
)
});
- let annotation = SharedString::from(excerpt.annotation.clone());
+
+ let annotation = SharedString::from(excerpt.annotation);
editor.insert_blocks(
[BlockProperties {
position: ranges[0].start,
@@ -137,30 +175,22 @@ impl LanguageModelTool for AnnotationTool {
None,
cx,
);
- })?;
- }
+ });
- workspace
- .update(&mut cx, |workspace, cx| {
- workspace.add_item_to_active_pane(Box::new(editor.clone()), None, cx);
- })
- .log_err();
+ if !this.added_editor_to_workspace {
+ this.added_editor_to_workspace = true;
+ this.workspace
+ .update(cx, |workspace, cx| {
+ workspace.add_item_to_active_pane(Box::new(editor.clone()), None, cx);
+ })
+ .log_err();
+ }
+ }
+ })?;
- anyhow::Ok("showed comments to users in a new view".into())
- })
+ Ok(())
}
- fn view(
- &self,
- _: Self::Input,
- output: Result<Self::Output>,
- cx: &mut WindowContext,
- ) -> View<Self::View> {
- cx.new_view(|_cx| AnnotationResultView { output })
- }
-}
-
-impl AnnotationTool {
fn render_note_block(explanation: &SharedString, cx: &mut BlockContext) -> AnyElement {
let anchor_x = cx.anchor_x;
let gutter_width = cx.gutter_dimensions.width;
@@ -186,24 +216,89 @@ impl AnnotationTool {
}
}
-pub struct AnnotationResultView {
- output: Result<String>,
-}
-
impl Render for AnnotationResultView {
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
- match &self.output {
- Ok(output) => div().child(output.clone().into_any_element()),
- Err(error) => div().child(format!("failed to open path: {:?}", error)),
+ if let Some(error) = &self.error {
+ ui::Label::new(error.to_string()).into_any_element()
+ } else {
+ ui::Label::new(SharedString::from(format!(
+ "Opened a buffer with {} excerpts",
+ self.rendered_excerpt_count
+ )))
+ .into_any_element()
}
}
}
impl ToolOutput for AnnotationResultView {
- fn generate(&self, _: &mut ProjectContext, _: &mut WindowContext) -> String {
- match &self.output {
- Ok(output) => output.clone(),
- Err(err) => format!("Failed to create buffer: {err:?}"),
+ type Input = AnnotationInput;
+ type SerializedState = Option<String>;
+
+ fn generate(&self, _: &mut ProjectContext, _: &mut ViewContext<Self>) -> String {
+ if let Some(error) = &self.error {
+ format!("Failed to create buffer: {error:?}")
+ } else {
+ format!(
+ "opened {} excerpts in a buffer",
+ self.rendered_excerpt_count
+ )
+ }
+ }
+
+ fn set_input(&mut self, mut input: Self::Input, cx: &mut ViewContext<Self>) {
+ let editor = if let Some(editor) = &self.editor {
+ editor.clone()
+ } else {
+ let multibuffer = cx.new_model(|_cx| {
+ MultiBuffer::new(0, language::Capability::ReadWrite).with_title(String::new())
+ });
+ let editor = cx.new_view(|cx| {
+ Editor::for_multibuffer(multibuffer.clone(), Some(self.project.clone()), cx)
+ });
+
+ self.editor = Some(editor.clone());
+ editor
+ };
+
+ editor.update(cx, |editor, cx| {
+ editor.buffer().update(cx, |multibuffer, cx| {
+ if multibuffer.title(cx) != input.title {
+ multibuffer.set_title(input.title.clone(), cx);
+ }
+ });
+
+ self.pending_excerpt = input.excerpts.pop();
+ for excerpt in input.excerpts.iter().skip(self.rendered_excerpt_count) {
+ self.tx.unbounded_send(excerpt.clone()).ok();
+ }
+ self.rendered_excerpt_count = input.excerpts.len();
+ });
+
+ cx.notify();
+ }
+
+ fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
+ if let Some(excerpt) = self.pending_excerpt.take() {
+ self.rendered_excerpt_count += 1;
+ self.tx.unbounded_send(excerpt.clone()).ok();
+ }
+
+ self.tx.close_channel();
+ Task::ready(Ok(()))
+ }
+
+ fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
+ self.error.as_ref().map(|error| error.to_string())
+ }
+
+ fn deserialize(
+ &mut self,
+ output: Self::SerializedState,
+ _cx: &mut ViewContext<Self>,
+ ) -> Result<()> {
+ if let Some(error_message) = output {
+ self.error = Some(anyhow::anyhow!("{}", error_message));
}
+ Ok(())
}
}
@@ -1,4 +1,4 @@
-use anyhow::Result;
+use anyhow::{anyhow, Result};
use assistant_tooling::{LanguageModelTool, ProjectContext, ToolOutput};
use editor::Editor;
use gpui::{prelude::*, Model, Task, View, WeakView};
@@ -20,7 +20,7 @@ impl CreateBufferTool {
}
}
-#[derive(Debug, Deserialize, JsonSchema)]
+#[derive(Debug, Clone, Deserialize, JsonSchema)]
pub struct CreateBufferInput {
/// The contents of the buffer.
text: String,
@@ -32,8 +32,6 @@ pub struct CreateBufferInput {
}
impl LanguageModelTool for CreateBufferTool {
- type Input = CreateBufferInput;
- type Output = ();
type View = CreateBufferView;
fn name(&self) -> String {
@@ -44,13 +42,59 @@ impl LanguageModelTool for CreateBufferTool {
"Create a new buffer in the current codebase".to_string()
}
- fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
+ fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
+ cx.new_view(|_cx| CreateBufferView {
+ workspace: self.workspace.clone(),
+ project: self.project.clone(),
+ input: None,
+ error: None,
+ })
+ }
+}
+
+pub struct CreateBufferView {
+ workspace: WeakView<Workspace>,
+ project: Model<Project>,
+ input: Option<CreateBufferInput>,
+ error: Option<anyhow::Error>,
+}
+
+impl Render for CreateBufferView {
+ fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
+ div().child("Opening a buffer")
+ }
+}
+
+impl ToolOutput for CreateBufferView {
+ type Input = CreateBufferInput;
+
+ type SerializedState = ();
+
+ fn generate(&self, _project: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
+ let Some(input) = self.input.as_ref() else {
+ return "No input".to_string();
+ };
+
+ match &self.error {
+ None => format!("Created a new {} buffer", input.language),
+ Some(err) => format!("Failed to create buffer: {err:?}"),
+ }
+ }
+
+ fn set_input(&mut self, input: Self::Input, _cx: &mut ViewContext<Self>) {
+ self.input = Some(input);
+ }
+
+ fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
cx.spawn({
let workspace = self.workspace.clone();
let project = self.project.clone();
- let text = input.text.clone();
- let language_name = input.language.clone();
- |mut cx| async move {
+ let input = self.input.clone();
+ |_this, mut cx| async move {
+ let input = input.ok_or_else(|| anyhow!("no input"))?;
+
+ let text = input.text.clone();
+ let language_name = input.language.clone();
let language = cx
.update(|cx| {
project
@@ -86,35 +130,15 @@ impl LanguageModelTool for CreateBufferTool {
})
}
- fn view(
- &self,
- input: Self::Input,
- output: Result<Self::Output>,
- cx: &mut WindowContext,
- ) -> View<Self::View> {
- cx.new_view(|_cx| CreateBufferView {
- language: input.language,
- output,
- })
+ fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
+ ()
}
-}
-pub struct CreateBufferView {
- language: String,
- output: Result<()>,
-}
-
-impl Render for CreateBufferView {
- fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
- div().child("Opening a buffer")
- }
-}
-
-impl ToolOutput for CreateBufferView {
- fn generate(&self, _: &mut ProjectContext, _: &mut WindowContext) -> String {
- match &self.output {
- Ok(_) => format!("Created a new {} buffer", self.language),
- Err(err) => format!("Failed to create buffer: {err:?}"),
- }
+ fn deserialize(
+ &mut self,
+ _output: Self::SerializedState,
+ _cx: &mut ViewContext<Self>,
+ ) -> Result<()> {
+ Ok(())
}
}
@@ -1,4 +1,4 @@
-use anyhow::{anyhow, Result};
+use anyhow::Result;
use assistant_tooling::{LanguageModelTool, ToolOutput};
use collections::BTreeMap;
use gpui::{prelude::*, Model, Task};
@@ -6,9 +6,8 @@ use project::ProjectPath;
use schemars::JsonSchema;
use semantic_index::{ProjectIndex, Status};
use serde::{Deserialize, Serialize};
-use serde_json::Value;
use std::{fmt::Write as _, ops::Range, path::Path, sync::Arc};
-use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
+use ui::{prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
const DEFAULT_SEARCH_LIMIT: usize = 20;
@@ -16,10 +15,26 @@ pub struct ProjectIndexTool {
project_index: Model<ProjectIndex>,
}
-// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
-// Any changes or deletions to the `CodebaseQuery` comments will change model behavior.
+#[derive(Default)]
+enum ProjectIndexToolState {
+ #[default]
+ CollectingQuery,
+ Searching,
+ Error(anyhow::Error),
+ Finished {
+ excerpts: BTreeMap<ProjectPath, Vec<Range<usize>>>,
+ index_status: Status,
+ },
+}
+
+pub struct ProjectIndexView {
+ project_index: Model<ProjectIndex>,
+ input: CodebaseQuery,
+ expanded_header: bool,
+ state: ProjectIndexToolState,
+}
-#[derive(Deserialize, JsonSchema)]
+#[derive(Default, Deserialize, JsonSchema)]
pub struct CodebaseQuery {
/// Semantic search query
query: String,
@@ -27,21 +42,14 @@ pub struct CodebaseQuery {
limit: Option<usize>,
}
-pub struct ProjectIndexView {
- input: CodebaseQuery,
- status: Status,
- excerpts: Result<BTreeMap<ProjectPath, Vec<Range<usize>>>>,
- element_id: ElementId,
- expanded_header: bool,
-}
-
#[derive(Serialize, Deserialize)]
-pub struct ProjectIndexOutput {
- status: Status,
+pub struct SerializedState {
+ index_status: Status,
+ error_message: Option<String>,
worktrees: BTreeMap<Arc<Path>, WorktreeIndexOutput>,
}
-#[derive(Serialize, Deserialize)]
+#[derive(Default, Serialize, Deserialize)]
struct WorktreeIndexOutput {
excerpts: BTreeMap<Arc<Path>, Vec<Range<usize>>>,
}
@@ -56,58 +64,80 @@ impl ProjectIndexView {
impl Render for ProjectIndexView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let query = self.input.query.clone();
- let excerpts = match &self.excerpts {
- Err(err) => {
- return div().child(Label::new(format!("Error: {}", err)).color(Color::Error));
+
+ let (header_text, content) = match &self.state {
+ ProjectIndexToolState::Error(error) => {
+ return format!("failed to search: {error:?}").into_any_element()
+ }
+ ProjectIndexToolState::CollectingQuery | ProjectIndexToolState::Searching => {
+ ("Searching...".to_string(), div())
+ }
+ ProjectIndexToolState::Finished { excerpts, .. } => {
+ let file_count = excerpts.len();
+
+ let header_text = format!(
+ "Read {} {}",
+ file_count,
+ if file_count == 1 { "file" } else { "files" }
+ );
+
+ let el = v_flex().gap_2().children(excerpts.keys().map(|path| {
+ h_flex().gap_2().child(Icon::new(IconName::File)).child(
+ Label::new(path.path.to_string_lossy().to_string()).color(Color::Muted),
+ )
+ }));
+
+ (header_text, el)
}
- Ok(excerpts) => excerpts,
};
- let file_count = excerpts.len();
let header = h_flex()
.gap_2()
.child(Icon::new(IconName::File))
- .child(format!(
- "Read {} {}",
- file_count,
- if file_count == 1 { "file" } else { "files" }
- ));
-
- v_flex().gap_3().child(
- CollapsibleContainer::new(self.element_id.clone(), self.expanded_header)
- .start_slot(header)
- .on_click(cx.listener(move |this, _, cx| {
- this.toggle_header(cx);
- }))
- .child(
- v_flex()
- .gap_3()
- .p_3()
- .child(
- h_flex()
- .gap_2()
- .child(Icon::new(IconName::MagnifyingGlass))
- .child(Label::new(format!("`{}`", query)).color(Color::Muted)),
- )
- .child(v_flex().gap_2().children(excerpts.keys().map(|path| {
- h_flex().gap_2().child(Icon::new(IconName::File)).child(
- Label::new(path.path.to_string_lossy().to_string())
- .color(Color::Muted),
+ .child(header_text);
+
+ v_flex()
+ .gap_3()
+ .child(
+ CollapsibleContainer::new("collapsible-container", self.expanded_header)
+ .start_slot(header)
+ .on_click(cx.listener(move |this, _, cx| {
+ this.toggle_header(cx);
+ }))
+ .child(
+ v_flex()
+ .gap_3()
+ .p_3()
+ .child(
+ h_flex()
+ .gap_2()
+ .child(Icon::new(IconName::MagnifyingGlass))
+ .child(Label::new(format!("`{}`", query)).color(Color::Muted)),
)
- }))),
- ),
- )
+ .child(content),
+ ),
+ )
+ .into_any_element()
}
}
impl ToolOutput for ProjectIndexView {
+ type Input = CodebaseQuery;
+ type SerializedState = SerializedState;
+
fn generate(
&self,
context: &mut assistant_tooling::ProjectContext,
- _: &mut WindowContext,
+ _: &mut ViewContext<Self>,
) -> String {
- match &self.excerpts {
- Ok(excerpts) => {
+ match &self.state {
+ ProjectIndexToolState::CollectingQuery => String::new(),
+ ProjectIndexToolState::Searching => String::new(),
+ ProjectIndexToolState::Error(error) => format!("failed to search: {error:?}"),
+ ProjectIndexToolState::Finished {
+ excerpts,
+ index_status,
+ } => {
let mut body = "found results in the following paths:\n".to_string();
for (project_path, ranges) in excerpts {
@@ -115,139 +145,151 @@ impl ToolOutput for ProjectIndexView {
writeln!(&mut body, "* {}", &project_path.path.display()).unwrap();
}
- if self.status != Status::Idle {
+ if *index_status != Status::Idle {
body.push_str("Still indexing. Results may be incomplete.\n");
}
body
}
- Err(err) => format!("Error: {}", err),
}
}
-}
-
-impl ProjectIndexTool {
- pub fn new(project_index: Model<ProjectIndex>) -> Self {
- Self { project_index }
- }
-}
-impl LanguageModelTool for ProjectIndexTool {
- type Input = CodebaseQuery;
- type Output = ProjectIndexOutput;
- type View = ProjectIndexView;
-
- fn name(&self) -> String {
- "query_codebase".to_string()
+ fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
+ self.input = input;
+ cx.notify();
}
- fn description(&self) -> String {
- "Semantic search against the user's current codebase, returning excerpts related to the query by computing a dot product against embeddings of code chunks in the code base and an embedding of the query.".to_string()
- }
+ fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
+ self.state = ProjectIndexToolState::Searching;
+ cx.notify();
- fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
let project_index = self.project_index.read(cx);
- let status = project_index.status();
+ let index_status = project_index.status();
let search = project_index.search(
- query.query.clone(),
- query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
+ self.input.query.clone(),
+ self.input.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
cx,
);
- cx.spawn(|mut cx| async move {
- let search_results = search.await?;
-
- cx.update(|cx| {
- let mut output = ProjectIndexOutput {
- status,
- worktrees: Default::default(),
- };
-
- for search_result in search_results {
- let worktree_path = search_result.worktree.read(cx).abs_path();
- let excerpts = &mut output
- .worktrees
- .entry(worktree_path)
- .or_insert(WorktreeIndexOutput {
- excerpts: Default::default(),
- })
- .excerpts;
-
- let excerpts_for_path = excerpts.entry(search_result.path).or_default();
- let ix = match excerpts_for_path
- .binary_search_by_key(&search_result.range.start, |r| r.start)
- {
- Ok(ix) | Err(ix) => ix,
- };
- excerpts_for_path.insert(ix, search_result.range);
+ cx.spawn(|this, mut cx| async move {
+ let search_result = search.await;
+ this.update(&mut cx, |this, cx| {
+ match search_result {
+ Ok(search_results) => {
+ let mut excerpts = BTreeMap::<ProjectPath, Vec<Range<usize>>>::new();
+ for search_result in search_results {
+ let project_path = ProjectPath {
+ worktree_id: search_result.worktree.read(cx).id(),
+ path: search_result.path,
+ };
+ excerpts
+ .entry(project_path)
+ .or_default()
+ .push(search_result.range);
+ }
+ this.state = ProjectIndexToolState::Finished {
+ excerpts,
+ index_status,
+ };
+ }
+ Err(error) => {
+ this.state = ProjectIndexToolState::Error(error);
+ }
}
-
- output
+ cx.notify();
})
})
}
- fn view(
- &self,
- input: Self::Input,
- output: Result<Self::Output>,
- cx: &mut WindowContext,
- ) -> gpui::View<Self::View> {
- cx.new_view(|cx| {
- let status;
- let excerpts;
- match output {
- Ok(output) => {
- status = output.status;
- let project_index = self.project_index.read(cx);
- if let Some(project) = project_index.project().upgrade() {
- let project = project.read(cx);
- excerpts = Ok(output
- .worktrees
- .into_iter()
- .filter_map(|(abs_path, output)| {
- for worktree in project.worktrees() {
- let worktree = worktree.read(cx);
- if worktree.abs_path() == abs_path {
- return Some((worktree.id(), output.excerpts));
- }
- }
- None
- })
- .flat_map(|(worktree_id, excerpts)| {
- excerpts.into_iter().map(move |(path, ranges)| {
- (ProjectPath { worktree_id, path }, ranges)
- })
- })
- .collect::<BTreeMap<_, _>>());
- } else {
- excerpts = Err(anyhow!("project was dropped"));
+ fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState {
+ let mut serialized = SerializedState {
+ error_message: None,
+ index_status: Status::Idle,
+ worktrees: Default::default(),
+ };
+ match &self.state {
+ ProjectIndexToolState::Error(err) => serialized.error_message = Some(err.to_string()),
+ ProjectIndexToolState::Finished {
+ excerpts,
+ index_status,
+ } => {
+ serialized.index_status = *index_status;
+ if let Some(project) = self.project_index.read(cx).project().upgrade() {
+ let project = project.read(cx);
+ for (project_path, excerpts) in excerpts {
+ if let Some(worktree) =
+ project.worktree_for_id(project_path.worktree_id, cx)
+ {
+ let worktree_path = worktree.read(cx).abs_path();
+ serialized
+ .worktrees
+ .entry(worktree_path)
+ .or_default()
+ .excerpts
+ .insert(project_path.path.clone(), excerpts.clone());
+ }
}
}
- Err(err) => {
- status = Status::Idle;
- excerpts = Err(err);
+ }
+ _ => {}
+ }
+ serialized
+ }
+
+ fn deserialize(
+ &mut self,
+ serialized: Self::SerializedState,
+ cx: &mut ViewContext<Self>,
+ ) -> Result<()> {
+ if !serialized.worktrees.is_empty() {
+ let mut excerpts = BTreeMap::<ProjectPath, Vec<Range<usize>>>::new();
+ if let Some(project) = self.project_index.read(cx).project().upgrade() {
+ let project = project.read(cx);
+ for (worktree_path, worktree_state) in serialized.worktrees {
+ if let Some(worktree) = project
+ .worktrees()
+ .find(|worktree| worktree.read(cx).abs_path() == worktree_path)
+ {
+ let worktree_id = worktree.read(cx).id();
+ for (path, serialized_excerpts) in worktree_state.excerpts {
+ excerpts.insert(ProjectPath { worktree_id, path }, serialized_excerpts);
+ }
+ }
}
+ }
+ self.state = ProjectIndexToolState::Finished {
+ excerpts,
+ index_status: serialized.index_status,
};
+ }
+ cx.notify();
+ Ok(())
+ }
+}
- ProjectIndexView {
- input,
- status,
- excerpts,
- element_id: ElementId::Name(nanoid::nanoid!().into()),
- expanded_header: false,
- }
- })
+impl ProjectIndexTool {
+ pub fn new(project_index: Model<ProjectIndex>) -> Self {
+ Self { project_index }
}
+}
- fn render_running(arguments: &Option<Value>, _: &mut WindowContext) -> impl IntoElement {
- let text: String = arguments
- .as_ref()
- .and_then(|arguments| arguments.get("query"))
- .and_then(|query| query.as_str())
- .map(|query| format!("Searching for: {}", query))
- .unwrap_or_else(|| "Preparing search...".to_string());
+impl LanguageModelTool for ProjectIndexTool {
+ type View = ProjectIndexView;
- CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false).start_slot(text)
+ fn name(&self) -> String {
+ "query_codebase".to_string()
+ }
+
+ fn description(&self) -> String {
+ "Semantic search against the user's current codebase, returning excerpts related to the query by computing a dot product against embeddings of code chunks in the code base and an embedding of the query.".to_string()
+ }
+
+ fn view(&self, cx: &mut WindowContext) -> gpui::View<Self::View> {
+ cx.new_view(|_| ProjectIndexView {
+ state: ProjectIndexToolState::CollectingQuery,
+ input: Default::default(),
+ expanded_header: false,
+ project_index: self.project_index.clone(),
+ })
}
}
@@ -16,7 +16,9 @@ anyhow.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
+log.workspace = true
project.workspace = true
+repair_json.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
@@ -3,11 +3,11 @@ mod project_context;
mod tool_registry;
pub use attachment_registry::{
- AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, UserAttachment,
+ AttachmentOutput, AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment,
+ UserAttachment,
};
pub use project_context::ProjectContext;
pub use tool_registry::{
- tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall,
- SavedToolFunctionCallResult, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
- ToolOutput, ToolRegistry,
+ tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall, SavedToolFunctionCallState,
+ ToolFunctionCall, ToolFunctionCallState, ToolFunctionDefinition, ToolOutput, ToolRegistry,
};
@@ -1,4 +1,4 @@
-use crate::{ProjectContext, ToolOutput};
+use crate::ProjectContext;
use anyhow::{anyhow, Result};
use collections::HashMap;
use futures::future::join_all;
@@ -18,9 +18,13 @@ pub struct AttachmentRegistry {
registered_attachments: HashMap<TypeId, RegisteredAttachment>,
}
+pub trait AttachmentOutput {
+ fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
+}
+
pub trait LanguageModelAttachment {
type Output: DeserializeOwned + Serialize + 'static;
- type View: Render + ToolOutput;
+ type View: Render + AttachmentOutput;
fn name(&self) -> Arc<str>;
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
@@ -1,11 +1,10 @@
use crate::ProjectContext;
use anyhow::{anyhow, Result};
-use gpui::{
- div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
-};
+use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
+use repair_json::repair;
use schemars::{schema::RootSchema, schema_for, JsonSchema};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
-use serde_json::{value::RawValue, Value};
+use serde_json::value::RawValue;
use std::{
any::TypeId,
collections::HashMap,
@@ -15,6 +14,7 @@ use std::{
Arc,
},
};
+use ui::ViewContext;
pub struct ToolRegistry {
registered_tools: HashMap<String, RegisteredTool>,
@@ -25,7 +25,25 @@ pub struct ToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
- pub result: Option<ToolFunctionCallResult>,
+ state: ToolFunctionCallState,
+}
+
+#[derive(Default)]
+pub enum ToolFunctionCallState {
+ #[default]
+ Initializing,
+ NoSuchTool,
+ KnownTool(Box<dyn ToolView>),
+ ExecutedTool(Box<dyn ToolView>),
+}
+
+pub trait ToolView {
+ fn view(&self) -> AnyView;
+ fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
+ fn set_input(&self, input: &str, cx: &mut WindowContext);
+ fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
+ fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
+ fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
}
#[derive(Default, Serialize, Deserialize)]
@@ -33,29 +51,19 @@ pub struct SavedToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
- pub result: Option<SavedToolFunctionCallResult>,
-}
-
-pub enum ToolFunctionCallResult {
- NoSuchTool,
- ParsingFailed,
- Finished {
- view: AnyView,
- serialized_output: Result<Box<RawValue>, String>,
- generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
- },
+ pub state: SavedToolFunctionCallState,
}
-#[derive(Serialize, Deserialize)]
-pub enum SavedToolFunctionCallResult {
+#[derive(Default, Serialize, Deserialize)]
+pub enum SavedToolFunctionCallState {
+ #[default]
+ Initializing,
NoSuchTool,
- ParsingFailed,
- Finished {
- serialized_output: Result<Box<RawValue>, String>,
- },
+ KnownTool,
+ ExecutedTool(Box<RawValue>),
}
-#[derive(Clone)]
+#[derive(Clone, Debug)]
pub struct ToolFunctionDefinition {
pub name: String,
pub description: String,
@@ -63,14 +71,7 @@ pub struct ToolFunctionDefinition {
}
pub trait LanguageModelTool {
- /// The input type that will be passed in to `execute` when the tool is called
- /// by the language model.
- type Input: DeserializeOwned + JsonSchema;
-
- /// The output returned by executing the tool.
- type Output: DeserializeOwned + Serialize + 'static;
-
- type View: Render + ToolOutput;
+ type View: ToolOutput;
/// Returns the name of the tool.
///
@@ -86,7 +87,7 @@ pub trait LanguageModelTool {
/// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
fn definition(&self) -> ToolFunctionDefinition {
- let root_schema = schema_for!(Self::Input);
+ let root_schema = schema_for!(<Self::View as ToolOutput>::Input);
ToolFunctionDefinition {
name: self.name(),
@@ -95,36 +96,46 @@ pub trait LanguageModelTool {
}
}
- /// Executes the tool with the given input.
- fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
-
/// A view of the output of running the tool, for displaying to the user.
- fn view(
- &self,
- input: Self::Input,
- output: Result<Self::Output>,
- cx: &mut WindowContext,
- ) -> View<Self::View>;
-
- fn render_running(_arguments: &Option<Value>, _cx: &mut WindowContext) -> impl IntoElement {
- tool_running_placeholder()
- }
+ fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
}
pub fn tool_running_placeholder() -> AnyElement {
ui::Label::new("Researching...").into_any_element()
}
-pub trait ToolOutput: Sized {
- fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
+pub fn unknown_tool_placeholder() -> AnyElement {
+ ui::Label::new("Unknown tool").into_any_element()
+}
+
+pub fn no_such_tool_placeholder() -> AnyElement {
+ ui::Label::new("No such tool").into_any_element()
+}
+
+pub trait ToolOutput: Render {
+ /// The input type that will be passed in to `execute` when the tool is called
+ /// by the language model.
+ type Input: DeserializeOwned + JsonSchema;
+
+ /// The output returned by executing the tool.
+ type SerializedState: DeserializeOwned + Serialize;
+
+ fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext<Self>) -> String;
+ fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>);
+ fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>>;
+
+ fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState;
+ fn deserialize(
+ &mut self,
+ output: Self::SerializedState,
+ cx: &mut ViewContext<Self>,
+ ) -> Result<()>;
}
struct RegisteredTool {
enabled: AtomicBool,
type_id: TypeId,
- execute: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
- deserialize: Box<dyn Fn(&SavedToolFunctionCall, &mut WindowContext) -> ToolFunctionCall>,
- render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
+ build_view: Box<dyn Fn(&mut WindowContext) -> Box<dyn ToolView>>,
definition: ToolFunctionDefinition,
}
@@ -161,63 +172,132 @@ impl ToolRegistry {
.collect()
}
- pub fn render_tool_call(
+ pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option<Box<dyn ToolView>> {
+ let tool = self.registered_tools.get(name)?;
+ Some((tool.build_view)(cx))
+ }
+
+ pub fn update_tool_call(
&self,
- tool_call: &ToolFunctionCall,
+ call: &mut ToolFunctionCall,
+ name: Option<&str>,
+ arguments: Option<&str>,
cx: &mut WindowContext,
- ) -> AnyElement {
- match &tool_call.result {
- Some(result) => div()
- .p_2()
- .child(result.into_any_element(&tool_call.name))
- .into_any_element(),
- None => {
- let tool = self.registered_tools.get(&tool_call.name);
-
- if let Some(tool) = tool {
- (tool.render_running)(&tool_call, cx)
+ ) {
+ if let Some(name) = name {
+ call.name.push_str(name);
+ }
+ if let Some(arguments) = arguments {
+ if call.arguments.is_empty() {
+ if let Some(view) = self.view_for_tool(&call.name, cx) {
+ call.state = ToolFunctionCallState::KnownTool(view);
} else {
- tool_running_placeholder()
+ call.state = ToolFunctionCallState::NoSuchTool;
+ }
+ }
+ call.arguments.push_str(arguments);
+
+ if let ToolFunctionCallState::KnownTool(view) = &call.state {
+ if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
+ view.set_input(&repaired_arguments, cx)
}
}
}
}
- pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall {
- SavedToolFunctionCall {
+ pub fn execute_tool_call(
+ &self,
+ tool_call: &ToolFunctionCall,
+ cx: &mut WindowContext,
+ ) -> Option<Task<Result<()>>> {
+ if let ToolFunctionCallState::KnownTool(view) = &tool_call.state {
+ Some(view.execute(cx))
+ } else {
+ None
+ }
+ }
+
+ pub fn render_tool_call(
+ &self,
+ tool_call: &ToolFunctionCall,
+ _cx: &mut WindowContext,
+ ) -> AnyElement {
+ match &tool_call.state {
+ ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(),
+ ToolFunctionCallState::Initializing => unknown_tool_placeholder(),
+ ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
+ view.view().into_any_element()
+ }
+ }
+ }
+
+ pub fn content_for_tool_call(
+ &self,
+ tool_call: &ToolFunctionCall,
+ project_context: &mut ProjectContext,
+ cx: &mut WindowContext,
+ ) -> String {
+ match &tool_call.state {
+ ToolFunctionCallState::Initializing => String::new(),
+ ToolFunctionCallState::NoSuchTool => {
+ format!("No such tool: {}", tool_call.name)
+ }
+ ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
+ view.generate(project_context, cx)
+ }
+ }
+ }
+
+ pub fn serialize_tool_call(
+ &self,
+ call: &ToolFunctionCall,
+ cx: &mut WindowContext,
+ ) -> Result<SavedToolFunctionCall> {
+ Ok(SavedToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
- result: call.result.as_ref().map(|result| match result {
- ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool,
- ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed,
- ToolFunctionCallResult::Finished {
- serialized_output, ..
- } => SavedToolFunctionCallResult::Finished {
- serialized_output: match serialized_output {
- Ok(value) => Ok(value.clone()),
- Err(e) => Err(e.to_string()),
- },
- },
- }),
- }
+ state: match &call.state {
+ ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing,
+ ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool,
+ ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool,
+ ToolFunctionCallState::ExecutedTool(view) => {
+ SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?)
+ }
+ },
+ })
}
pub fn deserialize_tool_call(
&self,
call: &SavedToolFunctionCall,
cx: &mut WindowContext,
- ) -> ToolFunctionCall {
- if let Some(tool) = &self.registered_tools.get(&call.name) {
- (tool.deserialize)(call, cx)
- } else {
- ToolFunctionCall {
- id: call.id.clone(),
- name: call.name.clone(),
- arguments: call.arguments.clone(),
- result: Some(ToolFunctionCallResult::NoSuchTool),
- }
- }
+ ) -> Result<ToolFunctionCall> {
+ let Some(tool) = self.registered_tools.get(&call.name) else {
+ return Err(anyhow!("no such tool {}", call.name));
+ };
+
+ Ok(ToolFunctionCall {
+ id: call.id.clone(),
+ name: call.name.clone(),
+ arguments: call.arguments.clone(),
+ state: match &call.state {
+ SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing,
+ SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool,
+ SavedToolFunctionCallState::KnownTool => {
+ log::error!("Deserialized tool that had not executed");
+ let view = (tool.build_view)(cx);
+ view.set_input(&call.arguments, cx);
+ ToolFunctionCallState::KnownTool(view)
+ }
+ SavedToolFunctionCallState::ExecutedTool(output) => {
+ let view = (tool.build_view)(cx);
+ view.set_input(&call.arguments, cx);
+ view.deserialize_output(output, cx)?;
+ ToolFunctionCallState::ExecutedTool(view)
+ }
+ },
+ })
}
pub fn register<T: 'static + LanguageModelTool>(
@@ -231,114 +311,7 @@ impl ToolRegistry {
type_id: TypeId::of::<T>(),
definition: tool.definition(),
enabled: AtomicBool::new(true),
- deserialize: Box::new({
- let tool = tool.clone();
- move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| {
- let id = tool_call.id.clone();
- let name = tool_call.name.clone();
- let arguments = tool_call.arguments.clone();
-
- let Ok(input) = serde_json::from_str::<T::Input>(&tool_call.arguments) else {
- return ToolFunctionCall {
- id,
- name: name.clone(),
- arguments,
- result: Some(ToolFunctionCallResult::ParsingFailed),
- };
- };
-
- let result = match &tool_call.result {
- Some(result) => match result {
- SavedToolFunctionCallResult::NoSuchTool => {
- Some(ToolFunctionCallResult::NoSuchTool)
- }
- SavedToolFunctionCallResult::ParsingFailed => {
- Some(ToolFunctionCallResult::ParsingFailed)
- }
- SavedToolFunctionCallResult::Finished { serialized_output } => {
- let output = match serialized_output {
- Ok(value) => {
- match serde_json::from_str::<T::Output>(value.get()) {
- Ok(value) => Ok(value),
- Err(_) => {
- return ToolFunctionCall {
- id,
- name: name.clone(),
- arguments,
- result: Some(
- ToolFunctionCallResult::ParsingFailed,
- ),
- };
- }
- }
- }
- Err(e) => Err(anyhow!("{e}")),
- };
-
- let view = tool.view(input, output, cx).into();
- Some(ToolFunctionCallResult::Finished {
- serialized_output: serialized_output.clone(),
- generate_fn: generate::<T>,
- view,
- })
- }
- },
- None => None,
- };
-
- ToolFunctionCall {
- id: tool_call.id.clone(),
- name: name.clone(),
- arguments: tool_call.arguments.clone(),
- result,
- }
- }
- }),
- execute: Box::new({
- let tool = tool.clone();
- move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
- let id = tool_call.id.clone();
- let name = tool_call.name.clone();
- let arguments = tool_call.arguments.clone();
-
- let Ok(input) = serde_json::from_str::<T::Input>(&arguments) else {
- return Task::ready(Ok(ToolFunctionCall {
- id,
- name: name.clone(),
- arguments,
- result: Some(ToolFunctionCallResult::ParsingFailed),
- }));
- };
-
- let result = tool.execute(&input, cx);
- let tool = tool.clone();
- cx.spawn(move |mut cx| async move {
- let result = result.await;
- let serialized_output = result
- .as_ref()
- .map_err(ToString::to_string)
- .and_then(|output| {
- Ok(RawValue::from_string(
- serde_json::to_string(output).map_err(|e| e.to_string())?,
- )
- .unwrap())
- });
- let view = cx.update(|cx| tool.view(input, result, cx))?;
-
- Ok(ToolFunctionCall {
- id,
- name: name.clone(),
- arguments,
- result: Some(ToolFunctionCallResult::Finished {
- serialized_output,
- view: view.into(),
- generate_fn: generate::<T>,
- }),
- })
- })
- }
- }),
- render_running: render_running::<T>,
+ build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))),
};
let previous = self.registered_tools.insert(name.clone(), registered_tool);
@@ -347,83 +320,40 @@ impl ToolRegistry {
}
return Ok(());
+ }
+}
- fn render_running<T: LanguageModelTool>(
- tool_call: &ToolFunctionCall,
- cx: &mut WindowContext,
- ) -> AnyElement {
- // Attempt to parse the string arguments that are JSON as a JSON value
- let maybe_arguments = serde_json::to_value(tool_call.arguments.clone()).ok();
+impl<T: ToolOutput> ToolView for View<T> {
+ fn view(&self) -> AnyView {
+ self.clone().into()
+ }
- T::render_running(&maybe_arguments, cx).into_any_element()
- }
+ fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
+ self.update(cx, |view, cx| view.generate(project, cx))
+ }
- fn generate<T: LanguageModelTool>(
- view: AnyView,
- project: &mut ProjectContext,
- cx: &mut WindowContext,
- ) -> String {
- view.downcast::<T::View>()
- .unwrap()
- .update(cx, |view, cx| T::View::generate(view, project, cx))
+ fn set_input(&self, input: &str, cx: &mut WindowContext) {
+ if let Ok(input) = serde_json::from_str::<T::Input>(input) {
+ self.update(cx, |view, cx| {
+ view.set_input(input, cx);
+ cx.notify();
+ });
}
}
- /// Task yields an error if the window for the given WindowContext is closed before the task completes.
- pub fn call(
- &self,
- tool_call: &ToolFunctionCall,
- cx: &mut WindowContext,
- ) -> Task<Result<ToolFunctionCall>> {
- let name = tool_call.name.clone();
- let arguments = tool_call.arguments.clone();
- let id = tool_call.id.clone();
-
- let tool = match self.registered_tools.get(&name) {
- Some(tool) => tool,
- None => {
- let name = name.clone();
- return Task::ready(Ok(ToolFunctionCall {
- id,
- name: name.clone(),
- arguments,
- result: Some(ToolFunctionCallResult::NoSuchTool),
- }));
- }
- };
-
- (tool.execute)(tool_call, cx)
+ fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>> {
+ self.update(cx, |view, cx| view.execute(cx))
}
-}
-impl ToolFunctionCallResult {
- pub fn generate(
- &self,
- name: &String,
- project: &mut ProjectContext,
- cx: &mut WindowContext,
- ) -> String {
- match self {
- ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
- ToolFunctionCallResult::ParsingFailed => {
- format!("Unable to parse arguments for {name}")
- }
- ToolFunctionCallResult::Finished {
- generate_fn, view, ..
- } => (generate_fn)(view.clone(), project, cx),
- }
+ fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>> {
+ let output = self.update(cx, |view, cx| view.serialize(cx));
+ Ok(RawValue::from_string(serde_json::to_string(&output)?)?)
}
- fn into_any_element(&self, name: &String) -> AnyElement {
- match self {
- ToolFunctionCallResult::NoSuchTool => {
- format!("Language Model attempted to call {name}").into_any_element()
- }
- ToolFunctionCallResult::ParsingFailed => {
- format!("Language Model called {name} with bad arguments").into_any_element()
- }
- ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
- }
+ fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> {
+ let state = serde_json::from_str::<T::SerializedState>(output.get())?;
+ self.update(cx, |view, cx| view.deserialize(state, cx))?;
+ Ok(())
}
}
@@ -453,10 +383,6 @@ mod test {
unit: String,
}
- struct WeatherTool {
- current_weather: WeatherResult,
- }
-
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
struct WeatherResult {
location: String,
@@ -465,24 +391,81 @@ mod test {
}
struct WeatherView {
- result: WeatherResult,
+ input: Option<WeatherQuery>,
+ result: Option<WeatherResult>,
+
+ // Fake API call
+ current_weather: WeatherResult,
+ }
+
+ #[derive(Clone, Serialize)]
+ struct WeatherTool {
+ current_weather: WeatherResult,
+ }
+
+ impl WeatherView {
+ fn new(current_weather: WeatherResult) -> Self {
+ Self {
+ input: None,
+ result: None,
+ current_weather,
+ }
+ }
}
impl Render for WeatherView {
fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
- div().child(format!("temperature: {}", self.result.temperature))
+ match self.result {
+ Some(ref result) => div()
+ .child(format!("temperature: {}", result.temperature))
+ .into_any_element(),
+ None => div().child("Calculating weather...").into_any_element(),
+ }
}
}
impl ToolOutput for WeatherView {
- fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
+ type Input = WeatherQuery;
+
+ type SerializedState = WeatherResult;
+
+ fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
serde_json::to_string(&self.result).unwrap()
}
+
+ fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
+ self.input = Some(input);
+ cx.notify();
+ }
+
+ fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
+ let input = self.input.as_ref().unwrap();
+
+ let _location = input.location.clone();
+ let _unit = input.unit.clone();
+
+ let weather = self.current_weather.clone();
+
+ self.result = Some(weather);
+
+ Task::ready(Ok(()))
+ }
+
+ fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
+ self.current_weather.clone()
+ }
+
+ fn deserialize(
+ &mut self,
+ output: Self::SerializedState,
+ _cx: &mut ViewContext<Self>,
+ ) -> Result<()> {
+ self.current_weather = output;
+ Ok(())
+ }
}
impl LanguageModelTool for WeatherTool {
- type Input = WeatherQuery;
- type Output = WeatherResult;
type View = WeatherView;
fn name(&self) -> String {
@@ -493,29 +476,8 @@ mod test {
"Fetches the current weather for a given location.".to_string()
}
- fn execute(
- &self,
- input: &Self::Input,
- _cx: &mut WindowContext,
- ) -> Task<Result<Self::Output>> {
- let _location = input.location.clone();
- let _unit = input.unit.clone();
-
- let weather = self.current_weather.clone();
-
- Task::ready(Ok(weather))
- }
-
- fn view(
- &self,
- _input: Self::Input,
- result: Result<Self::Output>,
- cx: &mut WindowContext,
- ) -> View<Self::View> {
- cx.new_view(|_cx| {
- let result = result.unwrap();
- WeatherView { result }
- })
+ fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
+ cx.new_view(|_cx| WeatherView::new(self.current_weather.clone()))
}
}
@@ -564,18 +526,14 @@ mod test {
})
);
- let args = json!({
- "location": "San Francisco",
- "unit": "Celsius"
- });
-
- let query: WeatherQuery = serde_json::from_value(args).unwrap();
+ let view = cx.update(|cx| tool.view(cx));
- let result = cx.update(|cx| tool.execute(&query, cx)).await;
+ cx.update(|cx| {
+ view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx);
+ });
- assert!(result.is_ok());
- let result = result.unwrap();
+ let finished = cx.update(|cx| view.execute(cx)).await;
- assert_eq!(result, tool.current_weather);
+ assert!(finished.is_ok());
}
}
@@ -1603,6 +1603,11 @@ impl MultiBuffer {
"untitled".into()
}
+ pub fn set_title(&mut self, title: String, cx: &mut ModelContext<Self>) {
+ self.title = Some(title);
+ cx.notify();
+ }
+
#[cfg(any(test, feature = "test-support"))]
pub fn is_parsing(&self, cx: &AppContext) -> bool {
self.as_singleton().unwrap().read(cx).is_parsing()
@@ -3151,10 +3156,10 @@ impl MultiBufferSnapshot {
.redacted_ranges(excerpt.range.context.clone())
.map(move |mut redacted_range| {
// Re-base onto the excerpts coordinates in the multibuffer
- redacted_range.start =
- excerpt_offset + (redacted_range.start - excerpt_buffer_start);
- redacted_range.end =
- excerpt_offset + (redacted_range.end - excerpt_buffer_start);
+ redacted_range.start = excerpt_offset
+ + redacted_range.start.saturating_sub(excerpt_buffer_start);
+ redacted_range.end = excerpt_offset
+ + redacted_range.end.saturating_sub(excerpt_buffer_start);
redacted_range
})
@@ -3179,10 +3184,13 @@ impl MultiBufferSnapshot {
.runnable_ranges(excerpt.range.context.clone())
.map(move |mut runnable| {
// Re-base onto the excerpts coordinates in the multibuffer
- runnable.run_range.start =
- excerpt_offset + (runnable.run_range.start - excerpt_buffer_start);
- runnable.run_range.end =
- excerpt_offset + (runnable.run_range.end - excerpt_buffer_start);
+ runnable.run_range.start = excerpt_offset
+ + runnable
+ .run_range
+ .start
+ .saturating_sub(excerpt_buffer_start);
+ runnable.run_range.end = excerpt_offset
+ + runnable.run_range.end.saturating_sub(excerpt_buffer_start);
runnable
})
.skip_while(move |runnable| runnable.run_range.end < range.start)