tool_registry.rs

  1use anyhow::{anyhow, Result};
  2use gpui::{
  3    div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
  4};
  5use schemars::{schema::RootSchema, schema_for, JsonSchema};
  6use serde::Deserialize;
  7use std::{
  8    any::TypeId,
  9    collections::HashMap,
 10    fmt::Display,
 11    sync::atomic::{AtomicBool, Ordering::SeqCst},
 12};
 13
 14use crate::ProjectContext;
 15
 16pub struct ToolRegistry {
 17    registered_tools: HashMap<String, RegisteredTool>,
 18}
 19
 20#[derive(Default, Deserialize)]
 21pub struct ToolFunctionCall {
 22    pub id: String,
 23    pub name: String,
 24    pub arguments: String,
 25    #[serde(skip)]
 26    pub result: Option<ToolFunctionCallResult>,
 27}
 28
 29pub enum ToolFunctionCallResult {
 30    NoSuchTool,
 31    ParsingFailed,
 32    Finished {
 33        view: AnyView,
 34        generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
 35    },
 36}
 37
 38#[derive(Clone)]
 39pub struct ToolFunctionDefinition {
 40    pub name: String,
 41    pub description: String,
 42    pub parameters: RootSchema,
 43}
 44
 45pub trait LanguageModelTool {
 46    /// The input type that will be passed in to `execute` when the tool is called
 47    /// by the language model.
 48    type Input: for<'de> Deserialize<'de> + JsonSchema;
 49
 50    /// The output returned by executing the tool.
 51    type Output: 'static;
 52
 53    type View: Render + ToolOutput;
 54
 55    /// Returns the name of the tool.
 56    ///
 57    /// This name is exposed to the language model to allow the model to pick
 58    /// which tools to use. As this name is used to identify the tool within a
 59    /// tool registry, it should be unique.
 60    fn name(&self) -> String;
 61
 62    /// Returns the description of the tool.
 63    ///
 64    /// This can be used to _prompt_ the model as to what the tool does.
 65    fn description(&self) -> String;
 66
 67    /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
 68    fn definition(&self) -> ToolFunctionDefinition {
 69        let root_schema = schema_for!(Self::Input);
 70
 71        ToolFunctionDefinition {
 72            name: self.name(),
 73            description: self.description(),
 74            parameters: root_schema,
 75        }
 76    }
 77
 78    /// Executes the tool with the given input.
 79    fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
 80
 81    fn output_view(
 82        input: Self::Input,
 83        output: Result<Self::Output>,
 84        cx: &mut WindowContext,
 85    ) -> View<Self::View>;
 86
 87    fn render_running(_cx: &mut WindowContext) -> impl IntoElement {
 88        div()
 89    }
 90}
 91
 92pub trait ToolOutput: Sized {
 93    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
 94}
 95
 96struct RegisteredTool {
 97    enabled: AtomicBool,
 98    type_id: TypeId,
 99    call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
100    render_running: fn(&mut WindowContext) -> gpui::AnyElement,
101    definition: ToolFunctionDefinition,
102}
103
104impl ToolRegistry {
105    pub fn new() -> Self {
106        Self {
107            registered_tools: HashMap::new(),
108        }
109    }
110
111    pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
112        for tool in self.registered_tools.values() {
113            if tool.type_id == TypeId::of::<T>() {
114                tool.enabled.store(is_enabled, SeqCst);
115                return;
116            }
117        }
118    }
119
120    pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
121        for tool in self.registered_tools.values() {
122            if tool.type_id == TypeId::of::<T>() {
123                return tool.enabled.load(SeqCst);
124            }
125        }
126        false
127    }
128
129    pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
130        self.registered_tools
131            .values()
132            .filter(|tool| tool.enabled.load(SeqCst))
133            .map(|tool| tool.definition.clone())
134            .collect()
135    }
136
137    pub fn render_tool_call(
138        &self,
139        tool_call: &ToolFunctionCall,
140        cx: &mut WindowContext,
141    ) -> AnyElement {
142        match &tool_call.result {
143            Some(result) => div()
144                .p_2()
145                .child(result.into_any_element(&tool_call.name))
146                .into_any_element(),
147            None => self
148                .registered_tools
149                .get(&tool_call.name)
150                .map(|tool| (tool.render_running)(cx))
151                .unwrap_or_else(|| div().into_any_element()),
152        }
153    }
154
155    pub fn register<T: 'static + LanguageModelTool>(
156        &mut self,
157        tool: T,
158        _cx: &mut WindowContext,
159    ) -> Result<()> {
160        let name = tool.name();
161        let registered_tool = RegisteredTool {
162            type_id: TypeId::of::<T>(),
163            definition: tool.definition(),
164            enabled: AtomicBool::new(true),
165            call: Box::new(
166                move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
167                    let name = tool_call.name.clone();
168                    let arguments = tool_call.arguments.clone();
169                    let id = tool_call.id.clone();
170
171                    let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
172                        return Task::ready(Ok(ToolFunctionCall {
173                            id,
174                            name: name.clone(),
175                            arguments,
176                            result: Some(ToolFunctionCallResult::ParsingFailed),
177                        }));
178                    };
179
180                    let result = tool.execute(&input, cx);
181
182                    cx.spawn(move |mut cx| async move {
183                        let result: Result<T::Output> = result.await;
184                        let view = cx.update(|cx| T::output_view(input, result, cx))?;
185
186                        Ok(ToolFunctionCall {
187                            id,
188                            name: name.clone(),
189                            arguments,
190                            result: Some(ToolFunctionCallResult::Finished {
191                                view: view.into(),
192                                generate_fn: generate::<T>,
193                            }),
194                        })
195                    })
196                },
197            ),
198            render_running: render_running::<T>,
199        };
200
201        let previous = self.registered_tools.insert(name.clone(), registered_tool);
202        if previous.is_some() {
203            return Err(anyhow!("already registered a tool with name {}", name));
204        }
205
206        return Ok(());
207
208        fn render_running<T: LanguageModelTool>(cx: &mut WindowContext) -> AnyElement {
209            T::render_running(cx).into_any_element()
210        }
211
212        fn generate<T: LanguageModelTool>(
213            view: AnyView,
214            project: &mut ProjectContext,
215            cx: &mut WindowContext,
216        ) -> String {
217            view.downcast::<T::View>()
218                .unwrap()
219                .update(cx, |view, cx| T::View::generate(view, project, cx))
220        }
221    }
222
223    /// Task yields an error if the window for the given WindowContext is closed before the task completes.
224    pub fn call(
225        &self,
226        tool_call: &ToolFunctionCall,
227        cx: &mut WindowContext,
228    ) -> Task<Result<ToolFunctionCall>> {
229        let name = tool_call.name.clone();
230        let arguments = tool_call.arguments.clone();
231        let id = tool_call.id.clone();
232
233        let tool = match self.registered_tools.get(&name) {
234            Some(tool) => tool,
235            None => {
236                let name = name.clone();
237                return Task::ready(Ok(ToolFunctionCall {
238                    id,
239                    name: name.clone(),
240                    arguments,
241                    result: Some(ToolFunctionCallResult::NoSuchTool),
242                }));
243            }
244        };
245
246        (tool.call)(tool_call, cx)
247    }
248}
249
250impl ToolFunctionCallResult {
251    pub fn generate(
252        &self,
253        name: &String,
254        project: &mut ProjectContext,
255        cx: &mut WindowContext,
256    ) -> String {
257        match self {
258            ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
259            ToolFunctionCallResult::ParsingFailed => {
260                format!("Unable to parse arguments for {name}")
261            }
262            ToolFunctionCallResult::Finished { generate_fn, view } => {
263                (generate_fn)(view.clone(), project, cx)
264            }
265        }
266    }
267
268    fn into_any_element(&self, name: &String) -> AnyElement {
269        match self {
270            ToolFunctionCallResult::NoSuchTool => {
271                format!("Language Model attempted to call {name}").into_any_element()
272            }
273            ToolFunctionCallResult::ParsingFailed => {
274                format!("Language Model called {name} with bad arguments").into_any_element()
275            }
276            ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
277        }
278    }
279}
280
281impl Display for ToolFunctionDefinition {
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        let schema = serde_json::to_string(&self.parameters).ok();
284        let schema = schema.unwrap_or("None".to_string());
285        write!(f, "Name: {}:\n", self.name)?;
286        write!(f, "Description: {}\n", self.description)?;
287        write!(f, "Parameters: {}", schema)
288    }
289}
290
291#[cfg(test)]
292mod test {
293    use super::*;
294    use gpui::{div, prelude::*, Render, TestAppContext};
295    use gpui::{EmptyView, View};
296    use schemars::schema_for;
297    use schemars::JsonSchema;
298    use serde::{Deserialize, Serialize};
299    use serde_json::json;
300
301    #[derive(Deserialize, Serialize, JsonSchema)]
302    struct WeatherQuery {
303        location: String,
304        unit: String,
305    }
306
307    struct WeatherTool {
308        current_weather: WeatherResult,
309    }
310
311    #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
312    struct WeatherResult {
313        location: String,
314        temperature: f64,
315        unit: String,
316    }
317
318    struct WeatherView {
319        result: WeatherResult,
320    }
321
322    impl Render for WeatherView {
323        fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
324            div().child(format!("temperature: {}", self.result.temperature))
325        }
326    }
327
328    impl ToolOutput for WeatherView {
329        fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
330            serde_json::to_string(&self.result).unwrap()
331        }
332    }
333
334    impl LanguageModelTool for WeatherTool {
335        type Input = WeatherQuery;
336        type Output = WeatherResult;
337        type View = WeatherView;
338
339        fn name(&self) -> String {
340            "get_current_weather".to_string()
341        }
342
343        fn description(&self) -> String {
344            "Fetches the current weather for a given location.".to_string()
345        }
346
347        fn execute(
348            &self,
349            input: &Self::Input,
350            _cx: &mut WindowContext,
351        ) -> Task<Result<Self::Output>> {
352            let _location = input.location.clone();
353            let _unit = input.unit.clone();
354
355            let weather = self.current_weather.clone();
356
357            Task::ready(Ok(weather))
358        }
359
360        fn output_view(
361            _input: Self::Input,
362            result: Result<Self::Output>,
363            cx: &mut WindowContext,
364        ) -> View<Self::View> {
365            cx.new_view(|_cx| {
366                let result = result.unwrap();
367                WeatherView { result }
368            })
369        }
370    }
371
372    #[gpui::test]
373    async fn test_openai_weather_example(cx: &mut TestAppContext) {
374        cx.background_executor.run_until_parked();
375        let (_, cx) = cx.add_window_view(|_cx| EmptyView);
376
377        let tool = WeatherTool {
378            current_weather: WeatherResult {
379                location: "San Francisco".to_string(),
380                temperature: 21.0,
381                unit: "Celsius".to_string(),
382            },
383        };
384
385        let tools = vec![tool.definition()];
386        assert_eq!(tools.len(), 1);
387
388        let expected = ToolFunctionDefinition {
389            name: "get_current_weather".to_string(),
390            description: "Fetches the current weather for a given location.".to_string(),
391            parameters: schema_for!(WeatherQuery),
392        };
393
394        assert_eq!(tools[0].name, expected.name);
395        assert_eq!(tools[0].description, expected.description);
396
397        let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
398
399        assert_eq!(
400            expected_schema,
401            json!({
402                "$schema": "http://json-schema.org/draft-07/schema#",
403                "title": "WeatherQuery",
404                "type": "object",
405                "properties": {
406                    "location": {
407                        "type": "string"
408                    },
409                    "unit": {
410                        "type": "string"
411                    }
412                },
413                "required": ["location", "unit"]
414            })
415        );
416
417        let args = json!({
418            "location": "San Francisco",
419            "unit": "Celsius"
420        });
421
422        let query: WeatherQuery = serde_json::from_value(args).unwrap();
423
424        let result = cx.update(|cx| tool.execute(&query, cx)).await;
425
426        assert!(result.is_ok());
427        let result = result.unwrap();
428
429        assert_eq!(result, tool.current_weather);
430    }
431}