1use anyhow::{anyhow, Result};
2use gpui::{
3 div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
4};
5use schemars::{schema::RootSchema, schema_for, JsonSchema};
6use serde::Deserialize;
7use serde_json::Value;
8use std::{
9 any::TypeId,
10 collections::HashMap,
11 fmt::Display,
12 sync::atomic::{AtomicBool, Ordering::SeqCst},
13};
14
15use crate::ProjectContext;
16
17pub struct ToolRegistry {
18 registered_tools: HashMap<String, RegisteredTool>,
19}
20
21#[derive(Default, Deserialize)]
22pub struct ToolFunctionCall {
23 pub id: String,
24 pub name: String,
25 pub arguments: String,
26 #[serde(skip)]
27 pub result: Option<ToolFunctionCallResult>,
28}
29
30pub enum ToolFunctionCallResult {
31 NoSuchTool,
32 ParsingFailed,
33 Finished {
34 view: AnyView,
35 generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
36 },
37}
38
39#[derive(Clone)]
40pub struct ToolFunctionDefinition {
41 pub name: String,
42 pub description: String,
43 pub parameters: RootSchema,
44}
45
46pub trait LanguageModelTool {
47 /// The input type that will be passed in to `execute` when the tool is called
48 /// by the language model.
49 type Input: for<'de> Deserialize<'de> + JsonSchema;
50
51 /// The output returned by executing the tool.
52 type Output: 'static;
53
54 type View: Render + ToolOutput;
55
56 /// Returns the name of the tool.
57 ///
58 /// This name is exposed to the language model to allow the model to pick
59 /// which tools to use. As this name is used to identify the tool within a
60 /// tool registry, it should be unique.
61 fn name(&self) -> String;
62
63 /// Returns the description of the tool.
64 ///
65 /// This can be used to _prompt_ the model as to what the tool does.
66 fn description(&self) -> String;
67
68 /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
69 fn definition(&self) -> ToolFunctionDefinition {
70 let root_schema = schema_for!(Self::Input);
71
72 ToolFunctionDefinition {
73 name: self.name(),
74 description: self.description(),
75 parameters: root_schema,
76 }
77 }
78
79 /// Executes the tool with the given input.
80 fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
81
82 /// A view of the output of running the tool, for displaying to the user.
83 fn output_view(
84 input: Self::Input,
85 output: Result<Self::Output>,
86 cx: &mut WindowContext,
87 ) -> View<Self::View>;
88
89 fn render_running(_arguments: &Option<Value>, _cx: &mut WindowContext) -> impl IntoElement {
90 tool_running_placeholder()
91 }
92}
93
94pub fn tool_running_placeholder() -> AnyElement {
95 ui::Label::new("Researching...").into_any_element()
96}
97
98pub trait ToolOutput: Sized {
99 fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
100}
101
102struct RegisteredTool {
103 enabled: AtomicBool,
104 type_id: TypeId,
105 call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
106 render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
107 definition: ToolFunctionDefinition,
108}
109
110impl ToolRegistry {
111 pub fn new() -> Self {
112 Self {
113 registered_tools: HashMap::new(),
114 }
115 }
116
117 pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
118 for tool in self.registered_tools.values() {
119 if tool.type_id == TypeId::of::<T>() {
120 tool.enabled.store(is_enabled, SeqCst);
121 return;
122 }
123 }
124 }
125
126 pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
127 for tool in self.registered_tools.values() {
128 if tool.type_id == TypeId::of::<T>() {
129 return tool.enabled.load(SeqCst);
130 }
131 }
132 false
133 }
134
135 pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
136 self.registered_tools
137 .values()
138 .filter(|tool| tool.enabled.load(SeqCst))
139 .map(|tool| tool.definition.clone())
140 .collect()
141 }
142
143 pub fn render_tool_call(
144 &self,
145 tool_call: &ToolFunctionCall,
146 cx: &mut WindowContext,
147 ) -> AnyElement {
148 match &tool_call.result {
149 Some(result) => div()
150 .p_2()
151 .child(result.into_any_element(&tool_call.name))
152 .into_any_element(),
153 None => {
154 let tool = self.registered_tools.get(&tool_call.name);
155
156 if let Some(tool) = tool {
157 (tool.render_running)(&tool_call, cx)
158 } else {
159 tool_running_placeholder()
160 }
161 }
162 }
163 }
164
165 pub fn register<T: 'static + LanguageModelTool>(
166 &mut self,
167 tool: T,
168 _cx: &mut WindowContext,
169 ) -> Result<()> {
170 let name = tool.name();
171 let registered_tool = RegisteredTool {
172 type_id: TypeId::of::<T>(),
173 definition: tool.definition(),
174 enabled: AtomicBool::new(true),
175 call: Box::new(
176 move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
177 let name = tool_call.name.clone();
178 let arguments = tool_call.arguments.clone();
179 let id = tool_call.id.clone();
180
181 let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
182 return Task::ready(Ok(ToolFunctionCall {
183 id,
184 name: name.clone(),
185 arguments,
186 result: Some(ToolFunctionCallResult::ParsingFailed),
187 }));
188 };
189
190 let result = tool.execute(&input, cx);
191
192 cx.spawn(move |mut cx| async move {
193 let result: Result<T::Output> = result.await;
194 let view = cx.update(|cx| T::output_view(input, result, cx))?;
195
196 Ok(ToolFunctionCall {
197 id,
198 name: name.clone(),
199 arguments,
200 result: Some(ToolFunctionCallResult::Finished {
201 view: view.into(),
202 generate_fn: generate::<T>,
203 }),
204 })
205 })
206 },
207 ),
208 render_running: render_running::<T>,
209 };
210
211 let previous = self.registered_tools.insert(name.clone(), registered_tool);
212 if previous.is_some() {
213 return Err(anyhow!("already registered a tool with name {}", name));
214 }
215
216 return Ok(());
217
218 fn render_running<T: LanguageModelTool>(
219 tool_call: &ToolFunctionCall,
220 cx: &mut WindowContext,
221 ) -> AnyElement {
222 // Attempt to parse the string arguments that are JSON as a JSON value
223 let maybe_arguments = serde_json::to_value(tool_call.arguments.clone()).ok();
224
225 T::render_running(&maybe_arguments, cx).into_any_element()
226 }
227
228 fn generate<T: LanguageModelTool>(
229 view: AnyView,
230 project: &mut ProjectContext,
231 cx: &mut WindowContext,
232 ) -> String {
233 view.downcast::<T::View>()
234 .unwrap()
235 .update(cx, |view, cx| T::View::generate(view, project, cx))
236 }
237 }
238
239 /// Task yields an error if the window for the given WindowContext is closed before the task completes.
240 pub fn call(
241 &self,
242 tool_call: &ToolFunctionCall,
243 cx: &mut WindowContext,
244 ) -> Task<Result<ToolFunctionCall>> {
245 let name = tool_call.name.clone();
246 let arguments = tool_call.arguments.clone();
247 let id = tool_call.id.clone();
248
249 let tool = match self.registered_tools.get(&name) {
250 Some(tool) => tool,
251 None => {
252 let name = name.clone();
253 return Task::ready(Ok(ToolFunctionCall {
254 id,
255 name: name.clone(),
256 arguments,
257 result: Some(ToolFunctionCallResult::NoSuchTool),
258 }));
259 }
260 };
261
262 (tool.call)(tool_call, cx)
263 }
264}
265
266impl ToolFunctionCallResult {
267 pub fn generate(
268 &self,
269 name: &String,
270 project: &mut ProjectContext,
271 cx: &mut WindowContext,
272 ) -> String {
273 match self {
274 ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
275 ToolFunctionCallResult::ParsingFailed => {
276 format!("Unable to parse arguments for {name}")
277 }
278 ToolFunctionCallResult::Finished { generate_fn, view } => {
279 (generate_fn)(view.clone(), project, cx)
280 }
281 }
282 }
283
284 fn into_any_element(&self, name: &String) -> AnyElement {
285 match self {
286 ToolFunctionCallResult::NoSuchTool => {
287 format!("Language Model attempted to call {name}").into_any_element()
288 }
289 ToolFunctionCallResult::ParsingFailed => {
290 format!("Language Model called {name} with bad arguments").into_any_element()
291 }
292 ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
293 }
294 }
295}
296
297impl Display for ToolFunctionDefinition {
298 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299 let schema = serde_json::to_string(&self.parameters).ok();
300 let schema = schema.unwrap_or("None".to_string());
301 write!(f, "Name: {}:\n", self.name)?;
302 write!(f, "Description: {}\n", self.description)?;
303 write!(f, "Parameters: {}", schema)
304 }
305}
306
307#[cfg(test)]
308mod test {
309 use super::*;
310 use gpui::{div, prelude::*, Render, TestAppContext};
311 use gpui::{EmptyView, View};
312 use schemars::schema_for;
313 use schemars::JsonSchema;
314 use serde::{Deserialize, Serialize};
315 use serde_json::json;
316
317 #[derive(Deserialize, Serialize, JsonSchema)]
318 struct WeatherQuery {
319 location: String,
320 unit: String,
321 }
322
323 struct WeatherTool {
324 current_weather: WeatherResult,
325 }
326
327 #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
328 struct WeatherResult {
329 location: String,
330 temperature: f64,
331 unit: String,
332 }
333
334 struct WeatherView {
335 result: WeatherResult,
336 }
337
338 impl Render for WeatherView {
339 fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
340 div().child(format!("temperature: {}", self.result.temperature))
341 }
342 }
343
344 impl ToolOutput for WeatherView {
345 fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
346 serde_json::to_string(&self.result).unwrap()
347 }
348 }
349
350 impl LanguageModelTool for WeatherTool {
351 type Input = WeatherQuery;
352 type Output = WeatherResult;
353 type View = WeatherView;
354
355 fn name(&self) -> String {
356 "get_current_weather".to_string()
357 }
358
359 fn description(&self) -> String {
360 "Fetches the current weather for a given location.".to_string()
361 }
362
363 fn execute(
364 &self,
365 input: &Self::Input,
366 _cx: &mut WindowContext,
367 ) -> Task<Result<Self::Output>> {
368 let _location = input.location.clone();
369 let _unit = input.unit.clone();
370
371 let weather = self.current_weather.clone();
372
373 Task::ready(Ok(weather))
374 }
375
376 fn output_view(
377 _input: Self::Input,
378 result: Result<Self::Output>,
379 cx: &mut WindowContext,
380 ) -> View<Self::View> {
381 cx.new_view(|_cx| {
382 let result = result.unwrap();
383 WeatherView { result }
384 })
385 }
386 }
387
388 #[gpui::test]
389 async fn test_openai_weather_example(cx: &mut TestAppContext) {
390 cx.background_executor.run_until_parked();
391 let (_, cx) = cx.add_window_view(|_cx| EmptyView);
392
393 let tool = WeatherTool {
394 current_weather: WeatherResult {
395 location: "San Francisco".to_string(),
396 temperature: 21.0,
397 unit: "Celsius".to_string(),
398 },
399 };
400
401 let tools = vec![tool.definition()];
402 assert_eq!(tools.len(), 1);
403
404 let expected = ToolFunctionDefinition {
405 name: "get_current_weather".to_string(),
406 description: "Fetches the current weather for a given location.".to_string(),
407 parameters: schema_for!(WeatherQuery),
408 };
409
410 assert_eq!(tools[0].name, expected.name);
411 assert_eq!(tools[0].description, expected.description);
412
413 let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
414
415 assert_eq!(
416 expected_schema,
417 json!({
418 "$schema": "http://json-schema.org/draft-07/schema#",
419 "title": "WeatherQuery",
420 "type": "object",
421 "properties": {
422 "location": {
423 "type": "string"
424 },
425 "unit": {
426 "type": "string"
427 }
428 },
429 "required": ["location", "unit"]
430 })
431 );
432
433 let args = json!({
434 "location": "San Francisco",
435 "unit": "Celsius"
436 });
437
438 let query: WeatherQuery = serde_json::from_value(args).unwrap();
439
440 let result = cx.update(|cx| tool.execute(&query, cx)).await;
441
442 assert!(result.is_ok());
443 let result = result.unwrap();
444
445 assert_eq!(result, tool.current_weather);
446 }
447}