registry.rs

  1use anyhow::{anyhow, Result};
  2use gpui::{AnyView, Task, WindowContext};
  3use std::collections::HashMap;
  4
  5use crate::tool::{
  6    LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
  7};
  8
  9pub struct ToolRegistry {
 10    tools: HashMap<
 11        String,
 12        Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
 13    >,
 14    definitions: Vec<ToolFunctionDefinition>,
 15    status_views: Vec<AnyView>,
 16}
 17
 18impl ToolRegistry {
 19    pub fn new() -> Self {
 20        Self {
 21            tools: HashMap::new(),
 22            definitions: Vec::new(),
 23            status_views: Vec::new(),
 24        }
 25    }
 26
 27    pub fn definitions(&self) -> &[ToolFunctionDefinition] {
 28        &self.definitions
 29    }
 30
 31    pub fn register<T: 'static + LanguageModelTool>(
 32        &mut self,
 33        tool: T,
 34        cx: &mut WindowContext,
 35    ) -> Result<()> {
 36        self.definitions.push(tool.definition());
 37
 38        if let Some(tool_view) = tool.status_view(cx) {
 39            self.status_views.push(tool_view);
 40        }
 41
 42        let name = tool.name();
 43        let previous = self.tools.insert(
 44            name.clone(),
 45            // registry.call(tool_call, cx)
 46            Box::new(
 47                move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
 48                    let name = tool_call.name.clone();
 49                    let arguments = tool_call.arguments.clone();
 50                    let id = tool_call.id.clone();
 51
 52                    let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
 53                        return Task::ready(Ok(ToolFunctionCall {
 54                            id,
 55                            name: name.clone(),
 56                            arguments,
 57                            result: Some(ToolFunctionCallResult::ParsingFailed),
 58                        }));
 59                    };
 60
 61                    let result = tool.execute(&input, cx);
 62
 63                    cx.spawn(move |mut cx| async move {
 64                        let result: Result<T::Output> = result.await;
 65                        let for_model = T::format(&input, &result);
 66                        let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?;
 67
 68                        Ok(ToolFunctionCall {
 69                            id,
 70                            name: name.clone(),
 71                            arguments,
 72                            result: Some(ToolFunctionCallResult::Finished {
 73                                view: view.into(),
 74                                for_model,
 75                            }),
 76                        })
 77                    })
 78                },
 79            ),
 80        );
 81
 82        if previous.is_some() {
 83            return Err(anyhow!("already registered a tool with name {}", name));
 84        }
 85
 86        Ok(())
 87    }
 88
 89    /// Task yields an error if the window for the given WindowContext is closed before the task completes.
 90    pub fn call(
 91        &self,
 92        tool_call: &ToolFunctionCall,
 93        cx: &mut WindowContext,
 94    ) -> Task<Result<ToolFunctionCall>> {
 95        let name = tool_call.name.clone();
 96        let arguments = tool_call.arguments.clone();
 97        let id = tool_call.id.clone();
 98
 99        let tool = match self.tools.get(&name) {
100            Some(tool) => tool,
101            None => {
102                let name = name.clone();
103                return Task::ready(Ok(ToolFunctionCall {
104                    id,
105                    name: name.clone(),
106                    arguments,
107                    result: Some(ToolFunctionCallResult::NoSuchTool),
108                }));
109            }
110        };
111
112        tool(tool_call, cx)
113    }
114
115    pub fn status_views(&self) -> &[AnyView] {
116        &self.status_views
117    }
118}
119
120#[cfg(test)]
121mod test {
122    use super::*;
123    use gpui::{div, prelude::*, Render, TestAppContext};
124    use gpui::{EmptyView, View};
125    use schemars::schema_for;
126    use schemars::JsonSchema;
127    use serde::{Deserialize, Serialize};
128    use serde_json::json;
129
130    #[derive(Deserialize, Serialize, JsonSchema)]
131    struct WeatherQuery {
132        location: String,
133        unit: String,
134    }
135
136    struct WeatherTool {
137        current_weather: WeatherResult,
138    }
139
140    #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
141    struct WeatherResult {
142        location: String,
143        temperature: f64,
144        unit: String,
145    }
146
147    struct WeatherView {
148        result: WeatherResult,
149    }
150
151    impl Render for WeatherView {
152        fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
153            div().child(format!("temperature: {}", self.result.temperature))
154        }
155    }
156
157    impl LanguageModelTool for WeatherTool {
158        type Input = WeatherQuery;
159        type Output = WeatherResult;
160        type View = WeatherView;
161
162        fn name(&self) -> String {
163            "get_current_weather".to_string()
164        }
165
166        fn description(&self) -> String {
167            "Fetches the current weather for a given location.".to_string()
168        }
169
170        fn execute(
171            &self,
172            input: &Self::Input,
173            _cx: &mut WindowContext,
174        ) -> Task<Result<Self::Output>> {
175            let _location = input.location.clone();
176            let _unit = input.unit.clone();
177
178            let weather = self.current_weather.clone();
179
180            Task::ready(Ok(weather))
181        }
182
183        fn output_view(
184            _tool_call_id: String,
185            _input: Self::Input,
186            result: Result<Self::Output>,
187            cx: &mut WindowContext,
188        ) -> View<Self::View> {
189            cx.new_view(|_cx| {
190                let result = result.unwrap();
191                WeatherView { result }
192            })
193        }
194
195        fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
196            serde_json::to_string(&output.as_ref().unwrap()).unwrap()
197        }
198    }
199
200    #[gpui::test]
201    async fn test_openai_weather_example(cx: &mut TestAppContext) {
202        cx.background_executor.run_until_parked();
203        let (_, cx) = cx.add_window_view(|_cx| EmptyView);
204
205        let tool = WeatherTool {
206            current_weather: WeatherResult {
207                location: "San Francisco".to_string(),
208                temperature: 21.0,
209                unit: "Celsius".to_string(),
210            },
211        };
212
213        let tools = vec![tool.definition()];
214        assert_eq!(tools.len(), 1);
215
216        let expected = ToolFunctionDefinition {
217            name: "get_current_weather".to_string(),
218            description: "Fetches the current weather for a given location.".to_string(),
219            parameters: schema_for!(WeatherQuery),
220        };
221
222        assert_eq!(tools[0].name, expected.name);
223        assert_eq!(tools[0].description, expected.description);
224
225        let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
226
227        assert_eq!(
228            expected_schema,
229            json!({
230                "$schema": "http://json-schema.org/draft-07/schema#",
231                "title": "WeatherQuery",
232                "type": "object",
233                "properties": {
234                    "location": {
235                        "type": "string"
236                    },
237                    "unit": {
238                        "type": "string"
239                    }
240                },
241                "required": ["location", "unit"]
242            })
243        );
244
245        let args = json!({
246            "location": "San Francisco",
247            "unit": "Celsius"
248        });
249
250        let query: WeatherQuery = serde_json::from_value(args).unwrap();
251
252        let result = cx.update(|cx| tool.execute(&query, cx)).await;
253
254        assert!(result.is_ok());
255        let result = result.unwrap();
256
257        assert_eq!(result, tool.current_weather);
258    }
259}