@@ -11,8 +11,7 @@ use crate::ui::UserOrAssistant;
use ::ui::{div, prelude::*, Color, Tooltip, ViewContext};
use anyhow::{Context, Result};
use assistant_tooling::{
- tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry,
- UserAttachment,
+ AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment,
};
use attachments::ActiveEditorAttachmentTool;
use client::{proto, Client, UserStore};
@@ -130,16 +129,13 @@ impl AssistantPanel {
let mut tool_registry = ToolRegistry::new();
tool_registry
- .register(ProjectIndexTool::new(project_index.clone()), cx)
+ .register(ProjectIndexTool::new(project_index.clone()))
.unwrap();
tool_registry
- .register(
- CreateBufferTool::new(workspace.clone(), project.clone()),
- cx,
- )
+ .register(CreateBufferTool::new(workspace.clone(), project.clone()))
.unwrap();
tool_registry
- .register(AnnotationTool::new(workspace.clone(), project.clone()), cx)
+ .register(AnnotationTool::new(workspace.clone(), project.clone()))
.unwrap();
let mut attachment_registry = AttachmentRegistry::new();
@@ -588,9 +584,9 @@ impl AssistantChat {
cx.notify();
} else {
if let Some(current_message) = messages.last_mut() {
- for tool_call in current_message.tool_calls.iter() {
+ for tool_call in current_message.tool_calls.iter_mut() {
tool_tasks
- .extend(this.tool_registry.execute_tool_call(&tool_call, cx));
+ .extend(this.tool_registry.execute_tool_call(tool_call, cx));
}
}
}
@@ -847,7 +843,7 @@ impl AssistantChat {
let tools = message
.tool_calls
.iter()
- .map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
+ .filter_map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
.collect::<Vec<AnyElement>>();
if !tools.is_empty() {
@@ -856,7 +852,7 @@ impl AssistantChat {
}
if message_elements.is_empty() {
- message_elements.push(tool_running_placeholder());
+ message_elements.push(::ui::Label::new("Researching...").into_any_element())
}
div()
@@ -9,10 +9,8 @@ use std::{
any::TypeId,
collections::HashMap,
fmt::Display,
- sync::{
- atomic::{AtomicBool, Ordering::SeqCst},
- Arc,
- },
+ mem,
+ sync::atomic::{AtomicBool, Ordering::SeqCst},
};
use ui::ViewContext;
@@ -29,7 +27,7 @@ pub struct ToolFunctionCall {
}
#[derive(Default)]
-pub enum ToolFunctionCallState {
+enum ToolFunctionCallState {
#[default]
Initializing,
NoSuchTool,
@@ -37,10 +35,10 @@ pub enum ToolFunctionCallState {
ExecutedTool(Box<dyn ToolView>),
}
-pub trait ToolView {
+trait ToolView {
fn view(&self) -> AnyView;
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
- fn set_input(&self, input: &str, cx: &mut WindowContext);
+ fn try_set_input(&self, input: &str, cx: &mut WindowContext);
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
@@ -48,14 +46,14 @@ pub trait ToolView {
#[derive(Default, Serialize, Deserialize)]
pub struct SavedToolFunctionCall {
- pub id: String,
- pub name: String,
- pub arguments: String,
- pub state: SavedToolFunctionCallState,
+ id: String,
+ name: String,
+ arguments: String,
+ state: SavedToolFunctionCallState,
}
#[derive(Default, Serialize, Deserialize)]
-pub enum SavedToolFunctionCallState {
+enum SavedToolFunctionCallState {
#[default]
Initializing,
NoSuchTool,
@@ -63,7 +61,7 @@ pub enum SavedToolFunctionCallState {
ExecutedTool(Box<RawValue>),
}
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, PartialEq)]
pub struct ToolFunctionDefinition {
pub name: String,
pub description: String,
@@ -100,18 +98,6 @@ pub trait LanguageModelTool {
fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
}
-pub fn tool_running_placeholder() -> AnyElement {
- ui::Label::new("Researching...").into_any_element()
-}
-
-pub fn unknown_tool_placeholder() -> AnyElement {
- ui::Label::new("Unknown tool").into_any_element()
-}
-
-pub fn no_such_tool_placeholder() -> AnyElement {
- ui::Label::new("No such tool").into_any_element()
-}
-
pub trait ToolOutput: Render {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
@@ -172,11 +158,6 @@ impl ToolRegistry {
.collect()
}
- pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option<Box<dyn ToolView>> {
- let tool = self.registered_tools.get(name)?;
- Some((tool.build_view)(cx))
- }
-
pub fn update_tool_call(
&self,
call: &mut ToolFunctionCall,
@@ -189,7 +170,8 @@ impl ToolRegistry {
}
if let Some(arguments) = arguments {
if call.arguments.is_empty() {
- if let Some(view) = self.view_for_tool(&call.name, cx) {
+ if let Some(tool) = self.registered_tools.get(&call.name) {
+ let view = (tool.build_view)(cx);
call.state = ToolFunctionCallState::KnownTool(view);
} else {
call.state = ToolFunctionCallState::NoSuchTool;
@@ -199,7 +181,7 @@ impl ToolRegistry {
if let ToolFunctionCallState::KnownTool(view) = &call.state {
if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
- view.set_input(&repaired_arguments, cx)
+ view.try_set_input(&repaired_arguments, cx)
}
}
}
@@ -207,11 +189,13 @@ impl ToolRegistry {
pub fn execute_tool_call(
&self,
- tool_call: &ToolFunctionCall,
+ tool_call: &mut ToolFunctionCall,
cx: &mut WindowContext,
) -> Option<Task<Result<()>>> {
- if let ToolFunctionCallState::KnownTool(view) = &tool_call.state {
- Some(view.execute(cx))
+ if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) {
+ let task = view.execute(cx);
+ tool_call.state = ToolFunctionCallState::ExecutedTool(view);
+ Some(task)
} else {
None
}
@@ -221,12 +205,14 @@ impl ToolRegistry {
&self,
tool_call: &ToolFunctionCall,
_cx: &mut WindowContext,
- ) -> AnyElement {
+ ) -> Option<AnyElement> {
match &tool_call.state {
- ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(),
- ToolFunctionCallState::Initializing => unknown_tool_placeholder(),
+ ToolFunctionCallState::NoSuchTool => {
+ Some(ui::Label::new("No such tool").into_any_element())
+ }
+ ToolFunctionCallState::Initializing => None,
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
- view.view().into_any_element()
+ Some(view.view().into_any_element())
}
}
}
@@ -287,12 +273,12 @@ impl ToolRegistry {
SavedToolFunctionCallState::KnownTool => {
log::error!("Deserialized tool that had not executed");
let view = (tool.build_view)(cx);
- view.set_input(&call.arguments, cx);
+ view.try_set_input(&call.arguments, cx);
ToolFunctionCallState::KnownTool(view)
}
SavedToolFunctionCallState::ExecutedTool(output) => {
let view = (tool.build_view)(cx);
- view.set_input(&call.arguments, cx);
+ view.try_set_input(&call.arguments, cx);
view.deserialize_output(output, cx)?;
ToolFunctionCallState::ExecutedTool(view)
}
@@ -300,13 +286,8 @@ impl ToolRegistry {
})
}
- pub fn register<T: 'static + LanguageModelTool>(
- &mut self,
- tool: T,
- _cx: &mut WindowContext,
- ) -> Result<()> {
+ pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
let name = tool.name();
- let tool = Arc::new(tool);
let registered_tool = RegisteredTool {
type_id: TypeId::of::<T>(),
definition: tool.definition(),
@@ -332,7 +313,7 @@ impl<T: ToolOutput> ToolView for View<T> {
self.update(cx, |view, cx| view.generate(project, cx))
}
- fn set_input(&self, input: &str, cx: &mut WindowContext) {
+ fn try_set_input(&self, input: &str, cx: &mut WindowContext) {
if let Ok(input) = serde_json::from_str::<T::Input>(input) {
self.update(cx, |view, cx| {
view.set_input(input, cx);
@@ -372,7 +353,6 @@ mod test {
use super::*;
use gpui::{div, prelude::*, Render, TestAppContext};
use gpui::{EmptyView, View};
- use schemars::schema_for;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
@@ -483,57 +463,64 @@ mod test {
#[gpui::test]
async fn test_openai_weather_example(cx: &mut TestAppContext) {
- cx.background_executor.run_until_parked();
let (_, cx) = cx.add_window_view(|_cx| EmptyView);
- let tool = WeatherTool {
- current_weather: WeatherResult {
- location: "San Francisco".to_string(),
- temperature: 21.0,
- unit: "Celsius".to_string(),
- },
- };
-
- let tools = vec![tool.definition()];
- assert_eq!(tools.len(), 1);
-
- let expected = ToolFunctionDefinition {
- name: "get_current_weather".to_string(),
- description: "Fetches the current weather for a given location.".to_string(),
- parameters: schema_for!(WeatherQuery),
- };
-
- assert_eq!(tools[0].name, expected.name);
- assert_eq!(tools[0].description, expected.description);
-
- let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
+ let mut registry = ToolRegistry::new();
+ registry
+ .register(WeatherTool {
+ current_weather: WeatherResult {
+ location: "San Francisco".to_string(),
+ temperature: 21.0,
+ unit: "Celsius".to_string(),
+ },
+ })
+ .unwrap();
+ let definitions = registry.definitions();
assert_eq!(
- expected_schema,
- json!({
- "$schema": "http://json-schema.org/draft-07/schema#",
- "title": "WeatherQuery",
- "type": "object",
- "properties": {
- "location": {
- "type": "string"
+ definitions,
+ [ToolFunctionDefinition {
+ name: "get_current_weather".to_string(),
+ description: "Fetches the current weather for a given location.".to_string(),
+ parameters: serde_json::from_value(json!({
+ "$schema": "http://json-schema.org/draft-07/schema#",
+ "title": "WeatherQuery",
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string"
+ },
+ "unit": {
+ "type": "string"
+ }
},
- "unit": {
- "type": "string"
- }
- },
- "required": ["location", "unit"]
- })
+ "required": ["location", "unit"]
+ }))
+ .unwrap(),
+ }]
);
- let view = cx.update(|cx| tool.view(cx));
+ let mut call = ToolFunctionCall {
+ id: "the-id".to_string(),
+ name: "get_cur".to_string(),
+ ..Default::default()
+ };
- cx.update(|cx| {
- view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx);
+ let task = cx.update(|cx| {
+ registry.update_tool_call(
+ &mut call,
+ Some("rent_weather"),
+ Some(r#"{"location": "San Francisco","#),
+ cx,
+ );
+ registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx);
+ registry.execute_tool_call(&mut call, cx).unwrap()
});
+ task.await.unwrap();
- let finished = cx.update(|cx| view.execute(cx)).await;
-
- assert!(finished.is_ok());
+ match &call.state {
+ ToolFunctionCallState::ExecutedTool(_view) => {}
+ _ => panic!(),
+ }
}
}