registry.rs

  1use anyhow::{anyhow, Result};
  2use gpui::{div, AnyElement, IntoElement as _, ParentElement, Styled, Task, WindowContext};
  3use std::{
  4    any::TypeId,
  5    collections::HashMap,
  6    sync::atomic::{AtomicBool, Ordering::SeqCst},
  7};
  8
  9use crate::tool::{
 10    LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
 11};
 12
 13// Internal Tool representation for the registry
 14pub struct Tool {
 15    enabled: AtomicBool,
 16    type_id: TypeId,
 17    call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
 18    render_running: Box<dyn Fn(&mut WindowContext) -> gpui::AnyElement>,
 19    definition: ToolFunctionDefinition,
 20}
 21
 22impl Tool {
 23    fn new(
 24        type_id: TypeId,
 25        call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
 26        render_running: Box<dyn Fn(&mut WindowContext) -> gpui::AnyElement>,
 27        definition: ToolFunctionDefinition,
 28    ) -> Self {
 29        Self {
 30            enabled: AtomicBool::new(true),
 31            type_id,
 32            call,
 33            render_running,
 34            definition,
 35        }
 36    }
 37}
 38
 39pub struct ToolRegistry {
 40    tools: HashMap<String, Tool>,
 41}
 42
 43impl ToolRegistry {
 44    pub fn new() -> Self {
 45        Self {
 46            tools: HashMap::new(),
 47        }
 48    }
 49
 50    pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
 51        for tool in self.tools.values() {
 52            if tool.type_id == TypeId::of::<T>() {
 53                tool.enabled.store(is_enabled, SeqCst);
 54                return;
 55            }
 56        }
 57    }
 58
 59    pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
 60        for tool in self.tools.values() {
 61            if tool.type_id == TypeId::of::<T>() {
 62                return tool.enabled.load(SeqCst);
 63            }
 64        }
 65        false
 66    }
 67
 68    pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
 69        self.tools
 70            .values()
 71            .filter(|tool| tool.enabled.load(SeqCst))
 72            .map(|tool| tool.definition.clone())
 73            .collect()
 74    }
 75
 76    pub fn render_tool_call(
 77        &self,
 78        tool_call: &ToolFunctionCall,
 79        cx: &mut WindowContext,
 80    ) -> AnyElement {
 81        match &tool_call.result {
 82            Some(result) => div()
 83                .p_2()
 84                .child(result.into_any_element(&tool_call.name))
 85                .into_any_element(),
 86            None => self
 87                .tools
 88                .get(&tool_call.name)
 89                .map(|tool| (tool.render_running)(cx))
 90                .unwrap_or_else(|| div().into_any_element()),
 91        }
 92    }
 93
 94    pub fn register<T: 'static + LanguageModelTool>(
 95        &mut self,
 96        tool: T,
 97        _cx: &mut WindowContext,
 98    ) -> Result<()> {
 99        let definition = tool.definition();
100
101        let name = tool.name();
102
103        let registered_tool = Tool::new(
104            TypeId::of::<T>(),
105            Box::new(
106                move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
107                    let name = tool_call.name.clone();
108                    let arguments = tool_call.arguments.clone();
109                    let id = tool_call.id.clone();
110
111                    let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
112                        return Task::ready(Ok(ToolFunctionCall {
113                            id,
114                            name: name.clone(),
115                            arguments,
116                            result: Some(ToolFunctionCallResult::ParsingFailed),
117                        }));
118                    };
119
120                    let result = tool.execute(&input, cx);
121
122                    cx.spawn(move |mut cx| async move {
123                        let result: Result<T::Output> = result.await;
124                        let for_model = T::format(&input, &result);
125                        let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?;
126
127                        Ok(ToolFunctionCall {
128                            id,
129                            name: name.clone(),
130                            arguments,
131                            result: Some(ToolFunctionCallResult::Finished {
132                                view: view.into(),
133                                for_model,
134                            }),
135                        })
136                    })
137                },
138            ),
139            Box::new(|cx| T::render_running(cx).into_any_element()),
140            definition,
141        );
142
143        let previous = self.tools.insert(name.clone(), registered_tool);
144
145        if previous.is_some() {
146            return Err(anyhow!("already registered a tool with name {}", name));
147        }
148
149        Ok(())
150    }
151
152    /// Task yields an error if the window for the given WindowContext is closed before the task completes.
153    pub fn call(
154        &self,
155        tool_call: &ToolFunctionCall,
156        cx: &mut WindowContext,
157    ) -> Task<Result<ToolFunctionCall>> {
158        let name = tool_call.name.clone();
159        let arguments = tool_call.arguments.clone();
160        let id = tool_call.id.clone();
161
162        let tool = match self.tools.get(&name) {
163            Some(tool) => tool,
164            None => {
165                let name = name.clone();
166                return Task::ready(Ok(ToolFunctionCall {
167                    id,
168                    name: name.clone(),
169                    arguments,
170                    result: Some(ToolFunctionCallResult::NoSuchTool),
171                }));
172            }
173        };
174
175        (tool.call)(tool_call, cx)
176    }
177}
178
179#[cfg(test)]
180mod test {
181    use super::*;
182    use gpui::{div, prelude::*, Render, TestAppContext};
183    use gpui::{EmptyView, View};
184    use schemars::schema_for;
185    use schemars::JsonSchema;
186    use serde::{Deserialize, Serialize};
187    use serde_json::json;
188
189    #[derive(Deserialize, Serialize, JsonSchema)]
190    struct WeatherQuery {
191        location: String,
192        unit: String,
193    }
194
195    struct WeatherTool {
196        current_weather: WeatherResult,
197    }
198
199    #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
200    struct WeatherResult {
201        location: String,
202        temperature: f64,
203        unit: String,
204    }
205
206    struct WeatherView {
207        result: WeatherResult,
208    }
209
210    impl Render for WeatherView {
211        fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
212            div().child(format!("temperature: {}", self.result.temperature))
213        }
214    }
215
216    impl LanguageModelTool for WeatherTool {
217        type Input = WeatherQuery;
218        type Output = WeatherResult;
219        type View = WeatherView;
220
221        fn name(&self) -> String {
222            "get_current_weather".to_string()
223        }
224
225        fn description(&self) -> String {
226            "Fetches the current weather for a given location.".to_string()
227        }
228
229        fn execute(
230            &self,
231            input: &Self::Input,
232            _cx: &mut WindowContext,
233        ) -> Task<Result<Self::Output>> {
234            let _location = input.location.clone();
235            let _unit = input.unit.clone();
236
237            let weather = self.current_weather.clone();
238
239            Task::ready(Ok(weather))
240        }
241
242        fn output_view(
243            _tool_call_id: String,
244            _input: Self::Input,
245            result: Result<Self::Output>,
246            cx: &mut WindowContext,
247        ) -> View<Self::View> {
248            cx.new_view(|_cx| {
249                let result = result.unwrap();
250                WeatherView { result }
251            })
252        }
253
254        fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
255            serde_json::to_string(&output.as_ref().unwrap()).unwrap()
256        }
257    }
258
259    #[gpui::test]
260    async fn test_openai_weather_example(cx: &mut TestAppContext) {
261        cx.background_executor.run_until_parked();
262        let (_, cx) = cx.add_window_view(|_cx| EmptyView);
263
264        let tool = WeatherTool {
265            current_weather: WeatherResult {
266                location: "San Francisco".to_string(),
267                temperature: 21.0,
268                unit: "Celsius".to_string(),
269            },
270        };
271
272        let tools = vec![tool.definition()];
273        assert_eq!(tools.len(), 1);
274
275        let expected = ToolFunctionDefinition {
276            name: "get_current_weather".to_string(),
277            description: "Fetches the current weather for a given location.".to_string(),
278            parameters: schema_for!(WeatherQuery),
279        };
280
281        assert_eq!(tools[0].name, expected.name);
282        assert_eq!(tools[0].description, expected.description);
283
284        let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
285
286        assert_eq!(
287            expected_schema,
288            json!({
289                "$schema": "http://json-schema.org/draft-07/schema#",
290                "title": "WeatherQuery",
291                "type": "object",
292                "properties": {
293                    "location": {
294                        "type": "string"
295                    },
296                    "unit": {
297                        "type": "string"
298                    }
299                },
300                "required": ["location", "unit"]
301            })
302        );
303
304        let args = json!({
305            "location": "San Francisco",
306            "unit": "Celsius"
307        });
308
309        let query: WeatherQuery = serde_json::from_value(args).unwrap();
310
311        let result = cx.update(|cx| tool.execute(&query, cx)).await;
312
313        assert!(result.is_ok());
314        let result = result.unwrap();
315
316        assert_eq!(result, tool.current_weather);
317    }
318}