1use crate::ProjectContext;
2use anyhow::{anyhow, Result};
3use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
4use repair_json::repair;
5use schemars::{schema::RootSchema, schema_for, JsonSchema};
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7use serde_json::value::RawValue;
8use std::{
9 any::TypeId,
10 collections::HashMap,
11 fmt::Display,
12 mem,
13 sync::atomic::{AtomicBool, Ordering::SeqCst},
14};
15use ui::ViewContext;
16
17pub struct ToolRegistry {
18 registered_tools: HashMap<String, RegisteredTool>,
19}
20
21#[derive(Default)]
22pub struct ToolFunctionCall {
23 pub id: String,
24 pub name: String,
25 pub arguments: String,
26 state: ToolFunctionCallState,
27}
28
29#[derive(Default)]
30enum ToolFunctionCallState {
31 #[default]
32 Initializing,
33 NoSuchTool,
34 KnownTool(Box<dyn InternalToolView>),
35 ExecutedTool(Box<dyn InternalToolView>),
36}
37
38trait InternalToolView {
39 fn view(&self) -> AnyView;
40 fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
41 fn try_set_input(&self, input: &str, cx: &mut WindowContext);
42 fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
43 fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
44 fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
45}
46
47#[derive(Default, Serialize, Deserialize)]
48pub struct SavedToolFunctionCall {
49 id: String,
50 name: String,
51 arguments: String,
52 state: SavedToolFunctionCallState,
53}
54
55#[derive(Default, Serialize, Deserialize)]
56enum SavedToolFunctionCallState {
57 #[default]
58 Initializing,
59 NoSuchTool,
60 KnownTool,
61 ExecutedTool(Box<RawValue>),
62}
63
64#[derive(Clone, Debug, PartialEq)]
65pub struct ToolFunctionDefinition {
66 pub name: String,
67 pub description: String,
68 pub parameters: RootSchema,
69}
70
71pub trait LanguageModelTool {
72 type View: ToolView;
73
74 /// Returns the name of the tool.
75 ///
76 /// This name is exposed to the language model to allow the model to pick
77 /// which tools to use. As this name is used to identify the tool within a
78 /// tool registry, it should be unique.
79 fn name(&self) -> String;
80
81 /// Returns the description of the tool.
82 ///
83 /// This can be used to _prompt_ the model as to what the tool does.
84 fn description(&self) -> String;
85
86 /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
87 fn definition(&self) -> ToolFunctionDefinition {
88 let root_schema = schema_for!(<Self::View as ToolView>::Input);
89
90 ToolFunctionDefinition {
91 name: self.name(),
92 description: self.description(),
93 parameters: root_schema,
94 }
95 }
96
97 /// A view of the output of running the tool, for displaying to the user.
98 fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
99}
100
101pub trait ToolView: Render {
102 /// The input type that will be passed in to `execute` when the tool is called
103 /// by the language model.
104 type Input: DeserializeOwned + JsonSchema;
105
106 /// The output returned by executing the tool.
107 type SerializedState: DeserializeOwned + Serialize;
108
109 fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext<Self>) -> String;
110 fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>);
111 fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>>;
112
113 fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState;
114 fn deserialize(
115 &mut self,
116 output: Self::SerializedState,
117 cx: &mut ViewContext<Self>,
118 ) -> Result<()>;
119}
120
121struct RegisteredTool {
122 enabled: AtomicBool,
123 type_id: TypeId,
124 build_view: Box<dyn Fn(&mut WindowContext) -> Box<dyn InternalToolView>>,
125 definition: ToolFunctionDefinition,
126}
127
128impl ToolRegistry {
129 pub fn new() -> Self {
130 Self {
131 registered_tools: HashMap::new(),
132 }
133 }
134
135 pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
136 for tool in self.registered_tools.values() {
137 if tool.type_id == TypeId::of::<T>() {
138 tool.enabled.store(is_enabled, SeqCst);
139 return;
140 }
141 }
142 }
143
144 pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
145 for tool in self.registered_tools.values() {
146 if tool.type_id == TypeId::of::<T>() {
147 return tool.enabled.load(SeqCst);
148 }
149 }
150 false
151 }
152
153 pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
154 self.registered_tools
155 .values()
156 .filter(|tool| tool.enabled.load(SeqCst))
157 .map(|tool| tool.definition.clone())
158 .collect()
159 }
160
161 pub fn update_tool_call(
162 &self,
163 call: &mut ToolFunctionCall,
164 name: Option<&str>,
165 arguments: Option<&str>,
166 cx: &mut WindowContext,
167 ) {
168 if let Some(name) = name {
169 call.name.push_str(name);
170 }
171 if let Some(arguments) = arguments {
172 if call.arguments.is_empty() {
173 if let Some(tool) = self.registered_tools.get(&call.name) {
174 let view = (tool.build_view)(cx);
175 call.state = ToolFunctionCallState::KnownTool(view);
176 } else {
177 call.state = ToolFunctionCallState::NoSuchTool;
178 }
179 }
180 call.arguments.push_str(arguments);
181
182 if let ToolFunctionCallState::KnownTool(view) = &call.state {
183 if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
184 view.try_set_input(&repaired_arguments, cx)
185 }
186 }
187 }
188 }
189
190 pub fn execute_tool_call(
191 &self,
192 tool_call: &mut ToolFunctionCall,
193 cx: &mut WindowContext,
194 ) -> Option<Task<Result<()>>> {
195 if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) {
196 let task = view.execute(cx);
197 tool_call.state = ToolFunctionCallState::ExecutedTool(view);
198 Some(task)
199 } else {
200 None
201 }
202 }
203
204 pub fn render_tool_call(
205 &self,
206 tool_call: &ToolFunctionCall,
207 _cx: &mut WindowContext,
208 ) -> Option<AnyElement> {
209 match &tool_call.state {
210 ToolFunctionCallState::NoSuchTool => {
211 Some(ui::Label::new("No such tool").into_any_element())
212 }
213 ToolFunctionCallState::Initializing => None,
214 ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
215 Some(view.view().into_any_element())
216 }
217 }
218 }
219
220 pub fn content_for_tool_call(
221 &self,
222 tool_call: &ToolFunctionCall,
223 project_context: &mut ProjectContext,
224 cx: &mut WindowContext,
225 ) -> String {
226 match &tool_call.state {
227 ToolFunctionCallState::Initializing => String::new(),
228 ToolFunctionCallState::NoSuchTool => {
229 format!("No such tool: {}", tool_call.name)
230 }
231 ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
232 view.generate(project_context, cx)
233 }
234 }
235 }
236
237 pub fn serialize_tool_call(
238 &self,
239 call: &ToolFunctionCall,
240 cx: &mut WindowContext,
241 ) -> Result<SavedToolFunctionCall> {
242 Ok(SavedToolFunctionCall {
243 id: call.id.clone(),
244 name: call.name.clone(),
245 arguments: call.arguments.clone(),
246 state: match &call.state {
247 ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing,
248 ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool,
249 ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool,
250 ToolFunctionCallState::ExecutedTool(view) => {
251 SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?)
252 }
253 },
254 })
255 }
256
257 pub fn deserialize_tool_call(
258 &self,
259 call: &SavedToolFunctionCall,
260 cx: &mut WindowContext,
261 ) -> Result<ToolFunctionCall> {
262 let Some(tool) = self.registered_tools.get(&call.name) else {
263 return Err(anyhow!("no such tool {}", call.name));
264 };
265
266 Ok(ToolFunctionCall {
267 id: call.id.clone(),
268 name: call.name.clone(),
269 arguments: call.arguments.clone(),
270 state: match &call.state {
271 SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing,
272 SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool,
273 SavedToolFunctionCallState::KnownTool => {
274 log::error!("Deserialized tool that had not executed");
275 let view = (tool.build_view)(cx);
276 view.try_set_input(&call.arguments, cx);
277 ToolFunctionCallState::KnownTool(view)
278 }
279 SavedToolFunctionCallState::ExecutedTool(output) => {
280 let view = (tool.build_view)(cx);
281 view.try_set_input(&call.arguments, cx);
282 view.deserialize_output(output, cx)?;
283 ToolFunctionCallState::ExecutedTool(view)
284 }
285 },
286 })
287 }
288
289 pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
290 let name = tool.name();
291 let registered_tool = RegisteredTool {
292 type_id: TypeId::of::<T>(),
293 definition: tool.definition(),
294 enabled: AtomicBool::new(true),
295 build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))),
296 };
297
298 let previous = self.registered_tools.insert(name.clone(), registered_tool);
299 if previous.is_some() {
300 return Err(anyhow!("already registered a tool with name {}", name));
301 }
302
303 return Ok(());
304 }
305}
306
307impl<T: ToolView> InternalToolView for View<T> {
308 fn view(&self) -> AnyView {
309 self.clone().into()
310 }
311
312 fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
313 self.update(cx, |view, cx| view.generate(project, cx))
314 }
315
316 fn try_set_input(&self, input: &str, cx: &mut WindowContext) {
317 if let Ok(input) = serde_json::from_str::<T::Input>(input) {
318 self.update(cx, |view, cx| {
319 view.set_input(input, cx);
320 cx.notify();
321 });
322 }
323 }
324
325 fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>> {
326 self.update(cx, |view, cx| view.execute(cx))
327 }
328
329 fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>> {
330 let output = self.update(cx, |view, cx| view.serialize(cx));
331 Ok(RawValue::from_string(serde_json::to_string(&output)?)?)
332 }
333
334 fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> {
335 let state = serde_json::from_str::<T::SerializedState>(output.get())?;
336 self.update(cx, |view, cx| view.deserialize(state, cx))?;
337 Ok(())
338 }
339}
340
341impl Display for ToolFunctionDefinition {
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 let schema = serde_json::to_string(&self.parameters).ok();
344 let schema = schema.unwrap_or("None".to_string());
345 write!(f, "Name: {}:\n", self.name)?;
346 write!(f, "Description: {}\n", self.description)?;
347 write!(f, "Parameters: {}", schema)
348 }
349}
350
351#[cfg(test)]
352mod test {
353 use super::*;
354 use gpui::{div, prelude::*, Render, TestAppContext};
355 use gpui::{EmptyView, View};
356 use schemars::JsonSchema;
357 use serde::{Deserialize, Serialize};
358 use serde_json::json;
359
360 #[derive(Deserialize, Serialize, JsonSchema)]
361 struct WeatherQuery {
362 location: String,
363 unit: String,
364 }
365
366 #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
367 struct WeatherResult {
368 location: String,
369 temperature: f64,
370 unit: String,
371 }
372
373 struct WeatherView {
374 input: Option<WeatherQuery>,
375 result: Option<WeatherResult>,
376
377 // Fake API call
378 current_weather: WeatherResult,
379 }
380
381 #[derive(Clone, Serialize)]
382 struct WeatherTool {
383 current_weather: WeatherResult,
384 }
385
386 impl WeatherView {
387 fn new(current_weather: WeatherResult) -> Self {
388 Self {
389 input: None,
390 result: None,
391 current_weather,
392 }
393 }
394 }
395
396 impl Render for WeatherView {
397 fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
398 match self.result {
399 Some(ref result) => div()
400 .child(format!("temperature: {}", result.temperature))
401 .into_any_element(),
402 None => div().child("Calculating weather...").into_any_element(),
403 }
404 }
405 }
406
407 impl ToolView for WeatherView {
408 type Input = WeatherQuery;
409
410 type SerializedState = WeatherResult;
411
412 fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
413 serde_json::to_string(&self.result).unwrap()
414 }
415
416 fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
417 self.input = Some(input);
418 cx.notify();
419 }
420
421 fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
422 let input = self.input.as_ref().unwrap();
423
424 let _location = input.location.clone();
425 let _unit = input.unit.clone();
426
427 let weather = self.current_weather.clone();
428
429 self.result = Some(weather);
430
431 Task::ready(Ok(()))
432 }
433
434 fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
435 self.current_weather.clone()
436 }
437
438 fn deserialize(
439 &mut self,
440 output: Self::SerializedState,
441 _cx: &mut ViewContext<Self>,
442 ) -> Result<()> {
443 self.current_weather = output;
444 Ok(())
445 }
446 }
447
448 impl LanguageModelTool for WeatherTool {
449 type View = WeatherView;
450
451 fn name(&self) -> String {
452 "get_current_weather".to_string()
453 }
454
455 fn description(&self) -> String {
456 "Fetches the current weather for a given location.".to_string()
457 }
458
459 fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
460 cx.new_view(|_cx| WeatherView::new(self.current_weather.clone()))
461 }
462 }
463
464 #[gpui::test]
465 async fn test_openai_weather_example(cx: &mut TestAppContext) {
466 let (_, cx) = cx.add_window_view(|_cx| EmptyView);
467
468 let mut registry = ToolRegistry::new();
469 registry
470 .register(WeatherTool {
471 current_weather: WeatherResult {
472 location: "San Francisco".to_string(),
473 temperature: 21.0,
474 unit: "Celsius".to_string(),
475 },
476 })
477 .unwrap();
478
479 let definitions = registry.definitions();
480 assert_eq!(
481 definitions,
482 [ToolFunctionDefinition {
483 name: "get_current_weather".to_string(),
484 description: "Fetches the current weather for a given location.".to_string(),
485 parameters: serde_json::from_value(json!({
486 "$schema": "http://json-schema.org/draft-07/schema#",
487 "title": "WeatherQuery",
488 "type": "object",
489 "properties": {
490 "location": {
491 "type": "string"
492 },
493 "unit": {
494 "type": "string"
495 }
496 },
497 "required": ["location", "unit"]
498 }))
499 .unwrap(),
500 }]
501 );
502
503 let mut call = ToolFunctionCall {
504 id: "the-id".to_string(),
505 name: "get_cur".to_string(),
506 ..Default::default()
507 };
508
509 let task = cx.update(|cx| {
510 registry.update_tool_call(
511 &mut call,
512 Some("rent_weather"),
513 Some(r#"{"location": "San Francisco","#),
514 cx,
515 );
516 registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx);
517 registry.execute_tool_call(&mut call, cx).unwrap()
518 });
519 task.await.unwrap();
520
521 match &call.state {
522 ToolFunctionCallState::ExecutedTool(_view) => {}
523 _ => panic!(),
524 }
525 }
526}