1use crate::ProjectContext;
2use anyhow::{anyhow, Result};
3use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
4use repair_json::repair;
5use schemars::{schema::RootSchema, schema_for, JsonSchema};
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7use serde_json::value::RawValue;
8use std::{
9 any::TypeId,
10 collections::HashMap,
11 fmt::Display,
12 sync::{
13 atomic::{AtomicBool, Ordering::SeqCst},
14 Arc,
15 },
16};
17use ui::ViewContext;
18
19pub struct ToolRegistry {
20 registered_tools: HashMap<String, RegisteredTool>,
21}
22
23#[derive(Default)]
24pub struct ToolFunctionCall {
25 pub id: String,
26 pub name: String,
27 pub arguments: String,
28 state: ToolFunctionCallState,
29}
30
31#[derive(Default)]
32pub enum ToolFunctionCallState {
33 #[default]
34 Initializing,
35 NoSuchTool,
36 KnownTool(Box<dyn ToolView>),
37 ExecutedTool(Box<dyn ToolView>),
38}
39
40pub trait ToolView {
41 fn view(&self) -> AnyView;
42 fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
43 fn set_input(&self, input: &str, cx: &mut WindowContext);
44 fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
45 fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
46 fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
47}
48
49#[derive(Default, Serialize, Deserialize)]
50pub struct SavedToolFunctionCall {
51 pub id: String,
52 pub name: String,
53 pub arguments: String,
54 pub state: SavedToolFunctionCallState,
55}
56
57#[derive(Default, Serialize, Deserialize)]
58pub enum SavedToolFunctionCallState {
59 #[default]
60 Initializing,
61 NoSuchTool,
62 KnownTool,
63 ExecutedTool(Box<RawValue>),
64}
65
66#[derive(Clone, Debug)]
67pub struct ToolFunctionDefinition {
68 pub name: String,
69 pub description: String,
70 pub parameters: RootSchema,
71}
72
73pub trait LanguageModelTool {
74 type View: ToolOutput;
75
76 /// Returns the name of the tool.
77 ///
78 /// This name is exposed to the language model to allow the model to pick
79 /// which tools to use. As this name is used to identify the tool within a
80 /// tool registry, it should be unique.
81 fn name(&self) -> String;
82
83 /// Returns the description of the tool.
84 ///
85 /// This can be used to _prompt_ the model as to what the tool does.
86 fn description(&self) -> String;
87
88 /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
89 fn definition(&self) -> ToolFunctionDefinition {
90 let root_schema = schema_for!(<Self::View as ToolOutput>::Input);
91
92 ToolFunctionDefinition {
93 name: self.name(),
94 description: self.description(),
95 parameters: root_schema,
96 }
97 }
98
99 /// A view of the output of running the tool, for displaying to the user.
100 fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
101}
102
103pub fn tool_running_placeholder() -> AnyElement {
104 ui::Label::new("Researching...").into_any_element()
105}
106
107pub fn unknown_tool_placeholder() -> AnyElement {
108 ui::Label::new("Unknown tool").into_any_element()
109}
110
111pub fn no_such_tool_placeholder() -> AnyElement {
112 ui::Label::new("No such tool").into_any_element()
113}
114
115pub trait ToolOutput: Render {
116 /// The input type that will be passed in to `execute` when the tool is called
117 /// by the language model.
118 type Input: DeserializeOwned + JsonSchema;
119
120 /// The output returned by executing the tool.
121 type SerializedState: DeserializeOwned + Serialize;
122
123 fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext<Self>) -> String;
124 fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>);
125 fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>>;
126
127 fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState;
128 fn deserialize(
129 &mut self,
130 output: Self::SerializedState,
131 cx: &mut ViewContext<Self>,
132 ) -> Result<()>;
133}
134
135struct RegisteredTool {
136 enabled: AtomicBool,
137 type_id: TypeId,
138 build_view: Box<dyn Fn(&mut WindowContext) -> Box<dyn ToolView>>,
139 definition: ToolFunctionDefinition,
140}
141
142impl ToolRegistry {
143 pub fn new() -> Self {
144 Self {
145 registered_tools: HashMap::new(),
146 }
147 }
148
149 pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
150 for tool in self.registered_tools.values() {
151 if tool.type_id == TypeId::of::<T>() {
152 tool.enabled.store(is_enabled, SeqCst);
153 return;
154 }
155 }
156 }
157
158 pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
159 for tool in self.registered_tools.values() {
160 if tool.type_id == TypeId::of::<T>() {
161 return tool.enabled.load(SeqCst);
162 }
163 }
164 false
165 }
166
167 pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
168 self.registered_tools
169 .values()
170 .filter(|tool| tool.enabled.load(SeqCst))
171 .map(|tool| tool.definition.clone())
172 .collect()
173 }
174
175 pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option<Box<dyn ToolView>> {
176 let tool = self.registered_tools.get(name)?;
177 Some((tool.build_view)(cx))
178 }
179
180 pub fn update_tool_call(
181 &self,
182 call: &mut ToolFunctionCall,
183 name: Option<&str>,
184 arguments: Option<&str>,
185 cx: &mut WindowContext,
186 ) {
187 if let Some(name) = name {
188 call.name.push_str(name);
189 }
190 if let Some(arguments) = arguments {
191 if call.arguments.is_empty() {
192 if let Some(view) = self.view_for_tool(&call.name, cx) {
193 call.state = ToolFunctionCallState::KnownTool(view);
194 } else {
195 call.state = ToolFunctionCallState::NoSuchTool;
196 }
197 }
198 call.arguments.push_str(arguments);
199
200 if let ToolFunctionCallState::KnownTool(view) = &call.state {
201 if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
202 view.set_input(&repaired_arguments, cx)
203 }
204 }
205 }
206 }
207
208 pub fn execute_tool_call(
209 &self,
210 tool_call: &ToolFunctionCall,
211 cx: &mut WindowContext,
212 ) -> Option<Task<Result<()>>> {
213 if let ToolFunctionCallState::KnownTool(view) = &tool_call.state {
214 Some(view.execute(cx))
215 } else {
216 None
217 }
218 }
219
220 pub fn render_tool_call(
221 &self,
222 tool_call: &ToolFunctionCall,
223 _cx: &mut WindowContext,
224 ) -> AnyElement {
225 match &tool_call.state {
226 ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(),
227 ToolFunctionCallState::Initializing => unknown_tool_placeholder(),
228 ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
229 view.view().into_any_element()
230 }
231 }
232 }
233
234 pub fn content_for_tool_call(
235 &self,
236 tool_call: &ToolFunctionCall,
237 project_context: &mut ProjectContext,
238 cx: &mut WindowContext,
239 ) -> String {
240 match &tool_call.state {
241 ToolFunctionCallState::Initializing => String::new(),
242 ToolFunctionCallState::NoSuchTool => {
243 format!("No such tool: {}", tool_call.name)
244 }
245 ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
246 view.generate(project_context, cx)
247 }
248 }
249 }
250
251 pub fn serialize_tool_call(
252 &self,
253 call: &ToolFunctionCall,
254 cx: &mut WindowContext,
255 ) -> Result<SavedToolFunctionCall> {
256 Ok(SavedToolFunctionCall {
257 id: call.id.clone(),
258 name: call.name.clone(),
259 arguments: call.arguments.clone(),
260 state: match &call.state {
261 ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing,
262 ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool,
263 ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool,
264 ToolFunctionCallState::ExecutedTool(view) => {
265 SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?)
266 }
267 },
268 })
269 }
270
271 pub fn deserialize_tool_call(
272 &self,
273 call: &SavedToolFunctionCall,
274 cx: &mut WindowContext,
275 ) -> Result<ToolFunctionCall> {
276 let Some(tool) = self.registered_tools.get(&call.name) else {
277 return Err(anyhow!("no such tool {}", call.name));
278 };
279
280 Ok(ToolFunctionCall {
281 id: call.id.clone(),
282 name: call.name.clone(),
283 arguments: call.arguments.clone(),
284 state: match &call.state {
285 SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing,
286 SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool,
287 SavedToolFunctionCallState::KnownTool => {
288 log::error!("Deserialized tool that had not executed");
289 let view = (tool.build_view)(cx);
290 view.set_input(&call.arguments, cx);
291 ToolFunctionCallState::KnownTool(view)
292 }
293 SavedToolFunctionCallState::ExecutedTool(output) => {
294 let view = (tool.build_view)(cx);
295 view.set_input(&call.arguments, cx);
296 view.deserialize_output(output, cx)?;
297 ToolFunctionCallState::ExecutedTool(view)
298 }
299 },
300 })
301 }
302
303 pub fn register<T: 'static + LanguageModelTool>(
304 &mut self,
305 tool: T,
306 _cx: &mut WindowContext,
307 ) -> Result<()> {
308 let name = tool.name();
309 let tool = Arc::new(tool);
310 let registered_tool = RegisteredTool {
311 type_id: TypeId::of::<T>(),
312 definition: tool.definition(),
313 enabled: AtomicBool::new(true),
314 build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))),
315 };
316
317 let previous = self.registered_tools.insert(name.clone(), registered_tool);
318 if previous.is_some() {
319 return Err(anyhow!("already registered a tool with name {}", name));
320 }
321
322 return Ok(());
323 }
324}
325
326impl<T: ToolOutput> ToolView for View<T> {
327 fn view(&self) -> AnyView {
328 self.clone().into()
329 }
330
331 fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
332 self.update(cx, |view, cx| view.generate(project, cx))
333 }
334
335 fn set_input(&self, input: &str, cx: &mut WindowContext) {
336 if let Ok(input) = serde_json::from_str::<T::Input>(input) {
337 self.update(cx, |view, cx| {
338 view.set_input(input, cx);
339 cx.notify();
340 });
341 }
342 }
343
344 fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>> {
345 self.update(cx, |view, cx| view.execute(cx))
346 }
347
348 fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>> {
349 let output = self.update(cx, |view, cx| view.serialize(cx));
350 Ok(RawValue::from_string(serde_json::to_string(&output)?)?)
351 }
352
353 fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> {
354 let state = serde_json::from_str::<T::SerializedState>(output.get())?;
355 self.update(cx, |view, cx| view.deserialize(state, cx))?;
356 Ok(())
357 }
358}
359
360impl Display for ToolFunctionDefinition {
361 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 let schema = serde_json::to_string(&self.parameters).ok();
363 let schema = schema.unwrap_or("None".to_string());
364 write!(f, "Name: {}:\n", self.name)?;
365 write!(f, "Description: {}\n", self.description)?;
366 write!(f, "Parameters: {}", schema)
367 }
368}
369
370#[cfg(test)]
371mod test {
372 use super::*;
373 use gpui::{div, prelude::*, Render, TestAppContext};
374 use gpui::{EmptyView, View};
375 use schemars::schema_for;
376 use schemars::JsonSchema;
377 use serde::{Deserialize, Serialize};
378 use serde_json::json;
379
380 #[derive(Deserialize, Serialize, JsonSchema)]
381 struct WeatherQuery {
382 location: String,
383 unit: String,
384 }
385
386 #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
387 struct WeatherResult {
388 location: String,
389 temperature: f64,
390 unit: String,
391 }
392
393 struct WeatherView {
394 input: Option<WeatherQuery>,
395 result: Option<WeatherResult>,
396
397 // Fake API call
398 current_weather: WeatherResult,
399 }
400
401 #[derive(Clone, Serialize)]
402 struct WeatherTool {
403 current_weather: WeatherResult,
404 }
405
406 impl WeatherView {
407 fn new(current_weather: WeatherResult) -> Self {
408 Self {
409 input: None,
410 result: None,
411 current_weather,
412 }
413 }
414 }
415
416 impl Render for WeatherView {
417 fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
418 match self.result {
419 Some(ref result) => div()
420 .child(format!("temperature: {}", result.temperature))
421 .into_any_element(),
422 None => div().child("Calculating weather...").into_any_element(),
423 }
424 }
425 }
426
427 impl ToolOutput for WeatherView {
428 type Input = WeatherQuery;
429
430 type SerializedState = WeatherResult;
431
432 fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
433 serde_json::to_string(&self.result).unwrap()
434 }
435
436 fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
437 self.input = Some(input);
438 cx.notify();
439 }
440
441 fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
442 let input = self.input.as_ref().unwrap();
443
444 let _location = input.location.clone();
445 let _unit = input.unit.clone();
446
447 let weather = self.current_weather.clone();
448
449 self.result = Some(weather);
450
451 Task::ready(Ok(()))
452 }
453
454 fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
455 self.current_weather.clone()
456 }
457
458 fn deserialize(
459 &mut self,
460 output: Self::SerializedState,
461 _cx: &mut ViewContext<Self>,
462 ) -> Result<()> {
463 self.current_weather = output;
464 Ok(())
465 }
466 }
467
468 impl LanguageModelTool for WeatherTool {
469 type View = WeatherView;
470
471 fn name(&self) -> String {
472 "get_current_weather".to_string()
473 }
474
475 fn description(&self) -> String {
476 "Fetches the current weather for a given location.".to_string()
477 }
478
479 fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
480 cx.new_view(|_cx| WeatherView::new(self.current_weather.clone()))
481 }
482 }
483
484 #[gpui::test]
485 async fn test_openai_weather_example(cx: &mut TestAppContext) {
486 cx.background_executor.run_until_parked();
487 let (_, cx) = cx.add_window_view(|_cx| EmptyView);
488
489 let tool = WeatherTool {
490 current_weather: WeatherResult {
491 location: "San Francisco".to_string(),
492 temperature: 21.0,
493 unit: "Celsius".to_string(),
494 },
495 };
496
497 let tools = vec![tool.definition()];
498 assert_eq!(tools.len(), 1);
499
500 let expected = ToolFunctionDefinition {
501 name: "get_current_weather".to_string(),
502 description: "Fetches the current weather for a given location.".to_string(),
503 parameters: schema_for!(WeatherQuery),
504 };
505
506 assert_eq!(tools[0].name, expected.name);
507 assert_eq!(tools[0].description, expected.description);
508
509 let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
510
511 assert_eq!(
512 expected_schema,
513 json!({
514 "$schema": "http://json-schema.org/draft-07/schema#",
515 "title": "WeatherQuery",
516 "type": "object",
517 "properties": {
518 "location": {
519 "type": "string"
520 },
521 "unit": {
522 "type": "string"
523 }
524 },
525 "required": ["location", "unit"]
526 })
527 );
528
529 let view = cx.update(|cx| tool.view(cx));
530
531 cx.update(|cx| {
532 view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx);
533 });
534
535 let finished = cx.update(|cx| view.execute(cx)).await;
536
537 assert!(finished.is_ok());
538 }
539}