1use crate::ProjectContext;
2use anyhow::{anyhow, Result};
3use gpui::{
4 div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
5};
6use schemars::{schema::RootSchema, schema_for, JsonSchema};
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use serde_json::{value::RawValue, Value};
9use std::{
10 any::TypeId,
11 collections::HashMap,
12 fmt::Display,
13 sync::{
14 atomic::{AtomicBool, Ordering::SeqCst},
15 Arc,
16 },
17};
18
19pub struct ToolRegistry {
20 registered_tools: HashMap<String, RegisteredTool>,
21}
22
23#[derive(Default)]
24pub struct ToolFunctionCall {
25 pub id: String,
26 pub name: String,
27 pub arguments: String,
28 pub result: Option<ToolFunctionCallResult>,
29}
30
31#[derive(Default, Serialize, Deserialize)]
32pub struct SavedToolFunctionCall {
33 pub id: String,
34 pub name: String,
35 pub arguments: String,
36 pub result: Option<SavedToolFunctionCallResult>,
37}
38
39pub enum ToolFunctionCallResult {
40 NoSuchTool,
41 ParsingFailed,
42 Finished {
43 view: AnyView,
44 serialized_output: Result<Box<RawValue>, String>,
45 generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
46 },
47}
48
49#[derive(Serialize, Deserialize)]
50pub enum SavedToolFunctionCallResult {
51 NoSuchTool,
52 ParsingFailed,
53 Finished {
54 serialized_output: Result<Box<RawValue>, String>,
55 },
56}
57
58#[derive(Clone)]
59pub struct ToolFunctionDefinition {
60 pub name: String,
61 pub description: String,
62 pub parameters: RootSchema,
63}
64
65pub trait LanguageModelTool {
66 /// The input type that will be passed in to `execute` when the tool is called
67 /// by the language model.
68 type Input: DeserializeOwned + JsonSchema;
69
70 /// The output returned by executing the tool.
71 type Output: DeserializeOwned + Serialize + 'static;
72
73 type View: Render + ToolOutput;
74
75 /// Returns the name of the tool.
76 ///
77 /// This name is exposed to the language model to allow the model to pick
78 /// which tools to use. As this name is used to identify the tool within a
79 /// tool registry, it should be unique.
80 fn name(&self) -> String;
81
82 /// Returns the description of the tool.
83 ///
84 /// This can be used to _prompt_ the model as to what the tool does.
85 fn description(&self) -> String;
86
87 /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
88 fn definition(&self) -> ToolFunctionDefinition {
89 let root_schema = schema_for!(Self::Input);
90
91 ToolFunctionDefinition {
92 name: self.name(),
93 description: self.description(),
94 parameters: root_schema,
95 }
96 }
97
98 /// Executes the tool with the given input.
99 fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
100
101 /// A view of the output of running the tool, for displaying to the user.
102 fn view(
103 &self,
104 input: Self::Input,
105 output: Result<Self::Output>,
106 cx: &mut WindowContext,
107 ) -> View<Self::View>;
108
109 fn render_running(_arguments: &Option<Value>, _cx: &mut WindowContext) -> impl IntoElement {
110 tool_running_placeholder()
111 }
112}
113
114pub fn tool_running_placeholder() -> AnyElement {
115 ui::Label::new("Researching...").into_any_element()
116}
117
118pub trait ToolOutput: Sized {
119 fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
120}
121
122struct RegisteredTool {
123 enabled: AtomicBool,
124 type_id: TypeId,
125 execute: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
126 deserialize: Box<dyn Fn(&SavedToolFunctionCall, &mut WindowContext) -> ToolFunctionCall>,
127 render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
128 definition: ToolFunctionDefinition,
129}
130
131impl ToolRegistry {
132 pub fn new() -> Self {
133 Self {
134 registered_tools: HashMap::new(),
135 }
136 }
137
138 pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
139 for tool in self.registered_tools.values() {
140 if tool.type_id == TypeId::of::<T>() {
141 tool.enabled.store(is_enabled, SeqCst);
142 return;
143 }
144 }
145 }
146
147 pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
148 for tool in self.registered_tools.values() {
149 if tool.type_id == TypeId::of::<T>() {
150 return tool.enabled.load(SeqCst);
151 }
152 }
153 false
154 }
155
156 pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
157 self.registered_tools
158 .values()
159 .filter(|tool| tool.enabled.load(SeqCst))
160 .map(|tool| tool.definition.clone())
161 .collect()
162 }
163
164 pub fn render_tool_call(
165 &self,
166 tool_call: &ToolFunctionCall,
167 cx: &mut WindowContext,
168 ) -> AnyElement {
169 match &tool_call.result {
170 Some(result) => div()
171 .p_2()
172 .child(result.into_any_element(&tool_call.name))
173 .into_any_element(),
174 None => {
175 let tool = self.registered_tools.get(&tool_call.name);
176
177 if let Some(tool) = tool {
178 (tool.render_running)(&tool_call, cx)
179 } else {
180 tool_running_placeholder()
181 }
182 }
183 }
184 }
185
186 pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall {
187 SavedToolFunctionCall {
188 id: call.id.clone(),
189 name: call.name.clone(),
190 arguments: call.arguments.clone(),
191 result: call.result.as_ref().map(|result| match result {
192 ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool,
193 ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed,
194 ToolFunctionCallResult::Finished {
195 serialized_output, ..
196 } => SavedToolFunctionCallResult::Finished {
197 serialized_output: match serialized_output {
198 Ok(value) => Ok(value.clone()),
199 Err(e) => Err(e.to_string()),
200 },
201 },
202 }),
203 }
204 }
205
206 pub fn deserialize_tool_call(
207 &self,
208 call: &SavedToolFunctionCall,
209 cx: &mut WindowContext,
210 ) -> ToolFunctionCall {
211 if let Some(tool) = &self.registered_tools.get(&call.name) {
212 (tool.deserialize)(call, cx)
213 } else {
214 ToolFunctionCall {
215 id: call.id.clone(),
216 name: call.name.clone(),
217 arguments: call.arguments.clone(),
218 result: Some(ToolFunctionCallResult::NoSuchTool),
219 }
220 }
221 }
222
223 pub fn register<T: 'static + LanguageModelTool>(
224 &mut self,
225 tool: T,
226 _cx: &mut WindowContext,
227 ) -> Result<()> {
228 let name = tool.name();
229 let tool = Arc::new(tool);
230 let registered_tool = RegisteredTool {
231 type_id: TypeId::of::<T>(),
232 definition: tool.definition(),
233 enabled: AtomicBool::new(true),
234 deserialize: Box::new({
235 let tool = tool.clone();
236 move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| {
237 let id = tool_call.id.clone();
238 let name = tool_call.name.clone();
239 let arguments = tool_call.arguments.clone();
240
241 let Ok(input) = serde_json::from_str::<T::Input>(&tool_call.arguments) else {
242 return ToolFunctionCall {
243 id,
244 name: name.clone(),
245 arguments,
246 result: Some(ToolFunctionCallResult::ParsingFailed),
247 };
248 };
249
250 let result = match &tool_call.result {
251 Some(result) => match result {
252 SavedToolFunctionCallResult::NoSuchTool => {
253 Some(ToolFunctionCallResult::NoSuchTool)
254 }
255 SavedToolFunctionCallResult::ParsingFailed => {
256 Some(ToolFunctionCallResult::ParsingFailed)
257 }
258 SavedToolFunctionCallResult::Finished { serialized_output } => {
259 let output = match serialized_output {
260 Ok(value) => {
261 match serde_json::from_str::<T::Output>(value.get()) {
262 Ok(value) => Ok(value),
263 Err(_) => {
264 return ToolFunctionCall {
265 id,
266 name: name.clone(),
267 arguments,
268 result: Some(
269 ToolFunctionCallResult::ParsingFailed,
270 ),
271 };
272 }
273 }
274 }
275 Err(e) => Err(anyhow!("{e}")),
276 };
277
278 let view = tool.view(input, output, cx).into();
279 Some(ToolFunctionCallResult::Finished {
280 serialized_output: serialized_output.clone(),
281 generate_fn: generate::<T>,
282 view,
283 })
284 }
285 },
286 None => None,
287 };
288
289 ToolFunctionCall {
290 id: tool_call.id.clone(),
291 name: name.clone(),
292 arguments: tool_call.arguments.clone(),
293 result,
294 }
295 }
296 }),
297 execute: Box::new({
298 let tool = tool.clone();
299 move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
300 let id = tool_call.id.clone();
301 let name = tool_call.name.clone();
302 let arguments = tool_call.arguments.clone();
303
304 let Ok(input) = serde_json::from_str::<T::Input>(&arguments) else {
305 return Task::ready(Ok(ToolFunctionCall {
306 id,
307 name: name.clone(),
308 arguments,
309 result: Some(ToolFunctionCallResult::ParsingFailed),
310 }));
311 };
312
313 let result = tool.execute(&input, cx);
314 let tool = tool.clone();
315 cx.spawn(move |mut cx| async move {
316 let result = result.await;
317 let serialized_output = result
318 .as_ref()
319 .map_err(ToString::to_string)
320 .and_then(|output| {
321 Ok(RawValue::from_string(
322 serde_json::to_string(output).map_err(|e| e.to_string())?,
323 )
324 .unwrap())
325 });
326 let view = cx.update(|cx| tool.view(input, result, cx))?;
327
328 Ok(ToolFunctionCall {
329 id,
330 name: name.clone(),
331 arguments,
332 result: Some(ToolFunctionCallResult::Finished {
333 serialized_output,
334 view: view.into(),
335 generate_fn: generate::<T>,
336 }),
337 })
338 })
339 }
340 }),
341 render_running: render_running::<T>,
342 };
343
344 let previous = self.registered_tools.insert(name.clone(), registered_tool);
345 if previous.is_some() {
346 return Err(anyhow!("already registered a tool with name {}", name));
347 }
348
349 return Ok(());
350
351 fn render_running<T: LanguageModelTool>(
352 tool_call: &ToolFunctionCall,
353 cx: &mut WindowContext,
354 ) -> AnyElement {
355 // Attempt to parse the string arguments that are JSON as a JSON value
356 let maybe_arguments = serde_json::to_value(tool_call.arguments.clone()).ok();
357
358 T::render_running(&maybe_arguments, cx).into_any_element()
359 }
360
361 fn generate<T: LanguageModelTool>(
362 view: AnyView,
363 project: &mut ProjectContext,
364 cx: &mut WindowContext,
365 ) -> String {
366 view.downcast::<T::View>()
367 .unwrap()
368 .update(cx, |view, cx| T::View::generate(view, project, cx))
369 }
370 }
371
372 /// Task yields an error if the window for the given WindowContext is closed before the task completes.
373 pub fn call(
374 &self,
375 tool_call: &ToolFunctionCall,
376 cx: &mut WindowContext,
377 ) -> Task<Result<ToolFunctionCall>> {
378 let name = tool_call.name.clone();
379 let arguments = tool_call.arguments.clone();
380 let id = tool_call.id.clone();
381
382 let tool = match self.registered_tools.get(&name) {
383 Some(tool) => tool,
384 None => {
385 let name = name.clone();
386 return Task::ready(Ok(ToolFunctionCall {
387 id,
388 name: name.clone(),
389 arguments,
390 result: Some(ToolFunctionCallResult::NoSuchTool),
391 }));
392 }
393 };
394
395 (tool.execute)(tool_call, cx)
396 }
397}
398
399impl ToolFunctionCallResult {
400 pub fn generate(
401 &self,
402 name: &String,
403 project: &mut ProjectContext,
404 cx: &mut WindowContext,
405 ) -> String {
406 match self {
407 ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
408 ToolFunctionCallResult::ParsingFailed => {
409 format!("Unable to parse arguments for {name}")
410 }
411 ToolFunctionCallResult::Finished {
412 generate_fn, view, ..
413 } => (generate_fn)(view.clone(), project, cx),
414 }
415 }
416
417 fn into_any_element(&self, name: &String) -> AnyElement {
418 match self {
419 ToolFunctionCallResult::NoSuchTool => {
420 format!("Language Model attempted to call {name}").into_any_element()
421 }
422 ToolFunctionCallResult::ParsingFailed => {
423 format!("Language Model called {name} with bad arguments").into_any_element()
424 }
425 ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
426 }
427 }
428}
429
430impl Display for ToolFunctionDefinition {
431 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432 let schema = serde_json::to_string(&self.parameters).ok();
433 let schema = schema.unwrap_or("None".to_string());
434 write!(f, "Name: {}:\n", self.name)?;
435 write!(f, "Description: {}\n", self.description)?;
436 write!(f, "Parameters: {}", schema)
437 }
438}
439
440#[cfg(test)]
441mod test {
442 use super::*;
443 use gpui::{div, prelude::*, Render, TestAppContext};
444 use gpui::{EmptyView, View};
445 use schemars::schema_for;
446 use schemars::JsonSchema;
447 use serde::{Deserialize, Serialize};
448 use serde_json::json;
449
450 #[derive(Deserialize, Serialize, JsonSchema)]
451 struct WeatherQuery {
452 location: String,
453 unit: String,
454 }
455
456 struct WeatherTool {
457 current_weather: WeatherResult,
458 }
459
460 #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
461 struct WeatherResult {
462 location: String,
463 temperature: f64,
464 unit: String,
465 }
466
467 struct WeatherView {
468 result: WeatherResult,
469 }
470
471 impl Render for WeatherView {
472 fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
473 div().child(format!("temperature: {}", self.result.temperature))
474 }
475 }
476
477 impl ToolOutput for WeatherView {
478 fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
479 serde_json::to_string(&self.result).unwrap()
480 }
481 }
482
483 impl LanguageModelTool for WeatherTool {
484 type Input = WeatherQuery;
485 type Output = WeatherResult;
486 type View = WeatherView;
487
488 fn name(&self) -> String {
489 "get_current_weather".to_string()
490 }
491
492 fn description(&self) -> String {
493 "Fetches the current weather for a given location.".to_string()
494 }
495
496 fn execute(
497 &self,
498 input: &Self::Input,
499 _cx: &mut WindowContext,
500 ) -> Task<Result<Self::Output>> {
501 let _location = input.location.clone();
502 let _unit = input.unit.clone();
503
504 let weather = self.current_weather.clone();
505
506 Task::ready(Ok(weather))
507 }
508
509 fn view(
510 &self,
511 _input: Self::Input,
512 result: Result<Self::Output>,
513 cx: &mut WindowContext,
514 ) -> View<Self::View> {
515 cx.new_view(|_cx| {
516 let result = result.unwrap();
517 WeatherView { result }
518 })
519 }
520 }
521
522 #[gpui::test]
523 async fn test_openai_weather_example(cx: &mut TestAppContext) {
524 cx.background_executor.run_until_parked();
525 let (_, cx) = cx.add_window_view(|_cx| EmptyView);
526
527 let tool = WeatherTool {
528 current_weather: WeatherResult {
529 location: "San Francisco".to_string(),
530 temperature: 21.0,
531 unit: "Celsius".to_string(),
532 },
533 };
534
535 let tools = vec![tool.definition()];
536 assert_eq!(tools.len(), 1);
537
538 let expected = ToolFunctionDefinition {
539 name: "get_current_weather".to_string(),
540 description: "Fetches the current weather for a given location.".to_string(),
541 parameters: schema_for!(WeatherQuery),
542 };
543
544 assert_eq!(tools[0].name, expected.name);
545 assert_eq!(tools[0].description, expected.description);
546
547 let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
548
549 assert_eq!(
550 expected_schema,
551 json!({
552 "$schema": "http://json-schema.org/draft-07/schema#",
553 "title": "WeatherQuery",
554 "type": "object",
555 "properties": {
556 "location": {
557 "type": "string"
558 },
559 "unit": {
560 "type": "string"
561 }
562 },
563 "required": ["location", "unit"]
564 })
565 );
566
567 let args = json!({
568 "location": "San Francisco",
569 "unit": "Celsius"
570 });
571
572 let query: WeatherQuery = serde_json::from_value(args).unwrap();
573
574 let result = cx.update(|cx| tool.execute(&query, cx)).await;
575
576 assert!(result.is_ok());
577 let result = result.unwrap();
578
579 assert_eq!(result, tool.current_weather);
580 }
581}