1use anyhow::{anyhow, Result};
2use gpui::{AnyView, Task, WindowContext};
3use std::collections::HashMap;
4
5use crate::tool::{
6 LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
7};
8
9pub struct ToolRegistry {
10 tools: HashMap<
11 String,
12 Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
13 >,
14 definitions: Vec<ToolFunctionDefinition>,
15 status_views: Vec<AnyView>,
16}
17
18impl ToolRegistry {
19 pub fn new() -> Self {
20 Self {
21 tools: HashMap::new(),
22 definitions: Vec::new(),
23 status_views: Vec::new(),
24 }
25 }
26
27 pub fn definitions(&self) -> &[ToolFunctionDefinition] {
28 &self.definitions
29 }
30
31 pub fn register<T: 'static + LanguageModelTool>(
32 &mut self,
33 tool: T,
34 cx: &mut WindowContext,
35 ) -> Result<()> {
36 self.definitions.push(tool.definition());
37
38 if let Some(tool_view) = tool.status_view(cx) {
39 self.status_views.push(tool_view);
40 }
41
42 let name = tool.name();
43 let previous = self.tools.insert(
44 name.clone(),
45 // registry.call(tool_call, cx)
46 Box::new(
47 move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
48 let name = tool_call.name.clone();
49 let arguments = tool_call.arguments.clone();
50 let id = tool_call.id.clone();
51
52 let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
53 return Task::ready(Ok(ToolFunctionCall {
54 id,
55 name: name.clone(),
56 arguments,
57 result: Some(ToolFunctionCallResult::ParsingFailed),
58 }));
59 };
60
61 let result = tool.execute(&input, cx);
62
63 cx.spawn(move |mut cx| async move {
64 let result: Result<T::Output> = result.await;
65 let for_model = T::format(&input, &result);
66 let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?;
67
68 Ok(ToolFunctionCall {
69 id,
70 name: name.clone(),
71 arguments,
72 result: Some(ToolFunctionCallResult::Finished {
73 view: view.into(),
74 for_model,
75 }),
76 })
77 })
78 },
79 ),
80 );
81
82 if previous.is_some() {
83 return Err(anyhow!("already registered a tool with name {}", name));
84 }
85
86 Ok(())
87 }
88
89 /// Task yields an error if the window for the given WindowContext is closed before the task completes.
90 pub fn call(
91 &self,
92 tool_call: &ToolFunctionCall,
93 cx: &mut WindowContext,
94 ) -> Task<Result<ToolFunctionCall>> {
95 let name = tool_call.name.clone();
96 let arguments = tool_call.arguments.clone();
97 let id = tool_call.id.clone();
98
99 let tool = match self.tools.get(&name) {
100 Some(tool) => tool,
101 None => {
102 let name = name.clone();
103 return Task::ready(Ok(ToolFunctionCall {
104 id,
105 name: name.clone(),
106 arguments,
107 result: Some(ToolFunctionCallResult::NoSuchTool),
108 }));
109 }
110 };
111
112 tool(tool_call, cx)
113 }
114
115 pub fn status_views(&self) -> &[AnyView] {
116 &self.status_views
117 }
118}
119
120#[cfg(test)]
121mod test {
122 use super::*;
123 use gpui::{div, prelude::*, Render, TestAppContext};
124 use gpui::{EmptyView, View};
125 use schemars::schema_for;
126 use schemars::JsonSchema;
127 use serde::{Deserialize, Serialize};
128 use serde_json::json;
129
130 #[derive(Deserialize, Serialize, JsonSchema)]
131 struct WeatherQuery {
132 location: String,
133 unit: String,
134 }
135
136 struct WeatherTool {
137 current_weather: WeatherResult,
138 }
139
140 #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
141 struct WeatherResult {
142 location: String,
143 temperature: f64,
144 unit: String,
145 }
146
147 struct WeatherView {
148 result: WeatherResult,
149 }
150
151 impl Render for WeatherView {
152 fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
153 div().child(format!("temperature: {}", self.result.temperature))
154 }
155 }
156
157 impl LanguageModelTool for WeatherTool {
158 type Input = WeatherQuery;
159 type Output = WeatherResult;
160 type View = WeatherView;
161
162 fn name(&self) -> String {
163 "get_current_weather".to_string()
164 }
165
166 fn description(&self) -> String {
167 "Fetches the current weather for a given location.".to_string()
168 }
169
170 fn execute(
171 &self,
172 input: &Self::Input,
173 _cx: &mut WindowContext,
174 ) -> Task<Result<Self::Output>> {
175 let _location = input.location.clone();
176 let _unit = input.unit.clone();
177
178 let weather = self.current_weather.clone();
179
180 Task::ready(Ok(weather))
181 }
182
183 fn output_view(
184 _tool_call_id: String,
185 _input: Self::Input,
186 result: Result<Self::Output>,
187 cx: &mut WindowContext,
188 ) -> View<Self::View> {
189 cx.new_view(|_cx| {
190 let result = result.unwrap();
191 WeatherView { result }
192 })
193 }
194
195 fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
196 serde_json::to_string(&output.as_ref().unwrap()).unwrap()
197 }
198 }
199
200 #[gpui::test]
201 async fn test_openai_weather_example(cx: &mut TestAppContext) {
202 cx.background_executor.run_until_parked();
203 let (_, cx) = cx.add_window_view(|_cx| EmptyView);
204
205 let tool = WeatherTool {
206 current_weather: WeatherResult {
207 location: "San Francisco".to_string(),
208 temperature: 21.0,
209 unit: "Celsius".to_string(),
210 },
211 };
212
213 let tools = vec![tool.definition()];
214 assert_eq!(tools.len(), 1);
215
216 let expected = ToolFunctionDefinition {
217 name: "get_current_weather".to_string(),
218 description: "Fetches the current weather for a given location.".to_string(),
219 parameters: schema_for!(WeatherQuery),
220 };
221
222 assert_eq!(tools[0].name, expected.name);
223 assert_eq!(tools[0].description, expected.description);
224
225 let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
226
227 assert_eq!(
228 expected_schema,
229 json!({
230 "$schema": "http://json-schema.org/draft-07/schema#",
231 "title": "WeatherQuery",
232 "type": "object",
233 "properties": {
234 "location": {
235 "type": "string"
236 },
237 "unit": {
238 "type": "string"
239 }
240 },
241 "required": ["location", "unit"]
242 })
243 );
244
245 let args = json!({
246 "location": "San Francisco",
247 "unit": "Celsius"
248 });
249
250 let query: WeatherQuery = serde_json::from_value(args).unwrap();
251
252 let result = cx.update(|cx| tool.execute(&query, cx)).await;
253
254 assert!(result.is_ok());
255 let result = result.unwrap();
256
257 assert_eq!(result, tool.current_weather);
258 }
259}