Detailed changes
@@ -1,4 +1,5 @@
-use anyhow::Context as _;
+/// This example creates a basic Chat UI with a function for rolling a die.
+use anyhow::{Context as _, Result};
use assets::Assets;
use assistant2::AssistantPanel;
use assistant_tooling::{LanguageModelTool, ToolRegistry};
@@ -83,9 +84,32 @@ struct DiceRoll {
rolls: Vec<DieRoll>,
}
+pub struct DiceView {
+ result: Result<DiceRoll>,
+}
+
+impl Render for DiceView {
+ fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
+ let output = match &self.result {
+ Ok(output) => output,
+ Err(_) => return "Somehow dice failed 🎲".into_any_element(),
+ };
+
+ h_flex()
+ .children(
+ output
+ .rolls
+ .iter()
+ .map(|roll| div().p_2().child(roll.render())),
+ )
+ .into_any_element()
+ }
+}
+
impl LanguageModelTool for RollDiceTool {
type Input = DiceParams;
type Output = DiceRoll;
+ type View = DiceView;
fn name(&self) -> String {
"roll_dice".to_string()
@@ -110,23 +134,21 @@ impl LanguageModelTool for RollDiceTool {
return Task::ready(Ok(DiceRoll { rolls }));
}
- fn render(
- _tool_call_id: &str,
- _input: &Self::Input,
- output: &Self::Output,
- _cx: &mut WindowContext,
- ) -> gpui::AnyElement {
- h_flex()
- .children(
- output
- .rolls
- .iter()
- .map(|roll| div().p_2().child(roll.render())),
- )
- .into_any_element()
+ fn new_view(
+ _tool_call_id: String,
+ _input: Self::Input,
+ result: Result<Self::Output>,
+ cx: &mut WindowContext,
+ ) -> gpui::View<Self::View> {
+ cx.new_view(|_cx| DiceView { result })
}
- fn format(_input: &Self::Input, output: &Self::Output) -> String {
+ fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
+ let output = match output {
+ Ok(output) => output,
+ Err(_) => return "Somehow dice failed 🎲".to_string(),
+ };
+
let mut result = String::new();
for roll in &output.rolls {
let die = &roll.die;
@@ -322,9 +322,11 @@ impl AssistantChat {
};
call_count += 1;
+ let messages = this.completion_messages(cx);
+
CompletionProvider::get(cx).complete(
this.model.clone(),
- this.completion_messages(cx),
+ messages,
Vec::new(),
1.0,
definitions,
@@ -407,6 +409,10 @@ impl AssistantChat {
}
let tools = join_all(tool_tasks.into_iter()).await;
+ // If the WindowContext went away for any tool's view we don't include it
+ // especially since the below call would fail for the same reason.
+ let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
+
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
this.messages.last_mut()
@@ -561,10 +567,9 @@ impl AssistantChat {
let result = &tool_call.result;
let name = tool_call.name.clone();
match result {
- Some(result) => div()
- .p_2()
- .child(result.render(&name, &tool_call.id, cx))
- .into_any(),
+ Some(result) => {
+ div().p_2().child(result.into_any_element(&name)).into_any()
+ }
None => div()
.p_2()
.child(Label::new(name).color(Color::Modified))
@@ -577,7 +582,7 @@ impl AssistantChat {
}
}
- fn completion_messages(&self, cx: &WindowContext) -> Vec<CompletionMessage> {
+ fn completion_messages(&self, cx: &mut WindowContext) -> Vec<CompletionMessage> {
let mut completion_messages = Vec::new();
for message in &self.messages {
@@ -1,10 +1,10 @@
use anyhow::Result;
use assistant_tooling::LanguageModelTool;
-use gpui::{prelude::*, AnyElement, AppContext, Model, Task};
+use gpui::{prelude::*, AppContext, Model, Task};
use project::Fs;
use schemars::JsonSchema;
use semantic_index::ProjectIndex;
-use serde::{Deserialize, Serialize};
+use serde::Deserialize;
use std::sync::Arc;
use ui::{
div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
@@ -14,11 +14,13 @@ use util::ResultExt as _;
const DEFAULT_SEARCH_LIMIT: usize = 20;
-#[derive(Serialize, Clone)]
+#[derive(Clone)]
pub struct CodebaseExcerpt {
path: SharedString,
text: SharedString,
score: f32,
+ element_id: ElementId,
+ expanded: bool,
}
// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
@@ -32,6 +34,79 @@ pub struct CodebaseQuery {
limit: Option<usize>,
}
+pub struct ProjectIndexView {
+ input: CodebaseQuery,
+ output: Result<Vec<CodebaseExcerpt>>,
+}
+
+impl ProjectIndexView {
+ fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext<Self>) {
+ if let Ok(excerpts) = &mut self.output {
+ if let Some(excerpt) = excerpts
+ .iter_mut()
+ .find(|excerpt| excerpt.element_id == element_id)
+ {
+ excerpt.expanded = !excerpt.expanded;
+ cx.notify();
+ }
+ }
+ }
+}
+
+impl Render for ProjectIndexView {
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ let query = self.input.query.clone();
+
+ let result = &self.output;
+
+ let excerpts = match result {
+ Err(err) => {
+ return div().child(Label::new(format!("Error: {}", err)).color(Color::Error));
+ }
+ Ok(excerpts) => excerpts,
+ };
+
+ div()
+ .v_flex()
+ .gap_2()
+ .child(
+ div()
+ .p_2()
+ .rounded_md()
+ .bg(cx.theme().colors().editor_background)
+ .child(
+ h_flex()
+ .child(Label::new("Query: ").color(Color::Modified))
+ .child(Label::new(query).color(Color::Muted)),
+ ),
+ )
+ .children(excerpts.iter().map(|excerpt| {
+ let element_id = excerpt.element_id.clone();
+ let expanded = excerpt.expanded;
+
+ CollapsibleContainer::new(element_id.clone(), expanded)
+ .start_slot(
+ h_flex()
+ .gap_1()
+ .child(Icon::new(IconName::File).color(Color::Muted))
+ .child(Label::new(excerpt.path.clone()).color(Color::Muted)),
+ )
+ .on_click(cx.listener(move |this, _, cx| {
+ this.toggle_expanded(element_id.clone(), cx);
+ }))
+ .child(
+ div()
+ .p_2()
+ .rounded_md()
+ .bg(cx.theme().colors().editor_background)
+ .child(
+ excerpt.text.clone(), // todo!(): Show as an editor block
+ ),
+ )
+ }))
+ }
+}
+
pub struct ProjectIndexTool {
project_index: Model<ProjectIndex>,
fs: Arc<dyn Fs>,
@@ -47,6 +122,7 @@ impl ProjectIndexTool {
impl LanguageModelTool for ProjectIndexTool {
type Input = CodebaseQuery;
type Output = Vec<CodebaseExcerpt>;
+ type View = ProjectIndexView;
fn name(&self) -> String {
"query_codebase".to_string()
@@ -90,6 +166,8 @@ impl LanguageModelTool for ProjectIndexTool {
}
anyhow::Ok(CodebaseExcerpt {
+ element_id: ElementId::Name(nanoid::nanoid!().into()),
+ expanded: false,
path: path.to_string_lossy().to_string().into(),
text: SharedString::from(text[start..end].to_string()),
score: result.score,
@@ -106,71 +184,37 @@ impl LanguageModelTool for ProjectIndexTool {
})
}
- fn render(
- _tool_call_id: &str,
- input: &Self::Input,
- excerpts: &Self::Output,
+ fn new_view(
+ _tool_call_id: String,
+ input: Self::Input,
+ output: Result<Self::Output>,
cx: &mut WindowContext,
- ) -> AnyElement {
- let query = input.query.clone();
-
- div()
- .v_flex()
- .gap_2()
- .child(
- div()
- .p_2()
- .rounded_md()
- .bg(cx.theme().colors().editor_background)
- .child(
- h_flex()
- .child(Label::new("Query: ").color(Color::Modified))
- .child(Label::new(query).color(Color::Muted)),
- ),
- )
- .children(excerpts.iter().map(|excerpt| {
- // This render doesn't have state/model, so we can't use the listener
- // let expanded = excerpt.expanded;
- // let element_id = excerpt.element_id.clone();
- let element_id = ElementId::Name(nanoid::nanoid!().into());
- let expanded = false;
-
- CollapsibleContainer::new(element_id.clone(), expanded)
- .start_slot(
- h_flex()
- .gap_1()
- .child(Icon::new(IconName::File).color(Color::Muted))
- .child(Label::new(excerpt.path.clone()).color(Color::Muted)),
- )
- // .on_click(cx.listener(move |this, _, cx| {
- // this.toggle_expanded(element_id.clone(), cx);
- // }))
- .child(
- div()
- .p_2()
- .rounded_md()
- .bg(cx.theme().colors().editor_background)
- .child(
- excerpt.text.clone(), // todo!(): Show as an editor block
- ),
- )
- }))
- .into_any_element()
+ ) -> gpui::View<Self::View> {
+ cx.new_view(|_cx| ProjectIndexView { input, output })
}
- fn format(_input: &Self::Input, excerpts: &Self::Output) -> String {
- let mut body = "Semantic search results:\n".to_string();
-
- for excerpt in excerpts {
- body.push_str("Excerpt from ");
- body.push_str(excerpt.path.as_ref());
- body.push_str(", score ");
- body.push_str(&excerpt.score.to_string());
- body.push_str(":\n");
- body.push_str("~~~\n");
- body.push_str(excerpt.text.as_ref());
- body.push_str("~~~\n");
+ fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
+ match &output {
+ Ok(excerpts) => {
+ if excerpts.len() == 0 {
+ return "No results found".to_string();
+ }
+
+ let mut body = "Semantic search results:\n".to_string();
+
+ for excerpt in excerpts {
+ body.push_str("Excerpt from ");
+ body.push_str(excerpt.path.as_ref());
+ body.push_str(", score ");
+ body.push_str(&excerpt.score.to_string());
+ body.push_str(":\n");
+ body.push_str("~~~\n");
+ body.push_str(excerpt.text.as_ref());
+ body.push_str("~~~\n");
+ }
+ body
+ }
+ Err(err) => format!("Error: {}", err),
}
- body
}
}
@@ -1,13 +1,16 @@
use anyhow::{anyhow, Result};
-use gpui::{AnyElement, AppContext, Task, WindowContext};
-use std::{any::Any, collections::HashMap};
+use gpui::{Task, WindowContext};
+use std::collections::HashMap;
use crate::tool::{
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
};
pub struct ToolRegistry {
- tools: HashMap<String, Box<dyn Fn(&ToolFunctionCall, &AppContext) -> Task<ToolFunctionCall>>>,
+ tools: HashMap<
+ String,
+ Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
+ >,
definitions: Vec<ToolFunctionDefinition>,
}
@@ -24,77 +27,45 @@ impl ToolRegistry {
}
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
- fn render<T: 'static + LanguageModelTool>(
- tool_call_id: &str,
- input: &Box<dyn Any>,
- output: &Box<dyn Any>,
- cx: &mut WindowContext,
- ) -> AnyElement {
- T::render(
- tool_call_id,
- input.as_ref().downcast_ref::<T::Input>().unwrap(),
- output.as_ref().downcast_ref::<T::Output>().unwrap(),
- cx,
- )
- }
-
- fn format<T: 'static + LanguageModelTool>(
- input: &Box<dyn Any>,
- output: &Box<dyn Any>,
- ) -> String {
- T::format(
- input.as_ref().downcast_ref::<T::Input>().unwrap(),
- output.as_ref().downcast_ref::<T::Output>().unwrap(),
- )
- }
-
self.definitions.push(tool.definition());
let name = tool.name();
let previous = self.tools.insert(
name.clone(),
- Box::new(move |tool_call: &ToolFunctionCall, cx: &AppContext| {
- let name = tool_call.name.clone();
- let arguments = tool_call.arguments.clone();
- let id = tool_call.id.clone();
-
- let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
- return Task::ready(ToolFunctionCall {
- id,
- name: name.clone(),
- arguments,
- result: Some(ToolFunctionCallResult::ParsingFailed),
- });
- };
-
- let result = tool.execute(&input, cx);
-
- cx.spawn(move |_cx| async move {
- match result.await {
- Ok(result) => {
- let result: T::Output = result;
- ToolFunctionCall {
- id,
- name: name.clone(),
- arguments,
- result: Some(ToolFunctionCallResult::Finished {
- input: Box::new(input),
- output: Box::new(result),
- render_fn: render::<T>,
- format_fn: format::<T>,
- }),
- }
- }
- Err(_error) => ToolFunctionCall {
+ // registry.call(tool_call, cx)
+ Box::new(
+ move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
+ let name = tool_call.name.clone();
+ let arguments = tool_call.arguments.clone();
+ let id = tool_call.id.clone();
+
+ let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
+ return Task::ready(Ok(ToolFunctionCall {
id,
name: name.clone(),
arguments,
- result: Some(ToolFunctionCallResult::ExecutionFailed {
- input: Box::new(input),
+ result: Some(ToolFunctionCallResult::ParsingFailed),
+ }));
+ };
+
+ let result = tool.execute(&input, cx);
+
+ cx.spawn(move |mut cx| async move {
+ let result: Result<T::Output> = result.await;
+ let for_model = T::format(&input, &result);
+ let view = cx.update(|cx| T::new_view(id.clone(), input, result, cx))?;
+
+ Ok(ToolFunctionCall {
+ id,
+ name: name.clone(),
+ arguments,
+ result: Some(ToolFunctionCallResult::Finished {
+ view: view.into(),
+ for_model,
}),
- },
- }
- })
- }),
+ })
+ })
+ },
+ ),
);
if previous.is_some() {
@@ -104,7 +75,12 @@ impl ToolRegistry {
Ok(())
}
- pub fn call(&self, tool_call: &ToolFunctionCall, cx: &AppContext) -> Task<ToolFunctionCall> {
+ /// Task yields an error if the window for the given WindowContext is closed before the task completes.
+ pub fn call(
+ &self,
+ tool_call: &ToolFunctionCall,
+ cx: &mut WindowContext,
+ ) -> Task<Result<ToolFunctionCall>> {
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let id = tool_call.id.clone();
@@ -113,12 +89,12 @@ impl ToolRegistry {
Some(tool) => tool,
None => {
let name = name.clone();
- return Task::ready(ToolFunctionCall {
+ return Task::ready(Ok(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::NoSuchTool),
- });
+ }));
}
};
@@ -128,12 +104,10 @@ impl ToolRegistry {
#[cfg(test)]
mod test {
-
use super::*;
-
+ use gpui::View;
+ use gpui::{div, prelude::*, Render, TestAppContext};
use schemars::schema_for;
-
- use gpui::{div, AnyElement, Element, ParentElement, TestAppContext, WindowContext};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
@@ -155,9 +129,20 @@ mod test {
unit: String,
}
+ struct WeatherView {
+ result: WeatherResult,
+ }
+
+ impl Render for WeatherView {
+ fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
+ div().child(format!("temperature: {}", self.result.temperature))
+ }
+ }
+
impl LanguageModelTool for WeatherTool {
type Input = WeatherQuery;
type Output = WeatherResult;
+ type View = WeatherView;
fn name(&self) -> String {
"get_current_weather".to_string()
@@ -167,7 +152,11 @@ mod test {
"Fetches the current weather for a given location.".to_string()
}
- fn execute(&self, input: &WeatherQuery, _cx: &AppContext) -> Task<Result<Self::Output>> {
+ fn execute(
+ &self,
+ input: &Self::Input,
+ _cx: &gpui::AppContext,
+ ) -> Task<Result<Self::Output>> {
let _location = input.location.clone();
let _unit = input.unit.clone();
@@ -176,25 +165,20 @@ mod test {
Task::ready(Ok(weather))
}
- fn render(
- _tool_call_id: &str,
- _input: &Self::Input,
- output: &Self::Output,
- _cx: &mut WindowContext,
- ) -> AnyElement {
- div()
- .child(format!(
- "The current temperature in {} is {} {}",
- output.location, output.temperature, output.unit
- ))
- .into_any()
+ fn new_view(
+ _tool_call_id: String,
+ _input: Self::Input,
+ result: Result<Self::Output>,
+ cx: &mut WindowContext,
+ ) -> View<Self::View> {
+ cx.new_view(|_cx| {
+ let result = result.unwrap();
+ WeatherView { result }
+ })
}
- fn format(_input: &Self::Input, output: &Self::Output) -> String {
- format!(
- "The current temperature in {} is {} {}",
- output.location, output.temperature, output.unit
- )
+ fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
+ serde_json::to_string(&output.as_ref().unwrap()).unwrap()
}
}
@@ -214,20 +198,20 @@ mod test {
registry.register(tool).unwrap();
- let _result = cx
- .update(|cx| {
- registry.call(
- &ToolFunctionCall {
- name: "get_current_weather".to_string(),
- arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
- .to_string(),
- id: "test-123".to_string(),
- result: None,
- },
- cx,
- )
- })
- .await;
+ // let _result = cx
+ // .update(|cx| {
+ // registry.call(
+ // &ToolFunctionCall {
+ // name: "get_current_weather".to_string(),
+ // arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
+ // .to_string(),
+ // id: "test-123".to_string(),
+ // result: None,
+ // },
+ // cx,
+ // )
+ // })
+ // .await;
// assert!(result.is_ok());
// let result = result.unwrap();
@@ -1,11 +1,8 @@
use anyhow::Result;
-use gpui::{div, AnyElement, AppContext, Element, ParentElement as _, Task, WindowContext};
+use gpui::{AnyElement, AnyView, AppContext, IntoElement as _, Render, Task, View, WindowContext};
use schemars::{schema::RootSchema, schema_for, JsonSchema};
use serde::Deserialize;
-use std::{
- any::Any,
- fmt::{Debug, Display},
-};
+use std::fmt::Display;
#[derive(Default, Deserialize)]
pub struct ToolFunctionCall {
@@ -19,71 +16,29 @@ pub struct ToolFunctionCall {
pub enum ToolFunctionCallResult {
NoSuchTool,
ParsingFailed,
- ExecutionFailed {
- input: Box<dyn Any>,
- },
- Finished {
- input: Box<dyn Any>,
- output: Box<dyn Any>,
- render_fn: fn(
- // tool_call_id
- &str,
- // LanguageModelTool::Input
- &Box<dyn Any>,
- // LanguageModelTool::Output
- &Box<dyn Any>,
- &mut WindowContext,
- ) -> AnyElement,
- format_fn: fn(
- // LanguageModelTool::Input
- &Box<dyn Any>,
- // LanguageModelTool::Output
- &Box<dyn Any>,
- ) -> String,
- },
+ Finished { for_model: String, view: AnyView },
}
impl ToolFunctionCallResult {
- pub fn render(
- &self,
- tool_name: &str,
- tool_call_id: &str,
- cx: &mut WindowContext,
- ) -> AnyElement {
+ pub fn format(&self, name: &String) -> String {
match self {
- ToolFunctionCallResult::NoSuchTool => {
- div().child(format!("no such tool {tool_name}")).into_any()
+ ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
+ ToolFunctionCallResult::ParsingFailed => {
+ format!("Unable to parse arguments for {name}")
}
- ToolFunctionCallResult::ParsingFailed => div()
- .child(format!("failed to parse input for tool {tool_name}"))
- .into_any(),
- ToolFunctionCallResult::ExecutionFailed { .. } => div()
- .child(format!("failed to execute tool {tool_name}"))
- .into_any(),
- ToolFunctionCallResult::Finished {
- input,
- output,
- render_fn,
- ..
- } => render_fn(tool_call_id, input, output, cx),
+ ToolFunctionCallResult::Finished { for_model, .. } => for_model.clone(),
}
}
- pub fn format(&self, tool: &str) -> String {
+ pub fn into_any_element(&self, name: &String) -> AnyElement {
match self {
- ToolFunctionCallResult::NoSuchTool => format!("no such tool {tool}"),
- ToolFunctionCallResult::ParsingFailed => {
- format!("failed to parse input for tool {tool}")
+ ToolFunctionCallResult::NoSuchTool => {
+ format!("Language Model attempted to call {name}").into_any_element()
}
- ToolFunctionCallResult::ExecutionFailed { input: _input } => {
- format!("failed to execute tool {tool}")
+ ToolFunctionCallResult::ParsingFailed => {
+ format!("Language Model called {name} with bad arguments").into_any_element()
}
- ToolFunctionCallResult::Finished {
- input,
- output,
- format_fn,
- ..
- } => format_fn(input, output),
+ ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
}
}
}
@@ -105,19 +60,6 @@ impl Display for ToolFunctionDefinition {
}
}
-impl Debug for ToolFunctionDefinition {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- let schema = serde_json::to_string(&self.parameters).ok();
- let schema = schema.unwrap_or("None".to_string());
-
- f.debug_struct("ToolFunctionDefinition")
- .field("name", &self.name)
- .field("description", &self.description)
- .field("parameters", &schema)
- .finish()
- }
-}
-
pub trait LanguageModelTool {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
@@ -126,6 +68,8 @@ pub trait LanguageModelTool {
/// The output returned by executing the tool.
type Output: 'static;
+ type View: Render;
+
/// The name of the tool is exposed to the language model to allow
/// the model to pick which tools to use. As this name is used to
/// identify the tool within a tool registry, it should be unique.
@@ -149,12 +93,12 @@ pub trait LanguageModelTool {
/// Execute the tool
fn execute(&self, input: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>>;
- fn render(
- tool_call_id: &str,
- input: &Self::Input,
- output: &Self::Output,
- cx: &mut WindowContext,
- ) -> AnyElement;
+ fn format(input: &Self::Input, output: &Result<Self::Output>) -> String;
- fn format(input: &Self::Input, output: &Self::Output) -> String;
+ fn new_view(
+ tool_call_id: String,
+ input: Self::Input,
+ output: Result<Self::Output>,
+ cx: &mut WindowContext,
+ ) -> View<Self::View>;
}