tool_registry.rs

  1use crate::ProjectContext;
  2use anyhow::{anyhow, Result};
  3use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
  4use repair_json::repair;
  5use schemars::{schema::RootSchema, schema_for, JsonSchema};
  6use serde::{de::DeserializeOwned, Deserialize, Serialize};
  7use serde_json::value::RawValue;
  8use std::{
  9    any::TypeId,
 10    collections::HashMap,
 11    fmt::Display,
 12    sync::{
 13        atomic::{AtomicBool, Ordering::SeqCst},
 14        Arc,
 15    },
 16};
 17use ui::ViewContext;
 18
 19pub struct ToolRegistry {
 20    registered_tools: HashMap<String, RegisteredTool>,
 21}
 22
 23#[derive(Default)]
 24pub struct ToolFunctionCall {
 25    pub id: String,
 26    pub name: String,
 27    pub arguments: String,
 28    state: ToolFunctionCallState,
 29}
 30
 31#[derive(Default)]
 32pub enum ToolFunctionCallState {
 33    #[default]
 34    Initializing,
 35    NoSuchTool,
 36    KnownTool(Box<dyn ToolView>),
 37    ExecutedTool(Box<dyn ToolView>),
 38}
 39
 40pub trait ToolView {
 41    fn view(&self) -> AnyView;
 42    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
 43    fn set_input(&self, input: &str, cx: &mut WindowContext);
 44    fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
 45    fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
 46    fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
 47}
 48
 49#[derive(Default, Serialize, Deserialize)]
 50pub struct SavedToolFunctionCall {
 51    pub id: String,
 52    pub name: String,
 53    pub arguments: String,
 54    pub state: SavedToolFunctionCallState,
 55}
 56
 57#[derive(Default, Serialize, Deserialize)]
 58pub enum SavedToolFunctionCallState {
 59    #[default]
 60    Initializing,
 61    NoSuchTool,
 62    KnownTool,
 63    ExecutedTool(Box<RawValue>),
 64}
 65
 66#[derive(Clone, Debug)]
 67pub struct ToolFunctionDefinition {
 68    pub name: String,
 69    pub description: String,
 70    pub parameters: RootSchema,
 71}
 72
 73pub trait LanguageModelTool {
 74    type View: ToolOutput;
 75
 76    /// Returns the name of the tool.
 77    ///
 78    /// This name is exposed to the language model to allow the model to pick
 79    /// which tools to use. As this name is used to identify the tool within a
 80    /// tool registry, it should be unique.
 81    fn name(&self) -> String;
 82
 83    /// Returns the description of the tool.
 84    ///
 85    /// This can be used to _prompt_ the model as to what the tool does.
 86    fn description(&self) -> String;
 87
 88    /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
 89    fn definition(&self) -> ToolFunctionDefinition {
 90        let root_schema = schema_for!(<Self::View as ToolOutput>::Input);
 91
 92        ToolFunctionDefinition {
 93            name: self.name(),
 94            description: self.description(),
 95            parameters: root_schema,
 96        }
 97    }
 98
 99    /// A view of the output of running the tool, for displaying to the user.
100    fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
101}
102
103pub fn tool_running_placeholder() -> AnyElement {
104    ui::Label::new("Researching...").into_any_element()
105}
106
107pub fn unknown_tool_placeholder() -> AnyElement {
108    ui::Label::new("Unknown tool").into_any_element()
109}
110
111pub fn no_such_tool_placeholder() -> AnyElement {
112    ui::Label::new("No such tool").into_any_element()
113}
114
115pub trait ToolOutput: Render {
116    /// The input type that will be passed in to `execute` when the tool is called
117    /// by the language model.
118    type Input: DeserializeOwned + JsonSchema;
119
120    /// The output returned by executing the tool.
121    type SerializedState: DeserializeOwned + Serialize;
122
123    fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext<Self>) -> String;
124    fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>);
125    fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>>;
126
127    fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState;
128    fn deserialize(
129        &mut self,
130        output: Self::SerializedState,
131        cx: &mut ViewContext<Self>,
132    ) -> Result<()>;
133}
134
135struct RegisteredTool {
136    enabled: AtomicBool,
137    type_id: TypeId,
138    build_view: Box<dyn Fn(&mut WindowContext) -> Box<dyn ToolView>>,
139    definition: ToolFunctionDefinition,
140}
141
142impl ToolRegistry {
143    pub fn new() -> Self {
144        Self {
145            registered_tools: HashMap::new(),
146        }
147    }
148
149    pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
150        for tool in self.registered_tools.values() {
151            if tool.type_id == TypeId::of::<T>() {
152                tool.enabled.store(is_enabled, SeqCst);
153                return;
154            }
155        }
156    }
157
158    pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
159        for tool in self.registered_tools.values() {
160            if tool.type_id == TypeId::of::<T>() {
161                return tool.enabled.load(SeqCst);
162            }
163        }
164        false
165    }
166
167    pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
168        self.registered_tools
169            .values()
170            .filter(|tool| tool.enabled.load(SeqCst))
171            .map(|tool| tool.definition.clone())
172            .collect()
173    }
174
175    pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option<Box<dyn ToolView>> {
176        let tool = self.registered_tools.get(name)?;
177        Some((tool.build_view)(cx))
178    }
179
180    pub fn update_tool_call(
181        &self,
182        call: &mut ToolFunctionCall,
183        name: Option<&str>,
184        arguments: Option<&str>,
185        cx: &mut WindowContext,
186    ) {
187        if let Some(name) = name {
188            call.name.push_str(name);
189        }
190        if let Some(arguments) = arguments {
191            if call.arguments.is_empty() {
192                if let Some(view) = self.view_for_tool(&call.name, cx) {
193                    call.state = ToolFunctionCallState::KnownTool(view);
194                } else {
195                    call.state = ToolFunctionCallState::NoSuchTool;
196                }
197            }
198            call.arguments.push_str(arguments);
199
200            if let ToolFunctionCallState::KnownTool(view) = &call.state {
201                if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
202                    view.set_input(&repaired_arguments, cx)
203                }
204            }
205        }
206    }
207
208    pub fn execute_tool_call(
209        &self,
210        tool_call: &ToolFunctionCall,
211        cx: &mut WindowContext,
212    ) -> Option<Task<Result<()>>> {
213        if let ToolFunctionCallState::KnownTool(view) = &tool_call.state {
214            Some(view.execute(cx))
215        } else {
216            None
217        }
218    }
219
220    pub fn render_tool_call(
221        &self,
222        tool_call: &ToolFunctionCall,
223        _cx: &mut WindowContext,
224    ) -> AnyElement {
225        match &tool_call.state {
226            ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(),
227            ToolFunctionCallState::Initializing => unknown_tool_placeholder(),
228            ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
229                view.view().into_any_element()
230            }
231        }
232    }
233
234    pub fn content_for_tool_call(
235        &self,
236        tool_call: &ToolFunctionCall,
237        project_context: &mut ProjectContext,
238        cx: &mut WindowContext,
239    ) -> String {
240        match &tool_call.state {
241            ToolFunctionCallState::Initializing => String::new(),
242            ToolFunctionCallState::NoSuchTool => {
243                format!("No such tool: {}", tool_call.name)
244            }
245            ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
246                view.generate(project_context, cx)
247            }
248        }
249    }
250
251    pub fn serialize_tool_call(
252        &self,
253        call: &ToolFunctionCall,
254        cx: &mut WindowContext,
255    ) -> Result<SavedToolFunctionCall> {
256        Ok(SavedToolFunctionCall {
257            id: call.id.clone(),
258            name: call.name.clone(),
259            arguments: call.arguments.clone(),
260            state: match &call.state {
261                ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing,
262                ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool,
263                ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool,
264                ToolFunctionCallState::ExecutedTool(view) => {
265                    SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?)
266                }
267            },
268        })
269    }
270
271    pub fn deserialize_tool_call(
272        &self,
273        call: &SavedToolFunctionCall,
274        cx: &mut WindowContext,
275    ) -> Result<ToolFunctionCall> {
276        let Some(tool) = self.registered_tools.get(&call.name) else {
277            return Err(anyhow!("no such tool {}", call.name));
278        };
279
280        Ok(ToolFunctionCall {
281            id: call.id.clone(),
282            name: call.name.clone(),
283            arguments: call.arguments.clone(),
284            state: match &call.state {
285                SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing,
286                SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool,
287                SavedToolFunctionCallState::KnownTool => {
288                    log::error!("Deserialized tool that had not executed");
289                    let view = (tool.build_view)(cx);
290                    view.set_input(&call.arguments, cx);
291                    ToolFunctionCallState::KnownTool(view)
292                }
293                SavedToolFunctionCallState::ExecutedTool(output) => {
294                    let view = (tool.build_view)(cx);
295                    view.set_input(&call.arguments, cx);
296                    view.deserialize_output(output, cx)?;
297                    ToolFunctionCallState::ExecutedTool(view)
298                }
299            },
300        })
301    }
302
303    pub fn register<T: 'static + LanguageModelTool>(
304        &mut self,
305        tool: T,
306        _cx: &mut WindowContext,
307    ) -> Result<()> {
308        let name = tool.name();
309        let tool = Arc::new(tool);
310        let registered_tool = RegisteredTool {
311            type_id: TypeId::of::<T>(),
312            definition: tool.definition(),
313            enabled: AtomicBool::new(true),
314            build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))),
315        };
316
317        let previous = self.registered_tools.insert(name.clone(), registered_tool);
318        if previous.is_some() {
319            return Err(anyhow!("already registered a tool with name {}", name));
320        }
321
322        return Ok(());
323    }
324}
325
326impl<T: ToolOutput> ToolView for View<T> {
327    fn view(&self) -> AnyView {
328        self.clone().into()
329    }
330
331    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
332        self.update(cx, |view, cx| view.generate(project, cx))
333    }
334
335    fn set_input(&self, input: &str, cx: &mut WindowContext) {
336        if let Ok(input) = serde_json::from_str::<T::Input>(input) {
337            self.update(cx, |view, cx| {
338                view.set_input(input, cx);
339                cx.notify();
340            });
341        }
342    }
343
344    fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>> {
345        self.update(cx, |view, cx| view.execute(cx))
346    }
347
348    fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>> {
349        let output = self.update(cx, |view, cx| view.serialize(cx));
350        Ok(RawValue::from_string(serde_json::to_string(&output)?)?)
351    }
352
353    fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> {
354        let state = serde_json::from_str::<T::SerializedState>(output.get())?;
355        self.update(cx, |view, cx| view.deserialize(state, cx))?;
356        Ok(())
357    }
358}
359
360impl Display for ToolFunctionDefinition {
361    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362        let schema = serde_json::to_string(&self.parameters).ok();
363        let schema = schema.unwrap_or("None".to_string());
364        write!(f, "Name: {}:\n", self.name)?;
365        write!(f, "Description: {}\n", self.description)?;
366        write!(f, "Parameters: {}", schema)
367    }
368}
369
370#[cfg(test)]
371mod test {
372    use super::*;
373    use gpui::{div, prelude::*, Render, TestAppContext};
374    use gpui::{EmptyView, View};
375    use schemars::schema_for;
376    use schemars::JsonSchema;
377    use serde::{Deserialize, Serialize};
378    use serde_json::json;
379
380    #[derive(Deserialize, Serialize, JsonSchema)]
381    struct WeatherQuery {
382        location: String,
383        unit: String,
384    }
385
386    #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
387    struct WeatherResult {
388        location: String,
389        temperature: f64,
390        unit: String,
391    }
392
393    struct WeatherView {
394        input: Option<WeatherQuery>,
395        result: Option<WeatherResult>,
396
397        // Fake API call
398        current_weather: WeatherResult,
399    }
400
401    #[derive(Clone, Serialize)]
402    struct WeatherTool {
403        current_weather: WeatherResult,
404    }
405
406    impl WeatherView {
407        fn new(current_weather: WeatherResult) -> Self {
408            Self {
409                input: None,
410                result: None,
411                current_weather,
412            }
413        }
414    }
415
416    impl Render for WeatherView {
417        fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
418            match self.result {
419                Some(ref result) => div()
420                    .child(format!("temperature: {}", result.temperature))
421                    .into_any_element(),
422                None => div().child("Calculating weather...").into_any_element(),
423            }
424        }
425    }
426
427    impl ToolOutput for WeatherView {
428        type Input = WeatherQuery;
429
430        type SerializedState = WeatherResult;
431
432        fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
433            serde_json::to_string(&self.result).unwrap()
434        }
435
436        fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
437            self.input = Some(input);
438            cx.notify();
439        }
440
441        fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
442            let input = self.input.as_ref().unwrap();
443
444            let _location = input.location.clone();
445            let _unit = input.unit.clone();
446
447            let weather = self.current_weather.clone();
448
449            self.result = Some(weather);
450
451            Task::ready(Ok(()))
452        }
453
454        fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
455            self.current_weather.clone()
456        }
457
458        fn deserialize(
459            &mut self,
460            output: Self::SerializedState,
461            _cx: &mut ViewContext<Self>,
462        ) -> Result<()> {
463            self.current_weather = output;
464            Ok(())
465        }
466    }
467
468    impl LanguageModelTool for WeatherTool {
469        type View = WeatherView;
470
471        fn name(&self) -> String {
472            "get_current_weather".to_string()
473        }
474
475        fn description(&self) -> String {
476            "Fetches the current weather for a given location.".to_string()
477        }
478
479        fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
480            cx.new_view(|_cx| WeatherView::new(self.current_weather.clone()))
481        }
482    }
483
484    #[gpui::test]
485    async fn test_openai_weather_example(cx: &mut TestAppContext) {
486        cx.background_executor.run_until_parked();
487        let (_, cx) = cx.add_window_view(|_cx| EmptyView);
488
489        let tool = WeatherTool {
490            current_weather: WeatherResult {
491                location: "San Francisco".to_string(),
492                temperature: 21.0,
493                unit: "Celsius".to_string(),
494            },
495        };
496
497        let tools = vec![tool.definition()];
498        assert_eq!(tools.len(), 1);
499
500        let expected = ToolFunctionDefinition {
501            name: "get_current_weather".to_string(),
502            description: "Fetches the current weather for a given location.".to_string(),
503            parameters: schema_for!(WeatherQuery),
504        };
505
506        assert_eq!(tools[0].name, expected.name);
507        assert_eq!(tools[0].description, expected.description);
508
509        let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
510
511        assert_eq!(
512            expected_schema,
513            json!({
514                "$schema": "http://json-schema.org/draft-07/schema#",
515                "title": "WeatherQuery",
516                "type": "object",
517                "properties": {
518                    "location": {
519                        "type": "string"
520                    },
521                    "unit": {
522                        "type": "string"
523                    }
524                },
525                "required": ["location", "unit"]
526            })
527        );
528
529        let view = cx.update(|cx| tool.view(cx));
530
531        cx.update(|cx| {
532            view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx);
533        });
534
535        let finished = cx.update(|cx| view.execute(cx)).await;
536
537        assert!(finished.is_ok());
538    }
539}