registry.rs

  1use anyhow::{anyhow, Result};
  2use gpui::{Task, WindowContext};
  3use std::{
  4    any::TypeId,
  5    collections::HashMap,
  6    sync::atomic::{AtomicBool, Ordering::SeqCst},
  7};
  8
  9use crate::tool::{
 10    LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
 11};
 12
 13// Internal Tool representation for the registry
 14pub struct Tool {
 15    enabled: AtomicBool,
 16    type_id: TypeId,
 17    call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
 18    definition: ToolFunctionDefinition,
 19}
 20
 21impl Tool {
 22    fn new(
 23        type_id: TypeId,
 24        call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
 25        definition: ToolFunctionDefinition,
 26    ) -> Self {
 27        Self {
 28            enabled: AtomicBool::new(true),
 29            type_id,
 30            call,
 31            definition,
 32        }
 33    }
 34}
 35
 36pub struct ToolRegistry {
 37    tools: HashMap<String, Tool>,
 38}
 39
 40impl ToolRegistry {
 41    pub fn new() -> Self {
 42        Self {
 43            tools: HashMap::new(),
 44        }
 45    }
 46
 47    pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
 48        for tool in self.tools.values() {
 49            if tool.type_id == TypeId::of::<T>() {
 50                tool.enabled.store(is_enabled, SeqCst);
 51                return;
 52            }
 53        }
 54    }
 55
 56    pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
 57        for tool in self.tools.values() {
 58            if tool.type_id == TypeId::of::<T>() {
 59                return tool.enabled.load(SeqCst);
 60            }
 61        }
 62        false
 63    }
 64
 65    pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
 66        self.tools
 67            .values()
 68            .filter(|tool| tool.enabled.load(SeqCst))
 69            .map(|tool| tool.definition.clone())
 70            .collect()
 71    }
 72
 73    pub fn register<T: 'static + LanguageModelTool>(
 74        &mut self,
 75        tool: T,
 76        _cx: &mut WindowContext,
 77    ) -> Result<()> {
 78        let definition = tool.definition();
 79
 80        let name = tool.name();
 81
 82        let registered_tool = Tool::new(
 83            TypeId::of::<T>(),
 84            Box::new(
 85                move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
 86                    let name = tool_call.name.clone();
 87                    let arguments = tool_call.arguments.clone();
 88                    let id = tool_call.id.clone();
 89
 90                    let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
 91                        return Task::ready(Ok(ToolFunctionCall {
 92                            id,
 93                            name: name.clone(),
 94                            arguments,
 95                            result: Some(ToolFunctionCallResult::ParsingFailed),
 96                        }));
 97                    };
 98
 99                    let result = tool.execute(&input, cx);
100
101                    cx.spawn(move |mut cx| async move {
102                        let result: Result<T::Output> = result.await;
103                        let for_model = T::format(&input, &result);
104                        let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?;
105
106                        Ok(ToolFunctionCall {
107                            id,
108                            name: name.clone(),
109                            arguments,
110                            result: Some(ToolFunctionCallResult::Finished {
111                                view: view.into(),
112                                for_model,
113                            }),
114                        })
115                    })
116                },
117            ),
118            definition,
119        );
120
121        let previous = self.tools.insert(name.clone(), registered_tool);
122
123        if previous.is_some() {
124            return Err(anyhow!("already registered a tool with name {}", name));
125        }
126
127        Ok(())
128    }
129
130    /// Task yields an error if the window for the given WindowContext is closed before the task completes.
131    pub fn call(
132        &self,
133        tool_call: &ToolFunctionCall,
134        cx: &mut WindowContext,
135    ) -> Task<Result<ToolFunctionCall>> {
136        let name = tool_call.name.clone();
137        let arguments = tool_call.arguments.clone();
138        let id = tool_call.id.clone();
139
140        let tool = match self.tools.get(&name) {
141            Some(tool) => tool,
142            None => {
143                let name = name.clone();
144                return Task::ready(Ok(ToolFunctionCall {
145                    id,
146                    name: name.clone(),
147                    arguments,
148                    result: Some(ToolFunctionCallResult::NoSuchTool),
149                }));
150            }
151        };
152
153        (tool.call)(tool_call, cx)
154    }
155}
156
157#[cfg(test)]
158mod test {
159    use super::*;
160    use gpui::{div, prelude::*, Render, TestAppContext};
161    use gpui::{EmptyView, View};
162    use schemars::schema_for;
163    use schemars::JsonSchema;
164    use serde::{Deserialize, Serialize};
165    use serde_json::json;
166
167    #[derive(Deserialize, Serialize, JsonSchema)]
168    struct WeatherQuery {
169        location: String,
170        unit: String,
171    }
172
173    struct WeatherTool {
174        current_weather: WeatherResult,
175    }
176
177    #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
178    struct WeatherResult {
179        location: String,
180        temperature: f64,
181        unit: String,
182    }
183
184    struct WeatherView {
185        result: WeatherResult,
186    }
187
188    impl Render for WeatherView {
189        fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
190            div().child(format!("temperature: {}", self.result.temperature))
191        }
192    }
193
194    impl LanguageModelTool for WeatherTool {
195        type Input = WeatherQuery;
196        type Output = WeatherResult;
197        type View = WeatherView;
198
199        fn name(&self) -> String {
200            "get_current_weather".to_string()
201        }
202
203        fn description(&self) -> String {
204            "Fetches the current weather for a given location.".to_string()
205        }
206
207        fn execute(
208            &self,
209            input: &Self::Input,
210            _cx: &mut WindowContext,
211        ) -> Task<Result<Self::Output>> {
212            let _location = input.location.clone();
213            let _unit = input.unit.clone();
214
215            let weather = self.current_weather.clone();
216
217            Task::ready(Ok(weather))
218        }
219
220        fn output_view(
221            _tool_call_id: String,
222            _input: Self::Input,
223            result: Result<Self::Output>,
224            cx: &mut WindowContext,
225        ) -> View<Self::View> {
226            cx.new_view(|_cx| {
227                let result = result.unwrap();
228                WeatherView { result }
229            })
230        }
231
232        fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
233            serde_json::to_string(&output.as_ref().unwrap()).unwrap()
234        }
235    }
236
237    #[gpui::test]
238    async fn test_openai_weather_example(cx: &mut TestAppContext) {
239        cx.background_executor.run_until_parked();
240        let (_, cx) = cx.add_window_view(|_cx| EmptyView);
241
242        let tool = WeatherTool {
243            current_weather: WeatherResult {
244                location: "San Francisco".to_string(),
245                temperature: 21.0,
246                unit: "Celsius".to_string(),
247            },
248        };
249
250        let tools = vec![tool.definition()];
251        assert_eq!(tools.len(), 1);
252
253        let expected = ToolFunctionDefinition {
254            name: "get_current_weather".to_string(),
255            description: "Fetches the current weather for a given location.".to_string(),
256            parameters: schema_for!(WeatherQuery),
257        };
258
259        assert_eq!(tools[0].name, expected.name);
260        assert_eq!(tools[0].description, expected.description);
261
262        let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
263
264        assert_eq!(
265            expected_schema,
266            json!({
267                "$schema": "http://json-schema.org/draft-07/schema#",
268                "title": "WeatherQuery",
269                "type": "object",
270                "properties": {
271                    "location": {
272                        "type": "string"
273                    },
274                    "unit": {
275                        "type": "string"
276                    }
277                },
278                "required": ["location", "unit"]
279            })
280        );
281
282        let args = json!({
283            "location": "San Francisco",
284            "unit": "Celsius"
285        });
286
287        let query: WeatherQuery = serde_json::from_value(args).unwrap();
288
289        let result = cx.update(|cx| tool.execute(&query, cx)).await;
290
291        assert!(result.is_ok());
292        let result = result.unwrap();
293
294        assert_eq!(result, tool.current_weather);
295    }
296}