1use anyhow::{anyhow, Result};
2use gpui::{div, AnyElement, IntoElement as _, ParentElement, Styled, Task, WindowContext};
3use std::{
4 any::TypeId,
5 collections::HashMap,
6 sync::atomic::{AtomicBool, Ordering::SeqCst},
7};
8
9use crate::tool::{
10 LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
11};
12
13// Internal Tool representation for the registry
14pub struct Tool {
15 enabled: AtomicBool,
16 type_id: TypeId,
17 call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
18 render_running: Box<dyn Fn(&mut WindowContext) -> gpui::AnyElement>,
19 definition: ToolFunctionDefinition,
20}
21
22impl Tool {
23 fn new(
24 type_id: TypeId,
25 call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
26 render_running: Box<dyn Fn(&mut WindowContext) -> gpui::AnyElement>,
27 definition: ToolFunctionDefinition,
28 ) -> Self {
29 Self {
30 enabled: AtomicBool::new(true),
31 type_id,
32 call,
33 render_running,
34 definition,
35 }
36 }
37}
38
39pub struct ToolRegistry {
40 tools: HashMap<String, Tool>,
41}
42
43impl ToolRegistry {
44 pub fn new() -> Self {
45 Self {
46 tools: HashMap::new(),
47 }
48 }
49
50 pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
51 for tool in self.tools.values() {
52 if tool.type_id == TypeId::of::<T>() {
53 tool.enabled.store(is_enabled, SeqCst);
54 return;
55 }
56 }
57 }
58
59 pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
60 for tool in self.tools.values() {
61 if tool.type_id == TypeId::of::<T>() {
62 return tool.enabled.load(SeqCst);
63 }
64 }
65 false
66 }
67
68 pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
69 self.tools
70 .values()
71 .filter(|tool| tool.enabled.load(SeqCst))
72 .map(|tool| tool.definition.clone())
73 .collect()
74 }
75
76 pub fn render_tool_call(
77 &self,
78 tool_call: &ToolFunctionCall,
79 cx: &mut WindowContext,
80 ) -> AnyElement {
81 match &tool_call.result {
82 Some(result) => div()
83 .p_2()
84 .child(result.into_any_element(&tool_call.name))
85 .into_any_element(),
86 None => self
87 .tools
88 .get(&tool_call.name)
89 .map(|tool| (tool.render_running)(cx))
90 .unwrap_or_else(|| div().into_any_element()),
91 }
92 }
93
94 pub fn register<T: 'static + LanguageModelTool>(
95 &mut self,
96 tool: T,
97 _cx: &mut WindowContext,
98 ) -> Result<()> {
99 let definition = tool.definition();
100
101 let name = tool.name();
102
103 let registered_tool = Tool::new(
104 TypeId::of::<T>(),
105 Box::new(
106 move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
107 let name = tool_call.name.clone();
108 let arguments = tool_call.arguments.clone();
109 let id = tool_call.id.clone();
110
111 let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
112 return Task::ready(Ok(ToolFunctionCall {
113 id,
114 name: name.clone(),
115 arguments,
116 result: Some(ToolFunctionCallResult::ParsingFailed),
117 }));
118 };
119
120 let result = tool.execute(&input, cx);
121
122 cx.spawn(move |mut cx| async move {
123 let result: Result<T::Output> = result.await;
124 let for_model = T::format(&input, &result);
125 let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?;
126
127 Ok(ToolFunctionCall {
128 id,
129 name: name.clone(),
130 arguments,
131 result: Some(ToolFunctionCallResult::Finished {
132 view: view.into(),
133 for_model,
134 }),
135 })
136 })
137 },
138 ),
139 Box::new(|cx| T::render_running(cx).into_any_element()),
140 definition,
141 );
142
143 let previous = self.tools.insert(name.clone(), registered_tool);
144
145 if previous.is_some() {
146 return Err(anyhow!("already registered a tool with name {}", name));
147 }
148
149 Ok(())
150 }
151
152 /// Task yields an error if the window for the given WindowContext is closed before the task completes.
153 pub fn call(
154 &self,
155 tool_call: &ToolFunctionCall,
156 cx: &mut WindowContext,
157 ) -> Task<Result<ToolFunctionCall>> {
158 let name = tool_call.name.clone();
159 let arguments = tool_call.arguments.clone();
160 let id = tool_call.id.clone();
161
162 let tool = match self.tools.get(&name) {
163 Some(tool) => tool,
164 None => {
165 let name = name.clone();
166 return Task::ready(Ok(ToolFunctionCall {
167 id,
168 name: name.clone(),
169 arguments,
170 result: Some(ToolFunctionCallResult::NoSuchTool),
171 }));
172 }
173 };
174
175 (tool.call)(tool_call, cx)
176 }
177}
178
179#[cfg(test)]
180mod test {
181 use super::*;
182 use gpui::{div, prelude::*, Render, TestAppContext};
183 use gpui::{EmptyView, View};
184 use schemars::schema_for;
185 use schemars::JsonSchema;
186 use serde::{Deserialize, Serialize};
187 use serde_json::json;
188
189 #[derive(Deserialize, Serialize, JsonSchema)]
190 struct WeatherQuery {
191 location: String,
192 unit: String,
193 }
194
195 struct WeatherTool {
196 current_weather: WeatherResult,
197 }
198
199 #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
200 struct WeatherResult {
201 location: String,
202 temperature: f64,
203 unit: String,
204 }
205
206 struct WeatherView {
207 result: WeatherResult,
208 }
209
210 impl Render for WeatherView {
211 fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
212 div().child(format!("temperature: {}", self.result.temperature))
213 }
214 }
215
216 impl LanguageModelTool for WeatherTool {
217 type Input = WeatherQuery;
218 type Output = WeatherResult;
219 type View = WeatherView;
220
221 fn name(&self) -> String {
222 "get_current_weather".to_string()
223 }
224
225 fn description(&self) -> String {
226 "Fetches the current weather for a given location.".to_string()
227 }
228
229 fn execute(
230 &self,
231 input: &Self::Input,
232 _cx: &mut WindowContext,
233 ) -> Task<Result<Self::Output>> {
234 let _location = input.location.clone();
235 let _unit = input.unit.clone();
236
237 let weather = self.current_weather.clone();
238
239 Task::ready(Ok(weather))
240 }
241
242 fn output_view(
243 _tool_call_id: String,
244 _input: Self::Input,
245 result: Result<Self::Output>,
246 cx: &mut WindowContext,
247 ) -> View<Self::View> {
248 cx.new_view(|_cx| {
249 let result = result.unwrap();
250 WeatherView { result }
251 })
252 }
253
254 fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
255 serde_json::to_string(&output.as_ref().unwrap()).unwrap()
256 }
257 }
258
259 #[gpui::test]
260 async fn test_openai_weather_example(cx: &mut TestAppContext) {
261 cx.background_executor.run_until_parked();
262 let (_, cx) = cx.add_window_view(|_cx| EmptyView);
263
264 let tool = WeatherTool {
265 current_weather: WeatherResult {
266 location: "San Francisco".to_string(),
267 temperature: 21.0,
268 unit: "Celsius".to_string(),
269 },
270 };
271
272 let tools = vec![tool.definition()];
273 assert_eq!(tools.len(), 1);
274
275 let expected = ToolFunctionDefinition {
276 name: "get_current_weather".to_string(),
277 description: "Fetches the current weather for a given location.".to_string(),
278 parameters: schema_for!(WeatherQuery),
279 };
280
281 assert_eq!(tools[0].name, expected.name);
282 assert_eq!(tools[0].description, expected.description);
283
284 let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
285
286 assert_eq!(
287 expected_schema,
288 json!({
289 "$schema": "http://json-schema.org/draft-07/schema#",
290 "title": "WeatherQuery",
291 "type": "object",
292 "properties": {
293 "location": {
294 "type": "string"
295 },
296 "unit": {
297 "type": "string"
298 }
299 },
300 "required": ["location", "unit"]
301 })
302 );
303
304 let args = json!({
305 "location": "San Francisco",
306 "unit": "Celsius"
307 });
308
309 let query: WeatherQuery = serde_json::from_value(args).unwrap();
310
311 let result = cx.update(|cx| tool.execute(&query, cx)).await;
312
313 assert!(result.is_ok());
314 let result = result.unwrap();
315
316 assert_eq!(result, tool.current_weather);
317 }
318}