tool_registry.rs

  1use anyhow::{anyhow, Result};
  2use gpui::{
  3    div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
  4};
  5use schemars::{schema::RootSchema, schema_for, JsonSchema};
  6use serde::Deserialize;
  7use serde_json::Value;
  8use std::{
  9    any::TypeId,
 10    collections::HashMap,
 11    fmt::Display,
 12    sync::atomic::{AtomicBool, Ordering::SeqCst},
 13};
 14
 15use crate::ProjectContext;
 16
 17pub struct ToolRegistry {
 18    registered_tools: HashMap<String, RegisteredTool>,
 19}
 20
 21#[derive(Default, Deserialize)]
 22pub struct ToolFunctionCall {
 23    pub id: String,
 24    pub name: String,
 25    pub arguments: String,
 26    #[serde(skip)]
 27    pub result: Option<ToolFunctionCallResult>,
 28}
 29
 30pub enum ToolFunctionCallResult {
 31    NoSuchTool,
 32    ParsingFailed,
 33    Finished {
 34        view: AnyView,
 35        generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
 36    },
 37}
 38
 39#[derive(Clone)]
 40pub struct ToolFunctionDefinition {
 41    pub name: String,
 42    pub description: String,
 43    pub parameters: RootSchema,
 44}
 45
 46pub trait LanguageModelTool {
 47    /// The input type that will be passed in to `execute` when the tool is called
 48    /// by the language model.
 49    type Input: for<'de> Deserialize<'de> + JsonSchema;
 50
 51    /// The output returned by executing the tool.
 52    type Output: 'static;
 53
 54    type View: Render + ToolOutput;
 55
 56    /// Returns the name of the tool.
 57    ///
 58    /// This name is exposed to the language model to allow the model to pick
 59    /// which tools to use. As this name is used to identify the tool within a
 60    /// tool registry, it should be unique.
 61    fn name(&self) -> String;
 62
 63    /// Returns the description of the tool.
 64    ///
 65    /// This can be used to _prompt_ the model as to what the tool does.
 66    fn description(&self) -> String;
 67
 68    /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
 69    fn definition(&self) -> ToolFunctionDefinition {
 70        let root_schema = schema_for!(Self::Input);
 71
 72        ToolFunctionDefinition {
 73            name: self.name(),
 74            description: self.description(),
 75            parameters: root_schema,
 76        }
 77    }
 78
 79    /// Executes the tool with the given input.
 80    fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
 81
 82    /// A view of the output of running the tool, for displaying to the user.
 83    fn output_view(
 84        input: Self::Input,
 85        output: Result<Self::Output>,
 86        cx: &mut WindowContext,
 87    ) -> View<Self::View>;
 88
 89    fn render_running(_arguments: &Option<Value>, _cx: &mut WindowContext) -> impl IntoElement {
 90        tool_running_placeholder()
 91    }
 92}
 93
 94pub fn tool_running_placeholder() -> AnyElement {
 95    ui::Label::new("Researching...").into_any_element()
 96}
 97
 98pub trait ToolOutput: Sized {
 99    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
100}
101
102struct RegisteredTool {
103    enabled: AtomicBool,
104    type_id: TypeId,
105    call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
106    render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
107    definition: ToolFunctionDefinition,
108}
109
110impl ToolRegistry {
111    pub fn new() -> Self {
112        Self {
113            registered_tools: HashMap::new(),
114        }
115    }
116
117    pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
118        for tool in self.registered_tools.values() {
119            if tool.type_id == TypeId::of::<T>() {
120                tool.enabled.store(is_enabled, SeqCst);
121                return;
122            }
123        }
124    }
125
126    pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
127        for tool in self.registered_tools.values() {
128            if tool.type_id == TypeId::of::<T>() {
129                return tool.enabled.load(SeqCst);
130            }
131        }
132        false
133    }
134
135    pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
136        self.registered_tools
137            .values()
138            .filter(|tool| tool.enabled.load(SeqCst))
139            .map(|tool| tool.definition.clone())
140            .collect()
141    }
142
143    pub fn render_tool_call(
144        &self,
145        tool_call: &ToolFunctionCall,
146        cx: &mut WindowContext,
147    ) -> AnyElement {
148        match &tool_call.result {
149            Some(result) => div()
150                .p_2()
151                .child(result.into_any_element(&tool_call.name))
152                .into_any_element(),
153            None => {
154                let tool = self.registered_tools.get(&tool_call.name);
155
156                if let Some(tool) = tool {
157                    (tool.render_running)(&tool_call, cx)
158                } else {
159                    tool_running_placeholder()
160                }
161            }
162        }
163    }
164
165    pub fn register<T: 'static + LanguageModelTool>(
166        &mut self,
167        tool: T,
168        _cx: &mut WindowContext,
169    ) -> Result<()> {
170        let name = tool.name();
171        let registered_tool = RegisteredTool {
172            type_id: TypeId::of::<T>(),
173            definition: tool.definition(),
174            enabled: AtomicBool::new(true),
175            call: Box::new(
176                move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
177                    let name = tool_call.name.clone();
178                    let arguments = tool_call.arguments.clone();
179                    let id = tool_call.id.clone();
180
181                    let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
182                        return Task::ready(Ok(ToolFunctionCall {
183                            id,
184                            name: name.clone(),
185                            arguments,
186                            result: Some(ToolFunctionCallResult::ParsingFailed),
187                        }));
188                    };
189
190                    let result = tool.execute(&input, cx);
191
192                    cx.spawn(move |mut cx| async move {
193                        let result: Result<T::Output> = result.await;
194                        let view = cx.update(|cx| T::output_view(input, result, cx))?;
195
196                        Ok(ToolFunctionCall {
197                            id,
198                            name: name.clone(),
199                            arguments,
200                            result: Some(ToolFunctionCallResult::Finished {
201                                view: view.into(),
202                                generate_fn: generate::<T>,
203                            }),
204                        })
205                    })
206                },
207            ),
208            render_running: render_running::<T>,
209        };
210
211        let previous = self.registered_tools.insert(name.clone(), registered_tool);
212        if previous.is_some() {
213            return Err(anyhow!("already registered a tool with name {}", name));
214        }
215
216        return Ok(());
217
218        fn render_running<T: LanguageModelTool>(
219            tool_call: &ToolFunctionCall,
220            cx: &mut WindowContext,
221        ) -> AnyElement {
222            // Attempt to parse the string arguments that are JSON as a JSON value
223            let maybe_arguments = serde_json::to_value(tool_call.arguments.clone()).ok();
224
225            T::render_running(&maybe_arguments, cx).into_any_element()
226        }
227
228        fn generate<T: LanguageModelTool>(
229            view: AnyView,
230            project: &mut ProjectContext,
231            cx: &mut WindowContext,
232        ) -> String {
233            view.downcast::<T::View>()
234                .unwrap()
235                .update(cx, |view, cx| T::View::generate(view, project, cx))
236        }
237    }
238
239    /// Task yields an error if the window for the given WindowContext is closed before the task completes.
240    pub fn call(
241        &self,
242        tool_call: &ToolFunctionCall,
243        cx: &mut WindowContext,
244    ) -> Task<Result<ToolFunctionCall>> {
245        let name = tool_call.name.clone();
246        let arguments = tool_call.arguments.clone();
247        let id = tool_call.id.clone();
248
249        let tool = match self.registered_tools.get(&name) {
250            Some(tool) => tool,
251            None => {
252                let name = name.clone();
253                return Task::ready(Ok(ToolFunctionCall {
254                    id,
255                    name: name.clone(),
256                    arguments,
257                    result: Some(ToolFunctionCallResult::NoSuchTool),
258                }));
259            }
260        };
261
262        (tool.call)(tool_call, cx)
263    }
264}
265
266impl ToolFunctionCallResult {
267    pub fn generate(
268        &self,
269        name: &String,
270        project: &mut ProjectContext,
271        cx: &mut WindowContext,
272    ) -> String {
273        match self {
274            ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
275            ToolFunctionCallResult::ParsingFailed => {
276                format!("Unable to parse arguments for {name}")
277            }
278            ToolFunctionCallResult::Finished { generate_fn, view } => {
279                (generate_fn)(view.clone(), project, cx)
280            }
281        }
282    }
283
284    fn into_any_element(&self, name: &String) -> AnyElement {
285        match self {
286            ToolFunctionCallResult::NoSuchTool => {
287                format!("Language Model attempted to call {name}").into_any_element()
288            }
289            ToolFunctionCallResult::ParsingFailed => {
290                format!("Language Model called {name} with bad arguments").into_any_element()
291            }
292            ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
293        }
294    }
295}
296
297impl Display for ToolFunctionDefinition {
298    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299        let schema = serde_json::to_string(&self.parameters).ok();
300        let schema = schema.unwrap_or("None".to_string());
301        write!(f, "Name: {}:\n", self.name)?;
302        write!(f, "Description: {}\n", self.description)?;
303        write!(f, "Parameters: {}", schema)
304    }
305}
306
307#[cfg(test)]
308mod test {
309    use super::*;
310    use gpui::{div, prelude::*, Render, TestAppContext};
311    use gpui::{EmptyView, View};
312    use schemars::schema_for;
313    use schemars::JsonSchema;
314    use serde::{Deserialize, Serialize};
315    use serde_json::json;
316
317    #[derive(Deserialize, Serialize, JsonSchema)]
318    struct WeatherQuery {
319        location: String,
320        unit: String,
321    }
322
323    struct WeatherTool {
324        current_weather: WeatherResult,
325    }
326
327    #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
328    struct WeatherResult {
329        location: String,
330        temperature: f64,
331        unit: String,
332    }
333
334    struct WeatherView {
335        result: WeatherResult,
336    }
337
338    impl Render for WeatherView {
339        fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
340            div().child(format!("temperature: {}", self.result.temperature))
341        }
342    }
343
344    impl ToolOutput for WeatherView {
345        fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
346            serde_json::to_string(&self.result).unwrap()
347        }
348    }
349
350    impl LanguageModelTool for WeatherTool {
351        type Input = WeatherQuery;
352        type Output = WeatherResult;
353        type View = WeatherView;
354
355        fn name(&self) -> String {
356            "get_current_weather".to_string()
357        }
358
359        fn description(&self) -> String {
360            "Fetches the current weather for a given location.".to_string()
361        }
362
363        fn execute(
364            &self,
365            input: &Self::Input,
366            _cx: &mut WindowContext,
367        ) -> Task<Result<Self::Output>> {
368            let _location = input.location.clone();
369            let _unit = input.unit.clone();
370
371            let weather = self.current_weather.clone();
372
373            Task::ready(Ok(weather))
374        }
375
376        fn output_view(
377            _input: Self::Input,
378            result: Result<Self::Output>,
379            cx: &mut WindowContext,
380        ) -> View<Self::View> {
381            cx.new_view(|_cx| {
382                let result = result.unwrap();
383                WeatherView { result }
384            })
385        }
386    }
387
388    #[gpui::test]
389    async fn test_openai_weather_example(cx: &mut TestAppContext) {
390        cx.background_executor.run_until_parked();
391        let (_, cx) = cx.add_window_view(|_cx| EmptyView);
392
393        let tool = WeatherTool {
394            current_weather: WeatherResult {
395                location: "San Francisco".to_string(),
396                temperature: 21.0,
397                unit: "Celsius".to_string(),
398            },
399        };
400
401        let tools = vec![tool.definition()];
402        assert_eq!(tools.len(), 1);
403
404        let expected = ToolFunctionDefinition {
405            name: "get_current_weather".to_string(),
406            description: "Fetches the current weather for a given location.".to_string(),
407            parameters: schema_for!(WeatherQuery),
408        };
409
410        assert_eq!(tools[0].name, expected.name);
411        assert_eq!(tools[0].description, expected.description);
412
413        let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
414
415        assert_eq!(
416            expected_schema,
417            json!({
418                "$schema": "http://json-schema.org/draft-07/schema#",
419                "title": "WeatherQuery",
420                "type": "object",
421                "properties": {
422                    "location": {
423                        "type": "string"
424                    },
425                    "unit": {
426                        "type": "string"
427                    }
428                },
429                "required": ["location", "unit"]
430            })
431        );
432
433        let args = json!({
434            "location": "San Francisco",
435            "unit": "Celsius"
436        });
437
438        let query: WeatherQuery = serde_json::from_value(args).unwrap();
439
440        let result = cx.update(|cx| tool.execute(&query, cx)).await;
441
442        assert!(result.is_ok());
443        let result = result.unwrap();
444
445        assert_eq!(result, tool.current_weather);
446    }
447}