tool_registry.rs

  1use crate::ProjectContext;
  2use anyhow::{anyhow, Result};
  3use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
  4use repair_json::repair;
  5use schemars::{schema::RootSchema, schema_for, JsonSchema};
  6use serde::{de::DeserializeOwned, Deserialize, Serialize};
  7use serde_json::value::RawValue;
  8use std::{
  9    any::TypeId,
 10    collections::HashMap,
 11    fmt::Display,
 12    mem,
 13    sync::atomic::{AtomicBool, Ordering::SeqCst},
 14};
 15use ui::ViewContext;
 16
 17pub struct ToolRegistry {
 18    registered_tools: HashMap<String, RegisteredTool>,
 19}
 20
 21#[derive(Default)]
 22pub struct ToolFunctionCall {
 23    pub id: String,
 24    pub name: String,
 25    pub arguments: String,
 26    state: ToolFunctionCallState,
 27}
 28
 29#[derive(Default)]
 30enum ToolFunctionCallState {
 31    #[default]
 32    Initializing,
 33    NoSuchTool,
 34    KnownTool(Box<dyn InternalToolView>),
 35    ExecutedTool(Box<dyn InternalToolView>),
 36}
 37
 38trait InternalToolView {
 39    fn view(&self) -> AnyView;
 40    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
 41    fn try_set_input(&self, input: &str, cx: &mut WindowContext);
 42    fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
 43    fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
 44    fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
 45}
 46
 47#[derive(Default, Serialize, Deserialize)]
 48pub struct SavedToolFunctionCall {
 49    id: String,
 50    name: String,
 51    arguments: String,
 52    state: SavedToolFunctionCallState,
 53}
 54
 55#[derive(Default, Serialize, Deserialize)]
 56enum SavedToolFunctionCallState {
 57    #[default]
 58    Initializing,
 59    NoSuchTool,
 60    KnownTool,
 61    ExecutedTool(Box<RawValue>),
 62}
 63
 64#[derive(Clone, Debug, PartialEq)]
 65pub struct ToolFunctionDefinition {
 66    pub name: String,
 67    pub description: String,
 68    pub parameters: RootSchema,
 69}
 70
 71pub trait LanguageModelTool {
 72    type View: ToolView;
 73
 74    /// Returns the name of the tool.
 75    ///
 76    /// This name is exposed to the language model to allow the model to pick
 77    /// which tools to use. As this name is used to identify the tool within a
 78    /// tool registry, it should be unique.
 79    fn name(&self) -> String;
 80
 81    /// Returns the description of the tool.
 82    ///
 83    /// This can be used to _prompt_ the model as to what the tool does.
 84    fn description(&self) -> String;
 85
 86    /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
 87    fn definition(&self) -> ToolFunctionDefinition {
 88        let root_schema = schema_for!(<Self::View as ToolView>::Input);
 89
 90        ToolFunctionDefinition {
 91            name: self.name(),
 92            description: self.description(),
 93            parameters: root_schema,
 94        }
 95    }
 96
 97    /// A view of the output of running the tool, for displaying to the user.
 98    fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
 99}
100
101pub trait ToolView: Render {
102    /// The input type that will be passed in to `execute` when the tool is called
103    /// by the language model.
104    type Input: DeserializeOwned + JsonSchema;
105
106    /// The output returned by executing the tool.
107    type SerializedState: DeserializeOwned + Serialize;
108
109    fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext<Self>) -> String;
110    fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>);
111    fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>>;
112
113    fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState;
114    fn deserialize(
115        &mut self,
116        output: Self::SerializedState,
117        cx: &mut ViewContext<Self>,
118    ) -> Result<()>;
119}
120
121struct RegisteredTool {
122    enabled: AtomicBool,
123    type_id: TypeId,
124    build_view: Box<dyn Fn(&mut WindowContext) -> Box<dyn InternalToolView>>,
125    definition: ToolFunctionDefinition,
126}
127
128impl ToolRegistry {
129    pub fn new() -> Self {
130        Self {
131            registered_tools: HashMap::new(),
132        }
133    }
134
135    pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
136        for tool in self.registered_tools.values() {
137            if tool.type_id == TypeId::of::<T>() {
138                tool.enabled.store(is_enabled, SeqCst);
139                return;
140            }
141        }
142    }
143
144    pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
145        for tool in self.registered_tools.values() {
146            if tool.type_id == TypeId::of::<T>() {
147                return tool.enabled.load(SeqCst);
148            }
149        }
150        false
151    }
152
153    pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
154        self.registered_tools
155            .values()
156            .filter(|tool| tool.enabled.load(SeqCst))
157            .map(|tool| tool.definition.clone())
158            .collect()
159    }
160
161    pub fn update_tool_call(
162        &self,
163        call: &mut ToolFunctionCall,
164        name: Option<&str>,
165        arguments: Option<&str>,
166        cx: &mut WindowContext,
167    ) {
168        if let Some(name) = name {
169            call.name.push_str(name);
170        }
171        if let Some(arguments) = arguments {
172            if call.arguments.is_empty() {
173                if let Some(tool) = self.registered_tools.get(&call.name) {
174                    let view = (tool.build_view)(cx);
175                    call.state = ToolFunctionCallState::KnownTool(view);
176                } else {
177                    call.state = ToolFunctionCallState::NoSuchTool;
178                }
179            }
180            call.arguments.push_str(arguments);
181
182            if let ToolFunctionCallState::KnownTool(view) = &call.state {
183                if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
184                    view.try_set_input(&repaired_arguments, cx)
185                }
186            }
187        }
188    }
189
190    pub fn execute_tool_call(
191        &self,
192        tool_call: &mut ToolFunctionCall,
193        cx: &mut WindowContext,
194    ) -> Option<Task<Result<()>>> {
195        if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) {
196            let task = view.execute(cx);
197            tool_call.state = ToolFunctionCallState::ExecutedTool(view);
198            Some(task)
199        } else {
200            None
201        }
202    }
203
204    pub fn render_tool_call(
205        &self,
206        tool_call: &ToolFunctionCall,
207        _cx: &mut WindowContext,
208    ) -> Option<AnyElement> {
209        match &tool_call.state {
210            ToolFunctionCallState::NoSuchTool => {
211                Some(ui::Label::new("No such tool").into_any_element())
212            }
213            ToolFunctionCallState::Initializing => None,
214            ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
215                Some(view.view().into_any_element())
216            }
217        }
218    }
219
220    pub fn content_for_tool_call(
221        &self,
222        tool_call: &ToolFunctionCall,
223        project_context: &mut ProjectContext,
224        cx: &mut WindowContext,
225    ) -> String {
226        match &tool_call.state {
227            ToolFunctionCallState::Initializing => String::new(),
228            ToolFunctionCallState::NoSuchTool => {
229                format!("No such tool: {}", tool_call.name)
230            }
231            ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
232                view.generate(project_context, cx)
233            }
234        }
235    }
236
237    pub fn serialize_tool_call(
238        &self,
239        call: &ToolFunctionCall,
240        cx: &mut WindowContext,
241    ) -> Result<SavedToolFunctionCall> {
242        Ok(SavedToolFunctionCall {
243            id: call.id.clone(),
244            name: call.name.clone(),
245            arguments: call.arguments.clone(),
246            state: match &call.state {
247                ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing,
248                ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool,
249                ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool,
250                ToolFunctionCallState::ExecutedTool(view) => {
251                    SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?)
252                }
253            },
254        })
255    }
256
257    pub fn deserialize_tool_call(
258        &self,
259        call: &SavedToolFunctionCall,
260        cx: &mut WindowContext,
261    ) -> Result<ToolFunctionCall> {
262        let Some(tool) = self.registered_tools.get(&call.name) else {
263            return Err(anyhow!("no such tool {}", call.name));
264        };
265
266        Ok(ToolFunctionCall {
267            id: call.id.clone(),
268            name: call.name.clone(),
269            arguments: call.arguments.clone(),
270            state: match &call.state {
271                SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing,
272                SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool,
273                SavedToolFunctionCallState::KnownTool => {
274                    log::error!("Deserialized tool that had not executed");
275                    let view = (tool.build_view)(cx);
276                    view.try_set_input(&call.arguments, cx);
277                    ToolFunctionCallState::KnownTool(view)
278                }
279                SavedToolFunctionCallState::ExecutedTool(output) => {
280                    let view = (tool.build_view)(cx);
281                    view.try_set_input(&call.arguments, cx);
282                    view.deserialize_output(output, cx)?;
283                    ToolFunctionCallState::ExecutedTool(view)
284                }
285            },
286        })
287    }
288
289    pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
290        let name = tool.name();
291        let registered_tool = RegisteredTool {
292            type_id: TypeId::of::<T>(),
293            definition: tool.definition(),
294            enabled: AtomicBool::new(true),
295            build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))),
296        };
297
298        let previous = self.registered_tools.insert(name.clone(), registered_tool);
299        if previous.is_some() {
300            return Err(anyhow!("already registered a tool with name {}", name));
301        }
302
303        return Ok(());
304    }
305}
306
307impl<T: ToolView> InternalToolView for View<T> {
308    fn view(&self) -> AnyView {
309        self.clone().into()
310    }
311
312    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
313        self.update(cx, |view, cx| view.generate(project, cx))
314    }
315
316    fn try_set_input(&self, input: &str, cx: &mut WindowContext) {
317        if let Ok(input) = serde_json::from_str::<T::Input>(input) {
318            self.update(cx, |view, cx| {
319                view.set_input(input, cx);
320                cx.notify();
321            });
322        }
323    }
324
325    fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>> {
326        self.update(cx, |view, cx| view.execute(cx))
327    }
328
329    fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>> {
330        let output = self.update(cx, |view, cx| view.serialize(cx));
331        Ok(RawValue::from_string(serde_json::to_string(&output)?)?)
332    }
333
334    fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> {
335        let state = serde_json::from_str::<T::SerializedState>(output.get())?;
336        self.update(cx, |view, cx| view.deserialize(state, cx))?;
337        Ok(())
338    }
339}
340
341impl Display for ToolFunctionDefinition {
342    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343        let schema = serde_json::to_string(&self.parameters).ok();
344        let schema = schema.unwrap_or("None".to_string());
345        write!(f, "Name: {}:\n", self.name)?;
346        write!(f, "Description: {}\n", self.description)?;
347        write!(f, "Parameters: {}", schema)
348    }
349}
350
351#[cfg(test)]
352mod test {
353    use super::*;
354    use gpui::{div, prelude::*, Render, TestAppContext};
355    use gpui::{EmptyView, View};
356    use schemars::JsonSchema;
357    use serde::{Deserialize, Serialize};
358    use serde_json::json;
359
360    #[derive(Deserialize, Serialize, JsonSchema)]
361    struct WeatherQuery {
362        location: String,
363        unit: String,
364    }
365
366    #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
367    struct WeatherResult {
368        location: String,
369        temperature: f64,
370        unit: String,
371    }
372
373    struct WeatherView {
374        input: Option<WeatherQuery>,
375        result: Option<WeatherResult>,
376
377        // Fake API call
378        current_weather: WeatherResult,
379    }
380
381    #[derive(Clone, Serialize)]
382    struct WeatherTool {
383        current_weather: WeatherResult,
384    }
385
386    impl WeatherView {
387        fn new(current_weather: WeatherResult) -> Self {
388            Self {
389                input: None,
390                result: None,
391                current_weather,
392            }
393        }
394    }
395
396    impl Render for WeatherView {
397        fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
398            match self.result {
399                Some(ref result) => div()
400                    .child(format!("temperature: {}", result.temperature))
401                    .into_any_element(),
402                None => div().child("Calculating weather...").into_any_element(),
403            }
404        }
405    }
406
407    impl ToolView for WeatherView {
408        type Input = WeatherQuery;
409
410        type SerializedState = WeatherResult;
411
412        fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
413            serde_json::to_string(&self.result).unwrap()
414        }
415
416        fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
417            self.input = Some(input);
418            cx.notify();
419        }
420
421        fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
422            let input = self.input.as_ref().unwrap();
423
424            let _location = input.location.clone();
425            let _unit = input.unit.clone();
426
427            let weather = self.current_weather.clone();
428
429            self.result = Some(weather);
430
431            Task::ready(Ok(()))
432        }
433
434        fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
435            self.current_weather.clone()
436        }
437
438        fn deserialize(
439            &mut self,
440            output: Self::SerializedState,
441            _cx: &mut ViewContext<Self>,
442        ) -> Result<()> {
443            self.current_weather = output;
444            Ok(())
445        }
446    }
447
448    impl LanguageModelTool for WeatherTool {
449        type View = WeatherView;
450
451        fn name(&self) -> String {
452            "get_current_weather".to_string()
453        }
454
455        fn description(&self) -> String {
456            "Fetches the current weather for a given location.".to_string()
457        }
458
459        fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
460            cx.new_view(|_cx| WeatherView::new(self.current_weather.clone()))
461        }
462    }
463
464    #[gpui::test]
465    async fn test_openai_weather_example(cx: &mut TestAppContext) {
466        let (_, cx) = cx.add_window_view(|_cx| EmptyView);
467
468        let mut registry = ToolRegistry::new();
469        registry
470            .register(WeatherTool {
471                current_weather: WeatherResult {
472                    location: "San Francisco".to_string(),
473                    temperature: 21.0,
474                    unit: "Celsius".to_string(),
475                },
476            })
477            .unwrap();
478
479        let definitions = registry.definitions();
480        assert_eq!(
481            definitions,
482            [ToolFunctionDefinition {
483                name: "get_current_weather".to_string(),
484                description: "Fetches the current weather for a given location.".to_string(),
485                parameters: serde_json::from_value(json!({
486                    "$schema": "http://json-schema.org/draft-07/schema#",
487                    "title": "WeatherQuery",
488                    "type": "object",
489                    "properties": {
490                        "location": {
491                            "type": "string"
492                        },
493                        "unit": {
494                            "type": "string"
495                        }
496                    },
497                    "required": ["location", "unit"]
498                }))
499                .unwrap(),
500            }]
501        );
502
503        let mut call = ToolFunctionCall {
504            id: "the-id".to_string(),
505            name: "get_cur".to_string(),
506            ..Default::default()
507        };
508
509        let task = cx.update(|cx| {
510            registry.update_tool_call(
511                &mut call,
512                Some("rent_weather"),
513                Some(r#"{"location": "San Francisco","#),
514                cx,
515            );
516            registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx);
517            registry.execute_tool_call(&mut call, cx).unwrap()
518        });
519        task.await.unwrap();
520
521        match &call.state {
522            ToolFunctionCallState::ExecutedTool(_view) => {}
523            _ => panic!(),
524        }
525    }
526}