registry.rs

  1use anyhow::{anyhow, Result};
  2use gpui::{AnyElement, AppContext, Task, WindowContext};
  3use std::{any::Any, collections::HashMap};
  4
  5use crate::tool::{
  6    LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
  7};
  8
  9pub struct ToolRegistry {
 10    tools: HashMap<String, Box<dyn Fn(&ToolFunctionCall, &AppContext) -> Task<ToolFunctionCall>>>,
 11    definitions: Vec<ToolFunctionDefinition>,
 12}
 13
 14impl ToolRegistry {
 15    pub fn new() -> Self {
 16        Self {
 17            tools: HashMap::new(),
 18            definitions: Vec::new(),
 19        }
 20    }
 21
 22    pub fn definitions(&self) -> &[ToolFunctionDefinition] {
 23        &self.definitions
 24    }
 25
 26    pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
 27        fn render<T: 'static + LanguageModelTool>(
 28            tool_call_id: &str,
 29            input: &Box<dyn Any>,
 30            output: &Box<dyn Any>,
 31            cx: &mut WindowContext,
 32        ) -> AnyElement {
 33            T::render(
 34                tool_call_id,
 35                input.as_ref().downcast_ref::<T::Input>().unwrap(),
 36                output.as_ref().downcast_ref::<T::Output>().unwrap(),
 37                cx,
 38            )
 39        }
 40
 41        fn format<T: 'static + LanguageModelTool>(
 42            input: &Box<dyn Any>,
 43            output: &Box<dyn Any>,
 44        ) -> String {
 45            T::format(
 46                input.as_ref().downcast_ref::<T::Input>().unwrap(),
 47                output.as_ref().downcast_ref::<T::Output>().unwrap(),
 48            )
 49        }
 50
 51        self.definitions.push(tool.definition());
 52        let name = tool.name();
 53        let previous = self.tools.insert(
 54            name.clone(),
 55            Box::new(move |tool_call: &ToolFunctionCall, cx: &AppContext| {
 56                let name = tool_call.name.clone();
 57                let arguments = tool_call.arguments.clone();
 58                let id = tool_call.id.clone();
 59
 60                let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
 61                    return Task::ready(ToolFunctionCall {
 62                        id,
 63                        name: name.clone(),
 64                        arguments,
 65                        result: Some(ToolFunctionCallResult::ParsingFailed),
 66                    });
 67                };
 68
 69                let result = tool.execute(&input, cx);
 70
 71                cx.spawn(move |_cx| async move {
 72                    match result.await {
 73                        Ok(result) => {
 74                            let result: T::Output = result;
 75                            ToolFunctionCall {
 76                                id,
 77                                name: name.clone(),
 78                                arguments,
 79                                result: Some(ToolFunctionCallResult::Finished {
 80                                    input: Box::new(input),
 81                                    output: Box::new(result),
 82                                    render_fn: render::<T>,
 83                                    format_fn: format::<T>,
 84                                }),
 85                            }
 86                        }
 87                        Err(_error) => ToolFunctionCall {
 88                            id,
 89                            name: name.clone(),
 90                            arguments,
 91                            result: Some(ToolFunctionCallResult::ExecutionFailed {
 92                                input: Box::new(input),
 93                            }),
 94                        },
 95                    }
 96                })
 97            }),
 98        );
 99
100        if previous.is_some() {
101            return Err(anyhow!("already registered a tool with name {}", name));
102        }
103
104        Ok(())
105    }
106
107    pub fn call(&self, tool_call: &ToolFunctionCall, cx: &AppContext) -> Task<ToolFunctionCall> {
108        let name = tool_call.name.clone();
109        let arguments = tool_call.arguments.clone();
110        let id = tool_call.id.clone();
111
112        let tool = match self.tools.get(&name) {
113            Some(tool) => tool,
114            None => {
115                let name = name.clone();
116                return Task::ready(ToolFunctionCall {
117                    id,
118                    name: name.clone(),
119                    arguments,
120                    result: Some(ToolFunctionCallResult::NoSuchTool),
121                });
122            }
123        };
124
125        tool(tool_call, cx)
126    }
127}
128
129#[cfg(test)]
130mod test {
131
132    use super::*;
133
134    use schemars::schema_for;
135
136    use gpui::{div, AnyElement, Element, ParentElement, TestAppContext, WindowContext};
137    use schemars::JsonSchema;
138    use serde::{Deserialize, Serialize};
139    use serde_json::json;
140
141    #[derive(Deserialize, Serialize, JsonSchema)]
142    struct WeatherQuery {
143        location: String,
144        unit: String,
145    }
146
147    struct WeatherTool {
148        current_weather: WeatherResult,
149    }
150
151    #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
152    struct WeatherResult {
153        location: String,
154        temperature: f64,
155        unit: String,
156    }
157
158    impl LanguageModelTool for WeatherTool {
159        type Input = WeatherQuery;
160        type Output = WeatherResult;
161
162        fn name(&self) -> String {
163            "get_current_weather".to_string()
164        }
165
166        fn description(&self) -> String {
167            "Fetches the current weather for a given location.".to_string()
168        }
169
170        fn execute(&self, input: &WeatherQuery, _cx: &AppContext) -> Task<Result<Self::Output>> {
171            let _location = input.location.clone();
172            let _unit = input.unit.clone();
173
174            let weather = self.current_weather.clone();
175
176            Task::ready(Ok(weather))
177        }
178
179        fn render(
180            _tool_call_id: &str,
181            _input: &Self::Input,
182            output: &Self::Output,
183            _cx: &mut WindowContext,
184        ) -> AnyElement {
185            div()
186                .child(format!(
187                    "The current temperature in {} is {} {}",
188                    output.location, output.temperature, output.unit
189                ))
190                .into_any()
191        }
192
193        fn format(_input: &Self::Input, output: &Self::Output) -> String {
194            format!(
195                "The current temperature in {} is {} {}",
196                output.location, output.temperature, output.unit
197            )
198        }
199    }
200
201    #[gpui::test]
202    async fn test_function_registry(cx: &mut TestAppContext) {
203        cx.background_executor.run_until_parked();
204
205        let mut registry = ToolRegistry::new();
206
207        let tool = WeatherTool {
208            current_weather: WeatherResult {
209                location: "San Francisco".to_string(),
210                temperature: 21.0,
211                unit: "Celsius".to_string(),
212            },
213        };
214
215        registry.register(tool).unwrap();
216
217        let _result = cx
218            .update(|cx| {
219                registry.call(
220                    &ToolFunctionCall {
221                        name: "get_current_weather".to_string(),
222                        arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
223                            .to_string(),
224                        id: "test-123".to_string(),
225                        result: None,
226                    },
227                    cx,
228                )
229            })
230            .await;
231
232        // assert!(result.is_ok());
233        // let result = result.unwrap();
234
235        // let expected = r#"{"location":"San Francisco","temperature":21.0,"unit":"Celsius"}"#;
236
237        // todo!(): Put this back in after the interface is stabilized
238        // assert_eq!(result, expected);
239    }
240
241    #[gpui::test]
242    async fn test_openai_weather_example(cx: &mut TestAppContext) {
243        cx.background_executor.run_until_parked();
244
245        let tool = WeatherTool {
246            current_weather: WeatherResult {
247                location: "San Francisco".to_string(),
248                temperature: 21.0,
249                unit: "Celsius".to_string(),
250            },
251        };
252
253        let tools = vec![tool.definition()];
254        assert_eq!(tools.len(), 1);
255
256        let expected = ToolFunctionDefinition {
257            name: "get_current_weather".to_string(),
258            description: "Fetches the current weather for a given location.".to_string(),
259            parameters: schema_for!(WeatherQuery),
260        };
261
262        assert_eq!(tools[0].name, expected.name);
263        assert_eq!(tools[0].description, expected.description);
264
265        let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
266
267        assert_eq!(
268            expected_schema,
269            json!({
270                "$schema": "http://json-schema.org/draft-07/schema#",
271                "title": "WeatherQuery",
272                "type": "object",
273                "properties": {
274                    "location": {
275                        "type": "string"
276                    },
277                    "unit": {
278                        "type": "string"
279                    }
280                },
281                "required": ["location", "unit"]
282            })
283        );
284
285        let args = json!({
286            "location": "San Francisco",
287            "unit": "Celsius"
288        });
289
290        let query: WeatherQuery = serde_json::from_value(args).unwrap();
291
292        let result = cx.update(|cx| tool.execute(&query, cx)).await;
293
294        assert!(result.is_ok());
295        let result = result.unwrap();
296
297        assert_eq!(result, tool.current_weather);
298    }
299}