1use anyhow::{anyhow, Result};
2use gpui::{AnyElement, AppContext, Task, WindowContext};
3use std::{any::Any, collections::HashMap};
4
5use crate::tool::{
6 LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
7};
8
9pub struct ToolRegistry {
10 tools: HashMap<String, Box<dyn Fn(&ToolFunctionCall, &AppContext) -> Task<ToolFunctionCall>>>,
11 definitions: Vec<ToolFunctionDefinition>,
12}
13
14impl ToolRegistry {
15 pub fn new() -> Self {
16 Self {
17 tools: HashMap::new(),
18 definitions: Vec::new(),
19 }
20 }
21
22 pub fn definitions(&self) -> &[ToolFunctionDefinition] {
23 &self.definitions
24 }
25
26 pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
27 fn render<T: 'static + LanguageModelTool>(
28 tool_call_id: &str,
29 input: &Box<dyn Any>,
30 output: &Box<dyn Any>,
31 cx: &mut WindowContext,
32 ) -> AnyElement {
33 T::render(
34 tool_call_id,
35 input.as_ref().downcast_ref::<T::Input>().unwrap(),
36 output.as_ref().downcast_ref::<T::Output>().unwrap(),
37 cx,
38 )
39 }
40
41 fn format<T: 'static + LanguageModelTool>(
42 input: &Box<dyn Any>,
43 output: &Box<dyn Any>,
44 ) -> String {
45 T::format(
46 input.as_ref().downcast_ref::<T::Input>().unwrap(),
47 output.as_ref().downcast_ref::<T::Output>().unwrap(),
48 )
49 }
50
51 self.definitions.push(tool.definition());
52 let name = tool.name();
53 let previous = self.tools.insert(
54 name.clone(),
55 Box::new(move |tool_call: &ToolFunctionCall, cx: &AppContext| {
56 let name = tool_call.name.clone();
57 let arguments = tool_call.arguments.clone();
58 let id = tool_call.id.clone();
59
60 let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
61 return Task::ready(ToolFunctionCall {
62 id,
63 name: name.clone(),
64 arguments,
65 result: Some(ToolFunctionCallResult::ParsingFailed),
66 });
67 };
68
69 let result = tool.execute(&input, cx);
70
71 cx.spawn(move |_cx| async move {
72 match result.await {
73 Ok(result) => {
74 let result: T::Output = result;
75 ToolFunctionCall {
76 id,
77 name: name.clone(),
78 arguments,
79 result: Some(ToolFunctionCallResult::Finished {
80 input: Box::new(input),
81 output: Box::new(result),
82 render_fn: render::<T>,
83 format_fn: format::<T>,
84 }),
85 }
86 }
87 Err(_error) => ToolFunctionCall {
88 id,
89 name: name.clone(),
90 arguments,
91 result: Some(ToolFunctionCallResult::ExecutionFailed {
92 input: Box::new(input),
93 }),
94 },
95 }
96 })
97 }),
98 );
99
100 if previous.is_some() {
101 return Err(anyhow!("already registered a tool with name {}", name));
102 }
103
104 Ok(())
105 }
106
107 pub fn call(&self, tool_call: &ToolFunctionCall, cx: &AppContext) -> Task<ToolFunctionCall> {
108 let name = tool_call.name.clone();
109 let arguments = tool_call.arguments.clone();
110 let id = tool_call.id.clone();
111
112 let tool = match self.tools.get(&name) {
113 Some(tool) => tool,
114 None => {
115 let name = name.clone();
116 return Task::ready(ToolFunctionCall {
117 id,
118 name: name.clone(),
119 arguments,
120 result: Some(ToolFunctionCallResult::NoSuchTool),
121 });
122 }
123 };
124
125 tool(tool_call, cx)
126 }
127}
128
129#[cfg(test)]
130mod test {
131
132 use super::*;
133
134 use schemars::schema_for;
135
136 use gpui::{div, AnyElement, Element, ParentElement, TestAppContext, WindowContext};
137 use schemars::JsonSchema;
138 use serde::{Deserialize, Serialize};
139 use serde_json::json;
140
141 #[derive(Deserialize, Serialize, JsonSchema)]
142 struct WeatherQuery {
143 location: String,
144 unit: String,
145 }
146
147 struct WeatherTool {
148 current_weather: WeatherResult,
149 }
150
151 #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
152 struct WeatherResult {
153 location: String,
154 temperature: f64,
155 unit: String,
156 }
157
158 impl LanguageModelTool for WeatherTool {
159 type Input = WeatherQuery;
160 type Output = WeatherResult;
161
162 fn name(&self) -> String {
163 "get_current_weather".to_string()
164 }
165
166 fn description(&self) -> String {
167 "Fetches the current weather for a given location.".to_string()
168 }
169
170 fn execute(&self, input: &WeatherQuery, _cx: &AppContext) -> Task<Result<Self::Output>> {
171 let _location = input.location.clone();
172 let _unit = input.unit.clone();
173
174 let weather = self.current_weather.clone();
175
176 Task::ready(Ok(weather))
177 }
178
179 fn render(
180 _tool_call_id: &str,
181 _input: &Self::Input,
182 output: &Self::Output,
183 _cx: &mut WindowContext,
184 ) -> AnyElement {
185 div()
186 .child(format!(
187 "The current temperature in {} is {} {}",
188 output.location, output.temperature, output.unit
189 ))
190 .into_any()
191 }
192
193 fn format(_input: &Self::Input, output: &Self::Output) -> String {
194 format!(
195 "The current temperature in {} is {} {}",
196 output.location, output.temperature, output.unit
197 )
198 }
199 }
200
201 #[gpui::test]
202 async fn test_function_registry(cx: &mut TestAppContext) {
203 cx.background_executor.run_until_parked();
204
205 let mut registry = ToolRegistry::new();
206
207 let tool = WeatherTool {
208 current_weather: WeatherResult {
209 location: "San Francisco".to_string(),
210 temperature: 21.0,
211 unit: "Celsius".to_string(),
212 },
213 };
214
215 registry.register(tool).unwrap();
216
217 let _result = cx
218 .update(|cx| {
219 registry.call(
220 &ToolFunctionCall {
221 name: "get_current_weather".to_string(),
222 arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
223 .to_string(),
224 id: "test-123".to_string(),
225 result: None,
226 },
227 cx,
228 )
229 })
230 .await;
231
232 // assert!(result.is_ok());
233 // let result = result.unwrap();
234
235 // let expected = r#"{"location":"San Francisco","temperature":21.0,"unit":"Celsius"}"#;
236
237 // todo!(): Put this back in after the interface is stabilized
238 // assert_eq!(result, expected);
239 }
240
241 #[gpui::test]
242 async fn test_openai_weather_example(cx: &mut TestAppContext) {
243 cx.background_executor.run_until_parked();
244
245 let tool = WeatherTool {
246 current_weather: WeatherResult {
247 location: "San Francisco".to_string(),
248 temperature: 21.0,
249 unit: "Celsius".to_string(),
250 },
251 };
252
253 let tools = vec![tool.definition()];
254 assert_eq!(tools.len(), 1);
255
256 let expected = ToolFunctionDefinition {
257 name: "get_current_weather".to_string(),
258 description: "Fetches the current weather for a given location.".to_string(),
259 parameters: schema_for!(WeatherQuery),
260 };
261
262 assert_eq!(tools[0].name, expected.name);
263 assert_eq!(tools[0].description, expected.description);
264
265 let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
266
267 assert_eq!(
268 expected_schema,
269 json!({
270 "$schema": "http://json-schema.org/draft-07/schema#",
271 "title": "WeatherQuery",
272 "type": "object",
273 "properties": {
274 "location": {
275 "type": "string"
276 },
277 "unit": {
278 "type": "string"
279 }
280 },
281 "required": ["location", "unit"]
282 })
283 );
284
285 let args = json!({
286 "location": "San Francisco",
287 "unit": "Celsius"
288 });
289
290 let query: WeatherQuery = serde_json::from_value(args).unwrap();
291
292 let result = cx.update(|cx| tool.execute(&query, cx)).await;
293
294 assert!(result.is_ok());
295 let result = result.unwrap();
296
297 assert_eq!(result, tool.current_weather);
298 }
299}