registry.rs

  1use anyhow::{anyhow, Result};
  2use gpui::{Task, WindowContext};
  3use std::collections::HashMap;
  4
  5use crate::tool::{
  6    LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
  7};
  8
  9pub struct ToolRegistry {
 10    tools: HashMap<
 11        String,
 12        Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
 13    >,
 14    definitions: Vec<ToolFunctionDefinition>,
 15}
 16
 17impl ToolRegistry {
 18    pub fn new() -> Self {
 19        Self {
 20            tools: HashMap::new(),
 21            definitions: Vec::new(),
 22        }
 23    }
 24
 25    pub fn definitions(&self) -> &[ToolFunctionDefinition] {
 26        &self.definitions
 27    }
 28
 29    pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
 30        self.definitions.push(tool.definition());
 31        let name = tool.name();
 32        let previous = self.tools.insert(
 33            name.clone(),
 34            // registry.call(tool_call, cx)
 35            Box::new(
 36                move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
 37                    let name = tool_call.name.clone();
 38                    let arguments = tool_call.arguments.clone();
 39                    let id = tool_call.id.clone();
 40
 41                    let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
 42                        return Task::ready(Ok(ToolFunctionCall {
 43                            id,
 44                            name: name.clone(),
 45                            arguments,
 46                            result: Some(ToolFunctionCallResult::ParsingFailed),
 47                        }));
 48                    };
 49
 50                    let result = tool.execute(&input, cx);
 51
 52                    cx.spawn(move |mut cx| async move {
 53                        let result: Result<T::Output> = result.await;
 54                        let for_model = T::format(&input, &result);
 55                        let view = cx.update(|cx| T::new_view(id.clone(), input, result, cx))?;
 56
 57                        Ok(ToolFunctionCall {
 58                            id,
 59                            name: name.clone(),
 60                            arguments,
 61                            result: Some(ToolFunctionCallResult::Finished {
 62                                view: view.into(),
 63                                for_model,
 64                            }),
 65                        })
 66                    })
 67                },
 68            ),
 69        );
 70
 71        if previous.is_some() {
 72            return Err(anyhow!("already registered a tool with name {}", name));
 73        }
 74
 75        Ok(())
 76    }
 77
 78    /// Task yields an error if the window for the given WindowContext is closed before the task completes.
 79    pub fn call(
 80        &self,
 81        tool_call: &ToolFunctionCall,
 82        cx: &mut WindowContext,
 83    ) -> Task<Result<ToolFunctionCall>> {
 84        let name = tool_call.name.clone();
 85        let arguments = tool_call.arguments.clone();
 86        let id = tool_call.id.clone();
 87
 88        let tool = match self.tools.get(&name) {
 89            Some(tool) => tool,
 90            None => {
 91                let name = name.clone();
 92                return Task::ready(Ok(ToolFunctionCall {
 93                    id,
 94                    name: name.clone(),
 95                    arguments,
 96                    result: Some(ToolFunctionCallResult::NoSuchTool),
 97                }));
 98            }
 99        };
100
101        tool(tool_call, cx)
102    }
103}
104
105#[cfg(test)]
106mod test {
107    use super::*;
108    use gpui::View;
109    use gpui::{div, prelude::*, Render, TestAppContext};
110    use schemars::schema_for;
111    use schemars::JsonSchema;
112    use serde::{Deserialize, Serialize};
113    use serde_json::json;
114
115    #[derive(Deserialize, Serialize, JsonSchema)]
116    struct WeatherQuery {
117        location: String,
118        unit: String,
119    }
120
121    struct WeatherTool {
122        current_weather: WeatherResult,
123    }
124
125    #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
126    struct WeatherResult {
127        location: String,
128        temperature: f64,
129        unit: String,
130    }
131
132    struct WeatherView {
133        result: WeatherResult,
134    }
135
136    impl Render for WeatherView {
137        fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
138            div().child(format!("temperature: {}", self.result.temperature))
139        }
140    }
141
142    impl LanguageModelTool for WeatherTool {
143        type Input = WeatherQuery;
144        type Output = WeatherResult;
145        type View = WeatherView;
146
147        fn name(&self) -> String {
148            "get_current_weather".to_string()
149        }
150
151        fn description(&self) -> String {
152            "Fetches the current weather for a given location.".to_string()
153        }
154
155        fn execute(
156            &self,
157            input: &Self::Input,
158            _cx: &gpui::AppContext,
159        ) -> Task<Result<Self::Output>> {
160            let _location = input.location.clone();
161            let _unit = input.unit.clone();
162
163            let weather = self.current_weather.clone();
164
165            Task::ready(Ok(weather))
166        }
167
168        fn new_view(
169            _tool_call_id: String,
170            _input: Self::Input,
171            result: Result<Self::Output>,
172            cx: &mut WindowContext,
173        ) -> View<Self::View> {
174            cx.new_view(|_cx| {
175                let result = result.unwrap();
176                WeatherView { result }
177            })
178        }
179
180        fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
181            serde_json::to_string(&output.as_ref().unwrap()).unwrap()
182        }
183    }
184
185    #[gpui::test]
186    async fn test_function_registry(cx: &mut TestAppContext) {
187        cx.background_executor.run_until_parked();
188
189        let mut registry = ToolRegistry::new();
190
191        let tool = WeatherTool {
192            current_weather: WeatherResult {
193                location: "San Francisco".to_string(),
194                temperature: 21.0,
195                unit: "Celsius".to_string(),
196            },
197        };
198
199        registry.register(tool).unwrap();
200
201        // let _result = cx
202        //     .update(|cx| {
203        //         registry.call(
204        //             &ToolFunctionCall {
205        //                 name: "get_current_weather".to_string(),
206        //                 arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
207        //                     .to_string(),
208        //                 id: "test-123".to_string(),
209        //                 result: None,
210        //             },
211        //             cx,
212        //         )
213        //     })
214        //     .await;
215
216        // assert!(result.is_ok());
217        // let result = result.unwrap();
218
219        // let expected = r#"{"location":"San Francisco","temperature":21.0,"unit":"Celsius"}"#;
220
221        // todo!(): Put this back in after the interface is stabilized
222        // assert_eq!(result, expected);
223    }
224
225    #[gpui::test]
226    async fn test_openai_weather_example(cx: &mut TestAppContext) {
227        cx.background_executor.run_until_parked();
228
229        let tool = WeatherTool {
230            current_weather: WeatherResult {
231                location: "San Francisco".to_string(),
232                temperature: 21.0,
233                unit: "Celsius".to_string(),
234            },
235        };
236
237        let tools = vec![tool.definition()];
238        assert_eq!(tools.len(), 1);
239
240        let expected = ToolFunctionDefinition {
241            name: "get_current_weather".to_string(),
242            description: "Fetches the current weather for a given location.".to_string(),
243            parameters: schema_for!(WeatherQuery),
244        };
245
246        assert_eq!(tools[0].name, expected.name);
247        assert_eq!(tools[0].description, expected.description);
248
249        let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
250
251        assert_eq!(
252            expected_schema,
253            json!({
254                "$schema": "http://json-schema.org/draft-07/schema#",
255                "title": "WeatherQuery",
256                "type": "object",
257                "properties": {
258                    "location": {
259                        "type": "string"
260                    },
261                    "unit": {
262                        "type": "string"
263                    }
264                },
265                "required": ["location", "unit"]
266            })
267        );
268
269        let args = json!({
270            "location": "San Francisco",
271            "unit": "Celsius"
272        });
273
274        let query: WeatherQuery = serde_json::from_value(args).unwrap();
275
276        let result = cx.update(|cx| tool.execute(&query, cx)).await;
277
278        assert!(result.is_ok());
279        let result = result.unwrap();
280
281        assert_eq!(result, tool.current_weather);
282    }
283}