1use anyhow::{anyhow, Result};
2use gpui::{
3 div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
4};
5use schemars::{schema::RootSchema, schema_for, JsonSchema};
6use serde::Deserialize;
7use std::{
8 any::TypeId,
9 collections::HashMap,
10 fmt::Display,
11 sync::atomic::{AtomicBool, Ordering::SeqCst},
12};
13
14use crate::ProjectContext;
15
16pub struct ToolRegistry {
17 registered_tools: HashMap<String, RegisteredTool>,
18}
19
20#[derive(Default, Deserialize)]
21pub struct ToolFunctionCall {
22 pub id: String,
23 pub name: String,
24 pub arguments: String,
25 #[serde(skip)]
26 pub result: Option<ToolFunctionCallResult>,
27}
28
29pub enum ToolFunctionCallResult {
30 NoSuchTool,
31 ParsingFailed,
32 Finished {
33 view: AnyView,
34 generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
35 },
36}
37
38#[derive(Clone)]
39pub struct ToolFunctionDefinition {
40 pub name: String,
41 pub description: String,
42 pub parameters: RootSchema,
43}
44
45pub trait LanguageModelTool {
46 /// The input type that will be passed in to `execute` when the tool is called
47 /// by the language model.
48 type Input: for<'de> Deserialize<'de> + JsonSchema;
49
50 /// The output returned by executing the tool.
51 type Output: 'static;
52
53 type View: Render + ToolOutput;
54
55 /// Returns the name of the tool.
56 ///
57 /// This name is exposed to the language model to allow the model to pick
58 /// which tools to use. As this name is used to identify the tool within a
59 /// tool registry, it should be unique.
60 fn name(&self) -> String;
61
62 /// Returns the description of the tool.
63 ///
64 /// This can be used to _prompt_ the model as to what the tool does.
65 fn description(&self) -> String;
66
67 /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
68 fn definition(&self) -> ToolFunctionDefinition {
69 let root_schema = schema_for!(Self::Input);
70
71 ToolFunctionDefinition {
72 name: self.name(),
73 description: self.description(),
74 parameters: root_schema,
75 }
76 }
77
78 /// Executes the tool with the given input.
79 fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
80
81 fn output_view(
82 input: Self::Input,
83 output: Result<Self::Output>,
84 cx: &mut WindowContext,
85 ) -> View<Self::View>;
86
87 fn render_running(_cx: &mut WindowContext) -> impl IntoElement {
88 div()
89 }
90}
91
92pub trait ToolOutput: Sized {
93 fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
94}
95
96struct RegisteredTool {
97 enabled: AtomicBool,
98 type_id: TypeId,
99 call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
100 render_running: fn(&mut WindowContext) -> gpui::AnyElement,
101 definition: ToolFunctionDefinition,
102}
103
104impl ToolRegistry {
105 pub fn new() -> Self {
106 Self {
107 registered_tools: HashMap::new(),
108 }
109 }
110
111 pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
112 for tool in self.registered_tools.values() {
113 if tool.type_id == TypeId::of::<T>() {
114 tool.enabled.store(is_enabled, SeqCst);
115 return;
116 }
117 }
118 }
119
120 pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
121 for tool in self.registered_tools.values() {
122 if tool.type_id == TypeId::of::<T>() {
123 return tool.enabled.load(SeqCst);
124 }
125 }
126 false
127 }
128
129 pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
130 self.registered_tools
131 .values()
132 .filter(|tool| tool.enabled.load(SeqCst))
133 .map(|tool| tool.definition.clone())
134 .collect()
135 }
136
137 pub fn render_tool_call(
138 &self,
139 tool_call: &ToolFunctionCall,
140 cx: &mut WindowContext,
141 ) -> AnyElement {
142 match &tool_call.result {
143 Some(result) => div()
144 .p_2()
145 .child(result.into_any_element(&tool_call.name))
146 .into_any_element(),
147 None => self
148 .registered_tools
149 .get(&tool_call.name)
150 .map(|tool| (tool.render_running)(cx))
151 .unwrap_or_else(|| div().into_any_element()),
152 }
153 }
154
155 pub fn register<T: 'static + LanguageModelTool>(
156 &mut self,
157 tool: T,
158 _cx: &mut WindowContext,
159 ) -> Result<()> {
160 let name = tool.name();
161 let registered_tool = RegisteredTool {
162 type_id: TypeId::of::<T>(),
163 definition: tool.definition(),
164 enabled: AtomicBool::new(true),
165 call: Box::new(
166 move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
167 let name = tool_call.name.clone();
168 let arguments = tool_call.arguments.clone();
169 let id = tool_call.id.clone();
170
171 let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
172 return Task::ready(Ok(ToolFunctionCall {
173 id,
174 name: name.clone(),
175 arguments,
176 result: Some(ToolFunctionCallResult::ParsingFailed),
177 }));
178 };
179
180 let result = tool.execute(&input, cx);
181
182 cx.spawn(move |mut cx| async move {
183 let result: Result<T::Output> = result.await;
184 let view = cx.update(|cx| T::output_view(input, result, cx))?;
185
186 Ok(ToolFunctionCall {
187 id,
188 name: name.clone(),
189 arguments,
190 result: Some(ToolFunctionCallResult::Finished {
191 view: view.into(),
192 generate_fn: generate::<T>,
193 }),
194 })
195 })
196 },
197 ),
198 render_running: render_running::<T>,
199 };
200
201 let previous = self.registered_tools.insert(name.clone(), registered_tool);
202 if previous.is_some() {
203 return Err(anyhow!("already registered a tool with name {}", name));
204 }
205
206 return Ok(());
207
208 fn render_running<T: LanguageModelTool>(cx: &mut WindowContext) -> AnyElement {
209 T::render_running(cx).into_any_element()
210 }
211
212 fn generate<T: LanguageModelTool>(
213 view: AnyView,
214 project: &mut ProjectContext,
215 cx: &mut WindowContext,
216 ) -> String {
217 view.downcast::<T::View>()
218 .unwrap()
219 .update(cx, |view, cx| T::View::generate(view, project, cx))
220 }
221 }
222
223 /// Task yields an error if the window for the given WindowContext is closed before the task completes.
224 pub fn call(
225 &self,
226 tool_call: &ToolFunctionCall,
227 cx: &mut WindowContext,
228 ) -> Task<Result<ToolFunctionCall>> {
229 let name = tool_call.name.clone();
230 let arguments = tool_call.arguments.clone();
231 let id = tool_call.id.clone();
232
233 let tool = match self.registered_tools.get(&name) {
234 Some(tool) => tool,
235 None => {
236 let name = name.clone();
237 return Task::ready(Ok(ToolFunctionCall {
238 id,
239 name: name.clone(),
240 arguments,
241 result: Some(ToolFunctionCallResult::NoSuchTool),
242 }));
243 }
244 };
245
246 (tool.call)(tool_call, cx)
247 }
248}
249
250impl ToolFunctionCallResult {
251 pub fn generate(
252 &self,
253 name: &String,
254 project: &mut ProjectContext,
255 cx: &mut WindowContext,
256 ) -> String {
257 match self {
258 ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
259 ToolFunctionCallResult::ParsingFailed => {
260 format!("Unable to parse arguments for {name}")
261 }
262 ToolFunctionCallResult::Finished { generate_fn, view } => {
263 (generate_fn)(view.clone(), project, cx)
264 }
265 }
266 }
267
268 fn into_any_element(&self, name: &String) -> AnyElement {
269 match self {
270 ToolFunctionCallResult::NoSuchTool => {
271 format!("Language Model attempted to call {name}").into_any_element()
272 }
273 ToolFunctionCallResult::ParsingFailed => {
274 format!("Language Model called {name} with bad arguments").into_any_element()
275 }
276 ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
277 }
278 }
279}
280
281impl Display for ToolFunctionDefinition {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 let schema = serde_json::to_string(&self.parameters).ok();
284 let schema = schema.unwrap_or("None".to_string());
285 write!(f, "Name: {}:\n", self.name)?;
286 write!(f, "Description: {}\n", self.description)?;
287 write!(f, "Parameters: {}", schema)
288 }
289}
290
291#[cfg(test)]
292mod test {
293 use super::*;
294 use gpui::{div, prelude::*, Render, TestAppContext};
295 use gpui::{EmptyView, View};
296 use schemars::schema_for;
297 use schemars::JsonSchema;
298 use serde::{Deserialize, Serialize};
299 use serde_json::json;
300
301 #[derive(Deserialize, Serialize, JsonSchema)]
302 struct WeatherQuery {
303 location: String,
304 unit: String,
305 }
306
307 struct WeatherTool {
308 current_weather: WeatherResult,
309 }
310
311 #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
312 struct WeatherResult {
313 location: String,
314 temperature: f64,
315 unit: String,
316 }
317
318 struct WeatherView {
319 result: WeatherResult,
320 }
321
322 impl Render for WeatherView {
323 fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
324 div().child(format!("temperature: {}", self.result.temperature))
325 }
326 }
327
328 impl ToolOutput for WeatherView {
329 fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
330 serde_json::to_string(&self.result).unwrap()
331 }
332 }
333
334 impl LanguageModelTool for WeatherTool {
335 type Input = WeatherQuery;
336 type Output = WeatherResult;
337 type View = WeatherView;
338
339 fn name(&self) -> String {
340 "get_current_weather".to_string()
341 }
342
343 fn description(&self) -> String {
344 "Fetches the current weather for a given location.".to_string()
345 }
346
347 fn execute(
348 &self,
349 input: &Self::Input,
350 _cx: &mut WindowContext,
351 ) -> Task<Result<Self::Output>> {
352 let _location = input.location.clone();
353 let _unit = input.unit.clone();
354
355 let weather = self.current_weather.clone();
356
357 Task::ready(Ok(weather))
358 }
359
360 fn output_view(
361 _input: Self::Input,
362 result: Result<Self::Output>,
363 cx: &mut WindowContext,
364 ) -> View<Self::View> {
365 cx.new_view(|_cx| {
366 let result = result.unwrap();
367 WeatherView { result }
368 })
369 }
370 }
371
372 #[gpui::test]
373 async fn test_openai_weather_example(cx: &mut TestAppContext) {
374 cx.background_executor.run_until_parked();
375 let (_, cx) = cx.add_window_view(|_cx| EmptyView);
376
377 let tool = WeatherTool {
378 current_weather: WeatherResult {
379 location: "San Francisco".to_string(),
380 temperature: 21.0,
381 unit: "Celsius".to_string(),
382 },
383 };
384
385 let tools = vec![tool.definition()];
386 assert_eq!(tools.len(), 1);
387
388 let expected = ToolFunctionDefinition {
389 name: "get_current_weather".to_string(),
390 description: "Fetches the current weather for a given location.".to_string(),
391 parameters: schema_for!(WeatherQuery),
392 };
393
394 assert_eq!(tools[0].name, expected.name);
395 assert_eq!(tools[0].description, expected.description);
396
397 let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
398
399 assert_eq!(
400 expected_schema,
401 json!({
402 "$schema": "http://json-schema.org/draft-07/schema#",
403 "title": "WeatherQuery",
404 "type": "object",
405 "properties": {
406 "location": {
407 "type": "string"
408 },
409 "unit": {
410 "type": "string"
411 }
412 },
413 "required": ["location", "unit"]
414 })
415 );
416
417 let args = json!({
418 "location": "San Francisco",
419 "unit": "Celsius"
420 });
421
422 let query: WeatherQuery = serde_json::from_value(args).unwrap();
423
424 let result = cx.update(|cx| tool.execute(&query, cx)).await;
425
426 assert!(result.is_ok());
427 let result = result.unwrap();
428
429 assert_eq!(result, tool.current_weather);
430 }
431}