tool_registry.rs

  1use crate::ProjectContext;
  2use anyhow::{anyhow, Result};
  3use gpui::{
  4    div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
  5};
  6use schemars::{schema::RootSchema, schema_for, JsonSchema};
  7use serde::{de::DeserializeOwned, Deserialize, Serialize};
  8use serde_json::{value::RawValue, Value};
  9use std::{
 10    any::TypeId,
 11    collections::HashMap,
 12    fmt::Display,
 13    sync::{
 14        atomic::{AtomicBool, Ordering::SeqCst},
 15        Arc,
 16    },
 17};
 18
 19pub struct ToolRegistry {
 20    registered_tools: HashMap<String, RegisteredTool>,
 21}
 22
 23#[derive(Default)]
 24pub struct ToolFunctionCall {
 25    pub id: String,
 26    pub name: String,
 27    pub arguments: String,
 28    pub result: Option<ToolFunctionCallResult>,
 29}
 30
 31#[derive(Default, Serialize, Deserialize)]
 32pub struct SavedToolFunctionCall {
 33    pub id: String,
 34    pub name: String,
 35    pub arguments: String,
 36    pub result: Option<SavedToolFunctionCallResult>,
 37}
 38
 39pub enum ToolFunctionCallResult {
 40    NoSuchTool,
 41    ParsingFailed,
 42    Finished {
 43        view: AnyView,
 44        serialized_output: Result<Box<RawValue>, String>,
 45        generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
 46    },
 47}
 48
 49#[derive(Serialize, Deserialize)]
 50pub enum SavedToolFunctionCallResult {
 51    NoSuchTool,
 52    ParsingFailed,
 53    Finished {
 54        serialized_output: Result<Box<RawValue>, String>,
 55    },
 56}
 57
 58#[derive(Clone)]
 59pub struct ToolFunctionDefinition {
 60    pub name: String,
 61    pub description: String,
 62    pub parameters: RootSchema,
 63}
 64
 65pub trait LanguageModelTool {
 66    /// The input type that will be passed in to `execute` when the tool is called
 67    /// by the language model.
 68    type Input: DeserializeOwned + JsonSchema;
 69
 70    /// The output returned by executing the tool.
 71    type Output: DeserializeOwned + Serialize + 'static;
 72
 73    type View: Render + ToolOutput;
 74
 75    /// Returns the name of the tool.
 76    ///
 77    /// This name is exposed to the language model to allow the model to pick
 78    /// which tools to use. As this name is used to identify the tool within a
 79    /// tool registry, it should be unique.
 80    fn name(&self) -> String;
 81
 82    /// Returns the description of the tool.
 83    ///
 84    /// This can be used to _prompt_ the model as to what the tool does.
 85    fn description(&self) -> String;
 86
 87    /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
 88    fn definition(&self) -> ToolFunctionDefinition {
 89        let root_schema = schema_for!(Self::Input);
 90
 91        ToolFunctionDefinition {
 92            name: self.name(),
 93            description: self.description(),
 94            parameters: root_schema,
 95        }
 96    }
 97
 98    /// Executes the tool with the given input.
 99    fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
100
101    /// A view of the output of running the tool, for displaying to the user.
102    fn view(
103        &self,
104        input: Self::Input,
105        output: Result<Self::Output>,
106        cx: &mut WindowContext,
107    ) -> View<Self::View>;
108
109    fn render_running(_arguments: &Option<Value>, _cx: &mut WindowContext) -> impl IntoElement {
110        tool_running_placeholder()
111    }
112}
113
114pub fn tool_running_placeholder() -> AnyElement {
115    ui::Label::new("Researching...").into_any_element()
116}
117
118pub trait ToolOutput: Sized {
119    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
120}
121
122struct RegisteredTool {
123    enabled: AtomicBool,
124    type_id: TypeId,
125    execute: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
126    deserialize: Box<dyn Fn(&SavedToolFunctionCall, &mut WindowContext) -> ToolFunctionCall>,
127    render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
128    definition: ToolFunctionDefinition,
129}
130
131impl ToolRegistry {
132    pub fn new() -> Self {
133        Self {
134            registered_tools: HashMap::new(),
135        }
136    }
137
138    pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
139        for tool in self.registered_tools.values() {
140            if tool.type_id == TypeId::of::<T>() {
141                tool.enabled.store(is_enabled, SeqCst);
142                return;
143            }
144        }
145    }
146
147    pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
148        for tool in self.registered_tools.values() {
149            if tool.type_id == TypeId::of::<T>() {
150                return tool.enabled.load(SeqCst);
151            }
152        }
153        false
154    }
155
156    pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
157        self.registered_tools
158            .values()
159            .filter(|tool| tool.enabled.load(SeqCst))
160            .map(|tool| tool.definition.clone())
161            .collect()
162    }
163
164    pub fn render_tool_call(
165        &self,
166        tool_call: &ToolFunctionCall,
167        cx: &mut WindowContext,
168    ) -> AnyElement {
169        match &tool_call.result {
170            Some(result) => div()
171                .p_2()
172                .child(result.into_any_element(&tool_call.name))
173                .into_any_element(),
174            None => {
175                let tool = self.registered_tools.get(&tool_call.name);
176
177                if let Some(tool) = tool {
178                    (tool.render_running)(&tool_call, cx)
179                } else {
180                    tool_running_placeholder()
181                }
182            }
183        }
184    }
185
186    pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall {
187        SavedToolFunctionCall {
188            id: call.id.clone(),
189            name: call.name.clone(),
190            arguments: call.arguments.clone(),
191            result: call.result.as_ref().map(|result| match result {
192                ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool,
193                ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed,
194                ToolFunctionCallResult::Finished {
195                    serialized_output, ..
196                } => SavedToolFunctionCallResult::Finished {
197                    serialized_output: match serialized_output {
198                        Ok(value) => Ok(value.clone()),
199                        Err(e) => Err(e.to_string()),
200                    },
201                },
202            }),
203        }
204    }
205
206    pub fn deserialize_tool_call(
207        &self,
208        call: &SavedToolFunctionCall,
209        cx: &mut WindowContext,
210    ) -> ToolFunctionCall {
211        if let Some(tool) = &self.registered_tools.get(&call.name) {
212            (tool.deserialize)(call, cx)
213        } else {
214            ToolFunctionCall {
215                id: call.id.clone(),
216                name: call.name.clone(),
217                arguments: call.arguments.clone(),
218                result: Some(ToolFunctionCallResult::NoSuchTool),
219            }
220        }
221    }
222
223    pub fn register<T: 'static + LanguageModelTool>(
224        &mut self,
225        tool: T,
226        _cx: &mut WindowContext,
227    ) -> Result<()> {
228        let name = tool.name();
229        let tool = Arc::new(tool);
230        let registered_tool = RegisteredTool {
231            type_id: TypeId::of::<T>(),
232            definition: tool.definition(),
233            enabled: AtomicBool::new(true),
234            deserialize: Box::new({
235                let tool = tool.clone();
236                move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| {
237                    let id = tool_call.id.clone();
238                    let name = tool_call.name.clone();
239                    let arguments = tool_call.arguments.clone();
240
241                    let Ok(input) = serde_json::from_str::<T::Input>(&tool_call.arguments) else {
242                        return ToolFunctionCall {
243                            id,
244                            name: name.clone(),
245                            arguments,
246                            result: Some(ToolFunctionCallResult::ParsingFailed),
247                        };
248                    };
249
250                    let result = match &tool_call.result {
251                        Some(result) => match result {
252                            SavedToolFunctionCallResult::NoSuchTool => {
253                                Some(ToolFunctionCallResult::NoSuchTool)
254                            }
255                            SavedToolFunctionCallResult::ParsingFailed => {
256                                Some(ToolFunctionCallResult::ParsingFailed)
257                            }
258                            SavedToolFunctionCallResult::Finished { serialized_output } => {
259                                let output = match serialized_output {
260                                    Ok(value) => {
261                                        match serde_json::from_str::<T::Output>(value.get()) {
262                                            Ok(value) => Ok(value),
263                                            Err(_) => {
264                                                return ToolFunctionCall {
265                                                    id,
266                                                    name: name.clone(),
267                                                    arguments,
268                                                    result: Some(
269                                                        ToolFunctionCallResult::ParsingFailed,
270                                                    ),
271                                                };
272                                            }
273                                        }
274                                    }
275                                    Err(e) => Err(anyhow!("{e}")),
276                                };
277
278                                let view = tool.view(input, output, cx).into();
279                                Some(ToolFunctionCallResult::Finished {
280                                    serialized_output: serialized_output.clone(),
281                                    generate_fn: generate::<T>,
282                                    view,
283                                })
284                            }
285                        },
286                        None => None,
287                    };
288
289                    ToolFunctionCall {
290                        id: tool_call.id.clone(),
291                        name: name.clone(),
292                        arguments: tool_call.arguments.clone(),
293                        result,
294                    }
295                }
296            }),
297            execute: Box::new({
298                let tool = tool.clone();
299                move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
300                    let id = tool_call.id.clone();
301                    let name = tool_call.name.clone();
302                    let arguments = tool_call.arguments.clone();
303
304                    let Ok(input) = serde_json::from_str::<T::Input>(&arguments) else {
305                        return Task::ready(Ok(ToolFunctionCall {
306                            id,
307                            name: name.clone(),
308                            arguments,
309                            result: Some(ToolFunctionCallResult::ParsingFailed),
310                        }));
311                    };
312
313                    let result = tool.execute(&input, cx);
314                    let tool = tool.clone();
315                    cx.spawn(move |mut cx| async move {
316                        let result = result.await;
317                        let serialized_output = result
318                            .as_ref()
319                            .map_err(ToString::to_string)
320                            .and_then(|output| {
321                                Ok(RawValue::from_string(
322                                    serde_json::to_string(output).map_err(|e| e.to_string())?,
323                                )
324                                .unwrap())
325                            });
326                        let view = cx.update(|cx| tool.view(input, result, cx))?;
327
328                        Ok(ToolFunctionCall {
329                            id,
330                            name: name.clone(),
331                            arguments,
332                            result: Some(ToolFunctionCallResult::Finished {
333                                serialized_output,
334                                view: view.into(),
335                                generate_fn: generate::<T>,
336                            }),
337                        })
338                    })
339                }
340            }),
341            render_running: render_running::<T>,
342        };
343
344        let previous = self.registered_tools.insert(name.clone(), registered_tool);
345        if previous.is_some() {
346            return Err(anyhow!("already registered a tool with name {}", name));
347        }
348
349        return Ok(());
350
351        fn render_running<T: LanguageModelTool>(
352            tool_call: &ToolFunctionCall,
353            cx: &mut WindowContext,
354        ) -> AnyElement {
355            // Attempt to parse the string arguments that are JSON as a JSON value
356            let maybe_arguments = serde_json::to_value(tool_call.arguments.clone()).ok();
357
358            T::render_running(&maybe_arguments, cx).into_any_element()
359        }
360
361        fn generate<T: LanguageModelTool>(
362            view: AnyView,
363            project: &mut ProjectContext,
364            cx: &mut WindowContext,
365        ) -> String {
366            view.downcast::<T::View>()
367                .unwrap()
368                .update(cx, |view, cx| T::View::generate(view, project, cx))
369        }
370    }
371
372    /// Task yields an error if the window for the given WindowContext is closed before the task completes.
373    pub fn call(
374        &self,
375        tool_call: &ToolFunctionCall,
376        cx: &mut WindowContext,
377    ) -> Task<Result<ToolFunctionCall>> {
378        let name = tool_call.name.clone();
379        let arguments = tool_call.arguments.clone();
380        let id = tool_call.id.clone();
381
382        let tool = match self.registered_tools.get(&name) {
383            Some(tool) => tool,
384            None => {
385                let name = name.clone();
386                return Task::ready(Ok(ToolFunctionCall {
387                    id,
388                    name: name.clone(),
389                    arguments,
390                    result: Some(ToolFunctionCallResult::NoSuchTool),
391                }));
392            }
393        };
394
395        (tool.execute)(tool_call, cx)
396    }
397}
398
399impl ToolFunctionCallResult {
400    pub fn generate(
401        &self,
402        name: &String,
403        project: &mut ProjectContext,
404        cx: &mut WindowContext,
405    ) -> String {
406        match self {
407            ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
408            ToolFunctionCallResult::ParsingFailed => {
409                format!("Unable to parse arguments for {name}")
410            }
411            ToolFunctionCallResult::Finished {
412                generate_fn, view, ..
413            } => (generate_fn)(view.clone(), project, cx),
414        }
415    }
416
417    fn into_any_element(&self, name: &String) -> AnyElement {
418        match self {
419            ToolFunctionCallResult::NoSuchTool => {
420                format!("Language Model attempted to call {name}").into_any_element()
421            }
422            ToolFunctionCallResult::ParsingFailed => {
423                format!("Language Model called {name} with bad arguments").into_any_element()
424            }
425            ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
426        }
427    }
428}
429
430impl Display for ToolFunctionDefinition {
431    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432        let schema = serde_json::to_string(&self.parameters).ok();
433        let schema = schema.unwrap_or("None".to_string());
434        write!(f, "Name: {}:\n", self.name)?;
435        write!(f, "Description: {}\n", self.description)?;
436        write!(f, "Parameters: {}", schema)
437    }
438}
439
440#[cfg(test)]
441mod test {
442    use super::*;
443    use gpui::{div, prelude::*, Render, TestAppContext};
444    use gpui::{EmptyView, View};
445    use schemars::schema_for;
446    use schemars::JsonSchema;
447    use serde::{Deserialize, Serialize};
448    use serde_json::json;
449
450    #[derive(Deserialize, Serialize, JsonSchema)]
451    struct WeatherQuery {
452        location: String,
453        unit: String,
454    }
455
456    struct WeatherTool {
457        current_weather: WeatherResult,
458    }
459
460    #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
461    struct WeatherResult {
462        location: String,
463        temperature: f64,
464        unit: String,
465    }
466
467    struct WeatherView {
468        result: WeatherResult,
469    }
470
471    impl Render for WeatherView {
472        fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
473            div().child(format!("temperature: {}", self.result.temperature))
474        }
475    }
476
477    impl ToolOutput for WeatherView {
478        fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
479            serde_json::to_string(&self.result).unwrap()
480        }
481    }
482
483    impl LanguageModelTool for WeatherTool {
484        type Input = WeatherQuery;
485        type Output = WeatherResult;
486        type View = WeatherView;
487
488        fn name(&self) -> String {
489            "get_current_weather".to_string()
490        }
491
492        fn description(&self) -> String {
493            "Fetches the current weather for a given location.".to_string()
494        }
495
496        fn execute(
497            &self,
498            input: &Self::Input,
499            _cx: &mut WindowContext,
500        ) -> Task<Result<Self::Output>> {
501            let _location = input.location.clone();
502            let _unit = input.unit.clone();
503
504            let weather = self.current_weather.clone();
505
506            Task::ready(Ok(weather))
507        }
508
509        fn view(
510            &self,
511            _input: Self::Input,
512            result: Result<Self::Output>,
513            cx: &mut WindowContext,
514        ) -> View<Self::View> {
515            cx.new_view(|_cx| {
516                let result = result.unwrap();
517                WeatherView { result }
518            })
519        }
520    }
521
522    #[gpui::test]
523    async fn test_openai_weather_example(cx: &mut TestAppContext) {
524        cx.background_executor.run_until_parked();
525        let (_, cx) = cx.add_window_view(|_cx| EmptyView);
526
527        let tool = WeatherTool {
528            current_weather: WeatherResult {
529                location: "San Francisco".to_string(),
530                temperature: 21.0,
531                unit: "Celsius".to_string(),
532            },
533        };
534
535        let tools = vec![tool.definition()];
536        assert_eq!(tools.len(), 1);
537
538        let expected = ToolFunctionDefinition {
539            name: "get_current_weather".to_string(),
540            description: "Fetches the current weather for a given location.".to_string(),
541            parameters: schema_for!(WeatherQuery),
542        };
543
544        assert_eq!(tools[0].name, expected.name);
545        assert_eq!(tools[0].description, expected.description);
546
547        let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
548
549        assert_eq!(
550            expected_schema,
551            json!({
552                "$schema": "http://json-schema.org/draft-07/schema#",
553                "title": "WeatherQuery",
554                "type": "object",
555                "properties": {
556                    "location": {
557                        "type": "string"
558                    },
559                    "unit": {
560                        "type": "string"
561                    }
562                },
563                "required": ["location", "unit"]
564            })
565        );
566
567        let args = json!({
568            "location": "San Francisco",
569            "unit": "Celsius"
570        });
571
572        let query: WeatherQuery = serde_json::from_value(args).unwrap();
573
574        let result = cx.update(|cx| tool.execute(&query, cx)).await;
575
576        assert!(result.is_ok());
577        let result = result.unwrap();
578
579        assert_eq!(result, tool.current_weather);
580    }
581}