1use anyhow::{anyhow, Result};
2use gpui::{Task, WindowContext};
3use std::{
4 any::TypeId,
5 collections::HashMap,
6 sync::atomic::{AtomicBool, Ordering::SeqCst},
7};
8
9use crate::tool::{
10 LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
11};
12
13// Internal Tool representation for the registry
14pub struct Tool {
15 enabled: AtomicBool,
16 type_id: TypeId,
17 call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
18 definition: ToolFunctionDefinition,
19}
20
21impl Tool {
22 fn new(
23 type_id: TypeId,
24 call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
25 definition: ToolFunctionDefinition,
26 ) -> Self {
27 Self {
28 enabled: AtomicBool::new(true),
29 type_id,
30 call,
31 definition,
32 }
33 }
34}
35
36pub struct ToolRegistry {
37 tools: HashMap<String, Tool>,
38}
39
40impl ToolRegistry {
41 pub fn new() -> Self {
42 Self {
43 tools: HashMap::new(),
44 }
45 }
46
47 pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
48 for tool in self.tools.values() {
49 if tool.type_id == TypeId::of::<T>() {
50 tool.enabled.store(is_enabled, SeqCst);
51 return;
52 }
53 }
54 }
55
56 pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
57 for tool in self.tools.values() {
58 if tool.type_id == TypeId::of::<T>() {
59 return tool.enabled.load(SeqCst);
60 }
61 }
62 false
63 }
64
65 pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
66 self.tools
67 .values()
68 .filter(|tool| tool.enabled.load(SeqCst))
69 .map(|tool| tool.definition.clone())
70 .collect()
71 }
72
73 pub fn register<T: 'static + LanguageModelTool>(
74 &mut self,
75 tool: T,
76 _cx: &mut WindowContext,
77 ) -> Result<()> {
78 let definition = tool.definition();
79
80 let name = tool.name();
81
82 let registered_tool = Tool::new(
83 TypeId::of::<T>(),
84 Box::new(
85 move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
86 let name = tool_call.name.clone();
87 let arguments = tool_call.arguments.clone();
88 let id = tool_call.id.clone();
89
90 let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
91 return Task::ready(Ok(ToolFunctionCall {
92 id,
93 name: name.clone(),
94 arguments,
95 result: Some(ToolFunctionCallResult::ParsingFailed),
96 }));
97 };
98
99 let result = tool.execute(&input, cx);
100
101 cx.spawn(move |mut cx| async move {
102 let result: Result<T::Output> = result.await;
103 let for_model = T::format(&input, &result);
104 let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?;
105
106 Ok(ToolFunctionCall {
107 id,
108 name: name.clone(),
109 arguments,
110 result: Some(ToolFunctionCallResult::Finished {
111 view: view.into(),
112 for_model,
113 }),
114 })
115 })
116 },
117 ),
118 definition,
119 );
120
121 let previous = self.tools.insert(name.clone(), registered_tool);
122
123 if previous.is_some() {
124 return Err(anyhow!("already registered a tool with name {}", name));
125 }
126
127 Ok(())
128 }
129
130 /// Task yields an error if the window for the given WindowContext is closed before the task completes.
131 pub fn call(
132 &self,
133 tool_call: &ToolFunctionCall,
134 cx: &mut WindowContext,
135 ) -> Task<Result<ToolFunctionCall>> {
136 let name = tool_call.name.clone();
137 let arguments = tool_call.arguments.clone();
138 let id = tool_call.id.clone();
139
140 let tool = match self.tools.get(&name) {
141 Some(tool) => tool,
142 None => {
143 let name = name.clone();
144 return Task::ready(Ok(ToolFunctionCall {
145 id,
146 name: name.clone(),
147 arguments,
148 result: Some(ToolFunctionCallResult::NoSuchTool),
149 }));
150 }
151 };
152
153 (tool.call)(tool_call, cx)
154 }
155}
156
157#[cfg(test)]
158mod test {
159 use super::*;
160 use gpui::{div, prelude::*, Render, TestAppContext};
161 use gpui::{EmptyView, View};
162 use schemars::schema_for;
163 use schemars::JsonSchema;
164 use serde::{Deserialize, Serialize};
165 use serde_json::json;
166
167 #[derive(Deserialize, Serialize, JsonSchema)]
168 struct WeatherQuery {
169 location: String,
170 unit: String,
171 }
172
173 struct WeatherTool {
174 current_weather: WeatherResult,
175 }
176
177 #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
178 struct WeatherResult {
179 location: String,
180 temperature: f64,
181 unit: String,
182 }
183
184 struct WeatherView {
185 result: WeatherResult,
186 }
187
188 impl Render for WeatherView {
189 fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
190 div().child(format!("temperature: {}", self.result.temperature))
191 }
192 }
193
194 impl LanguageModelTool for WeatherTool {
195 type Input = WeatherQuery;
196 type Output = WeatherResult;
197 type View = WeatherView;
198
199 fn name(&self) -> String {
200 "get_current_weather".to_string()
201 }
202
203 fn description(&self) -> String {
204 "Fetches the current weather for a given location.".to_string()
205 }
206
207 fn execute(
208 &self,
209 input: &Self::Input,
210 _cx: &mut WindowContext,
211 ) -> Task<Result<Self::Output>> {
212 let _location = input.location.clone();
213 let _unit = input.unit.clone();
214
215 let weather = self.current_weather.clone();
216
217 Task::ready(Ok(weather))
218 }
219
220 fn output_view(
221 _tool_call_id: String,
222 _input: Self::Input,
223 result: Result<Self::Output>,
224 cx: &mut WindowContext,
225 ) -> View<Self::View> {
226 cx.new_view(|_cx| {
227 let result = result.unwrap();
228 WeatherView { result }
229 })
230 }
231
232 fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
233 serde_json::to_string(&output.as_ref().unwrap()).unwrap()
234 }
235 }
236
237 #[gpui::test]
238 async fn test_openai_weather_example(cx: &mut TestAppContext) {
239 cx.background_executor.run_until_parked();
240 let (_, cx) = cx.add_window_view(|_cx| EmptyView);
241
242 let tool = WeatherTool {
243 current_weather: WeatherResult {
244 location: "San Francisco".to_string(),
245 temperature: 21.0,
246 unit: "Celsius".to_string(),
247 },
248 };
249
250 let tools = vec![tool.definition()];
251 assert_eq!(tools.len(), 1);
252
253 let expected = ToolFunctionDefinition {
254 name: "get_current_weather".to_string(),
255 description: "Fetches the current weather for a given location.".to_string(),
256 parameters: schema_for!(WeatherQuery),
257 };
258
259 assert_eq!(tools[0].name, expected.name);
260 assert_eq!(tools[0].description, expected.description);
261
262 let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
263
264 assert_eq!(
265 expected_schema,
266 json!({
267 "$schema": "http://json-schema.org/draft-07/schema#",
268 "title": "WeatherQuery",
269 "type": "object",
270 "properties": {
271 "location": {
272 "type": "string"
273 },
274 "unit": {
275 "type": "string"
276 }
277 },
278 "required": ["location", "unit"]
279 })
280 );
281
282 let args = json!({
283 "location": "San Francisco",
284 "unit": "Celsius"
285 });
286
287 let query: WeatherQuery = serde_json::from_value(args).unwrap();
288
289 let result = cx.update(|cx| tool.execute(&query, cx)).await;
290
291 assert!(result.is_ok());
292 let result = result.unwrap();
293
294 assert_eq!(result, tool.current_weather);
295 }
296}