1use anyhow::{anyhow, Result};
2use gpui::{Task, WindowContext};
3use std::collections::HashMap;
4
5use crate::tool::{
6 LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
7};
8
9pub struct ToolRegistry {
10 tools: HashMap<
11 String,
12 Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
13 >,
14 definitions: Vec<ToolFunctionDefinition>,
15}
16
17impl ToolRegistry {
18 pub fn new() -> Self {
19 Self {
20 tools: HashMap::new(),
21 definitions: Vec::new(),
22 }
23 }
24
25 pub fn definitions(&self) -> &[ToolFunctionDefinition] {
26 &self.definitions
27 }
28
29 pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
30 self.definitions.push(tool.definition());
31 let name = tool.name();
32 let previous = self.tools.insert(
33 name.clone(),
34 // registry.call(tool_call, cx)
35 Box::new(
36 move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
37 let name = tool_call.name.clone();
38 let arguments = tool_call.arguments.clone();
39 let id = tool_call.id.clone();
40
41 let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
42 return Task::ready(Ok(ToolFunctionCall {
43 id,
44 name: name.clone(),
45 arguments,
46 result: Some(ToolFunctionCallResult::ParsingFailed),
47 }));
48 };
49
50 let result = tool.execute(&input, cx);
51
52 cx.spawn(move |mut cx| async move {
53 let result: Result<T::Output> = result.await;
54 let for_model = T::format(&input, &result);
55 let view = cx.update(|cx| T::new_view(id.clone(), input, result, cx))?;
56
57 Ok(ToolFunctionCall {
58 id,
59 name: name.clone(),
60 arguments,
61 result: Some(ToolFunctionCallResult::Finished {
62 view: view.into(),
63 for_model,
64 }),
65 })
66 })
67 },
68 ),
69 );
70
71 if previous.is_some() {
72 return Err(anyhow!("already registered a tool with name {}", name));
73 }
74
75 Ok(())
76 }
77
78 /// Task yields an error if the window for the given WindowContext is closed before the task completes.
79 pub fn call(
80 &self,
81 tool_call: &ToolFunctionCall,
82 cx: &mut WindowContext,
83 ) -> Task<Result<ToolFunctionCall>> {
84 let name = tool_call.name.clone();
85 let arguments = tool_call.arguments.clone();
86 let id = tool_call.id.clone();
87
88 let tool = match self.tools.get(&name) {
89 Some(tool) => tool,
90 None => {
91 let name = name.clone();
92 return Task::ready(Ok(ToolFunctionCall {
93 id,
94 name: name.clone(),
95 arguments,
96 result: Some(ToolFunctionCallResult::NoSuchTool),
97 }));
98 }
99 };
100
101 tool(tool_call, cx)
102 }
103}
104
105#[cfg(test)]
106mod test {
107 use super::*;
108 use gpui::View;
109 use gpui::{div, prelude::*, Render, TestAppContext};
110 use schemars::schema_for;
111 use schemars::JsonSchema;
112 use serde::{Deserialize, Serialize};
113 use serde_json::json;
114
115 #[derive(Deserialize, Serialize, JsonSchema)]
116 struct WeatherQuery {
117 location: String,
118 unit: String,
119 }
120
121 struct WeatherTool {
122 current_weather: WeatherResult,
123 }
124
125 #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
126 struct WeatherResult {
127 location: String,
128 temperature: f64,
129 unit: String,
130 }
131
132 struct WeatherView {
133 result: WeatherResult,
134 }
135
136 impl Render for WeatherView {
137 fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
138 div().child(format!("temperature: {}", self.result.temperature))
139 }
140 }
141
142 impl LanguageModelTool for WeatherTool {
143 type Input = WeatherQuery;
144 type Output = WeatherResult;
145 type View = WeatherView;
146
147 fn name(&self) -> String {
148 "get_current_weather".to_string()
149 }
150
151 fn description(&self) -> String {
152 "Fetches the current weather for a given location.".to_string()
153 }
154
155 fn execute(
156 &self,
157 input: &Self::Input,
158 _cx: &gpui::AppContext,
159 ) -> Task<Result<Self::Output>> {
160 let _location = input.location.clone();
161 let _unit = input.unit.clone();
162
163 let weather = self.current_weather.clone();
164
165 Task::ready(Ok(weather))
166 }
167
168 fn new_view(
169 _tool_call_id: String,
170 _input: Self::Input,
171 result: Result<Self::Output>,
172 cx: &mut WindowContext,
173 ) -> View<Self::View> {
174 cx.new_view(|_cx| {
175 let result = result.unwrap();
176 WeatherView { result }
177 })
178 }
179
180 fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
181 serde_json::to_string(&output.as_ref().unwrap()).unwrap()
182 }
183 }
184
185 #[gpui::test]
186 async fn test_function_registry(cx: &mut TestAppContext) {
187 cx.background_executor.run_until_parked();
188
189 let mut registry = ToolRegistry::new();
190
191 let tool = WeatherTool {
192 current_weather: WeatherResult {
193 location: "San Francisco".to_string(),
194 temperature: 21.0,
195 unit: "Celsius".to_string(),
196 },
197 };
198
199 registry.register(tool).unwrap();
200
201 // let _result = cx
202 // .update(|cx| {
203 // registry.call(
204 // &ToolFunctionCall {
205 // name: "get_current_weather".to_string(),
206 // arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
207 // .to_string(),
208 // id: "test-123".to_string(),
209 // result: None,
210 // },
211 // cx,
212 // )
213 // })
214 // .await;
215
216 // assert!(result.is_ok());
217 // let result = result.unwrap();
218
219 // let expected = r#"{"location":"San Francisco","temperature":21.0,"unit":"Celsius"}"#;
220
221 // todo!(): Put this back in after the interface is stabilized
222 // assert_eq!(result, expected);
223 }
224
225 #[gpui::test]
226 async fn test_openai_weather_example(cx: &mut TestAppContext) {
227 cx.background_executor.run_until_parked();
228
229 let tool = WeatherTool {
230 current_weather: WeatherResult {
231 location: "San Francisco".to_string(),
232 temperature: 21.0,
233 unit: "Celsius".to_string(),
234 },
235 };
236
237 let tools = vec![tool.definition()];
238 assert_eq!(tools.len(), 1);
239
240 let expected = ToolFunctionDefinition {
241 name: "get_current_weather".to_string(),
242 description: "Fetches the current weather for a given location.".to_string(),
243 parameters: schema_for!(WeatherQuery),
244 };
245
246 assert_eq!(tools[0].name, expected.name);
247 assert_eq!(tools[0].description, expected.description);
248
249 let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
250
251 assert_eq!(
252 expected_schema,
253 json!({
254 "$schema": "http://json-schema.org/draft-07/schema#",
255 "title": "WeatherQuery",
256 "type": "object",
257 "properties": {
258 "location": {
259 "type": "string"
260 },
261 "unit": {
262 "type": "string"
263 }
264 },
265 "required": ["location", "unit"]
266 })
267 );
268
269 let args = json!({
270 "location": "San Francisco",
271 "unit": "Celsius"
272 });
273
274 let query: WeatherQuery = serde_json::from_value(args).unwrap();
275
276 let result = cx.update(|cx| tool.execute(&query, cx)).await;
277
278 assert!(result.is_ok());
279 let result = result.unwrap();
280
281 assert_eq!(result, tool.current_weather);
282 }
283}