1mod action_log;
2pub mod outline;
3mod tool_registry;
4mod tool_schema;
5mod tool_working_set;
6
7use std::{fmt, fmt::Debug, fmt::Formatter, ops::Deref, sync::Arc};
8
9use anyhow::Result;
10use gpui::{
11 AnyElement, AnyWindowHandle, App, Context, Entity, IntoElement, SharedString, Task, WeakEntity,
12 Window,
13};
14use icons::IconName;
15use language_model::{
16 LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat,
17};
18use project::Project;
19use serde::de::DeserializeOwned;
20use workspace::Workspace;
21
22pub use crate::action_log::*;
23pub use crate::tool_registry::*;
24pub use crate::tool_schema::*;
25pub use crate::tool_working_set::*;
26
27pub fn init(cx: &mut App) {
28 ToolRegistry::default_global(cx);
29}
30
31#[derive(Debug, Clone)]
32pub enum ToolUseStatus {
33 InputStillStreaming,
34 NeedsConfirmation,
35 Pending,
36 Running,
37 Finished(SharedString),
38 Error(SharedString),
39}
40
41impl ToolUseStatus {
42 pub fn text(&self) -> SharedString {
43 match self {
44 ToolUseStatus::NeedsConfirmation => "".into(),
45 ToolUseStatus::InputStillStreaming => "".into(),
46 ToolUseStatus::Pending => "".into(),
47 ToolUseStatus::Running => "".into(),
48 ToolUseStatus::Finished(out) => out.clone(),
49 ToolUseStatus::Error(out) => out.clone(),
50 }
51 }
52
53 pub fn error(&self) -> Option<SharedString> {
54 match self {
55 ToolUseStatus::Error(out) => Some(out.clone()),
56 _ => None,
57 }
58 }
59}
60
61#[derive(Debug)]
62pub struct ToolResultOutput {
63 pub content: ToolResultContent,
64 pub output: Option<serde_json::Value>,
65}
66
67#[derive(Debug, PartialEq, Eq)]
68pub enum ToolResultContent {
69 Text(String),
70 Image(LanguageModelImage),
71}
72
73impl ToolResultContent {
74 pub fn len(&self) -> usize {
75 match self {
76 ToolResultContent::Text(str) => str.len(),
77 ToolResultContent::Image(image) => image.len(),
78 }
79 }
80
81 pub fn is_empty(&self) -> bool {
82 match self {
83 ToolResultContent::Text(str) => str.is_empty(),
84 ToolResultContent::Image(image) => image.is_empty(),
85 }
86 }
87
88 pub fn as_str(&self) -> Option<&str> {
89 match self {
90 ToolResultContent::Text(str) => Some(str),
91 ToolResultContent::Image(_) => None,
92 }
93 }
94}
95
96impl From<String> for ToolResultOutput {
97 fn from(value: String) -> Self {
98 ToolResultOutput {
99 content: ToolResultContent::Text(value),
100 output: None,
101 }
102 }
103}
104
105impl Deref for ToolResultOutput {
106 type Target = ToolResultContent;
107
108 fn deref(&self) -> &Self::Target {
109 &self.content
110 }
111}
112
113/// The result of running a tool, containing both the asynchronous output
114/// and an optional card view that can be rendered immediately.
115pub struct ToolResult {
116 /// The asynchronous task that will eventually resolve to the tool's output
117 pub output: Task<Result<ToolResultOutput>>,
118 /// An optional view to present the output of the tool.
119 pub card: Option<AnyToolCard>,
120}
121
122pub trait ToolCard: 'static + Sized {
123 fn render(
124 &mut self,
125 status: &ToolUseStatus,
126 window: &mut Window,
127 workspace: WeakEntity<Workspace>,
128 cx: &mut Context<Self>,
129 ) -> impl IntoElement;
130}
131
132#[derive(Clone)]
133pub struct AnyToolCard {
134 entity: gpui::AnyEntity,
135 render: fn(
136 entity: gpui::AnyEntity,
137 status: &ToolUseStatus,
138 window: &mut Window,
139 workspace: WeakEntity<Workspace>,
140 cx: &mut App,
141 ) -> AnyElement,
142}
143
144impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
145 fn from(entity: Entity<T>) -> Self {
146 fn downcast_render<T: ToolCard>(
147 entity: gpui::AnyEntity,
148 status: &ToolUseStatus,
149 window: &mut Window,
150 workspace: WeakEntity<Workspace>,
151 cx: &mut App,
152 ) -> AnyElement {
153 let entity = entity.downcast::<T>().unwrap();
154 entity.update(cx, |entity, cx| {
155 entity
156 .render(status, window, workspace, cx)
157 .into_any_element()
158 })
159 }
160
161 Self {
162 entity: entity.into(),
163 render: downcast_render::<T>,
164 }
165 }
166}
167
168impl AnyToolCard {
169 pub fn render(
170 &self,
171 status: &ToolUseStatus,
172 window: &mut Window,
173 workspace: WeakEntity<Workspace>,
174 cx: &mut App,
175 ) -> AnyElement {
176 (self.render)(self.entity.clone(), status, window, workspace, cx)
177 }
178}
179
180impl From<Task<Result<ToolResultOutput>>> for ToolResult {
181 /// Convert from a task to a ToolResult with no card
182 fn from(output: Task<Result<ToolResultOutput>>) -> Self {
183 Self { output, card: None }
184 }
185}
186
187#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
188pub enum ToolSource {
189 /// A native tool built-in to Zed.
190 Native,
191 /// A tool provided by a context server.
192 ContextServer { id: SharedString },
193}
194
195/// A tool that can be used by a language model.
196pub trait Tool: Send + Sync + 'static {
197 /// The input type that is accepted by the tool.
198 type Input: DeserializeOwned;
199
200 /// Returns the name of the tool.
201 fn name(&self) -> String;
202
203 /// Returns the description of the tool.
204 fn description(&self) -> String;
205
206 /// Returns the icon for the tool.
207 fn icon(&self) -> IconName;
208
209 /// Returns the source of the tool.
210 fn source(&self) -> ToolSource {
211 ToolSource::Native
212 }
213
214 /// Returns true if the tool needs the users's confirmation
215 /// before having permission to run.
216 fn needs_confirmation(&self, input: &Self::Input, cx: &App) -> bool;
217
218 /// Returns true if the tool may perform edits.
219 fn may_perform_edits(&self) -> bool;
220
221 /// Returns the JSON schema that describes the tool's input.
222 fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
223 Ok(serde_json::Value::Object(serde_json::Map::default()))
224 }
225
226 /// Returns markdown to be displayed in the UI for this tool.
227 fn ui_text(&self, input: &Self::Input) -> String;
228
229 /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
230 /// (so information may be missing).
231 fn still_streaming_ui_text(&self, input: &Self::Input) -> String {
232 self.ui_text(input)
233 }
234
235 /// Runs the tool with the provided input.
236 fn run(
237 self: Arc<Self>,
238 input: Self::Input,
239 request: Arc<LanguageModelRequest>,
240 project: Entity<Project>,
241 action_log: Entity<ActionLog>,
242 model: Arc<dyn LanguageModel>,
243 window: Option<AnyWindowHandle>,
244 cx: &mut App,
245 ) -> ToolResult;
246
247 fn deserialize_card(
248 self: Arc<Self>,
249 _output: serde_json::Value,
250 _project: Entity<Project>,
251 _window: &mut Window,
252 _cx: &mut App,
253 ) -> Option<AnyToolCard> {
254 None
255 }
256}
257
258#[derive(Clone)]
259pub struct AnyTool {
260 inner: Arc<dyn ErasedTool>,
261}
262
263/// Copy of `Tool` where the Input type is erased.
264trait ErasedTool: Send + Sync {
265 fn name(&self) -> String;
266 fn description(&self) -> String;
267 fn icon(&self) -> IconName;
268 fn source(&self) -> ToolSource;
269 fn may_perform_edits(&self) -> bool;
270 fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
271 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
272 fn ui_text(&self, input: &serde_json::Value) -> String;
273 fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String;
274 fn run(
275 &self,
276 input: serde_json::Value,
277 request: Arc<LanguageModelRequest>,
278 project: Entity<Project>,
279 action_log: Entity<ActionLog>,
280 model: Arc<dyn LanguageModel>,
281 window: Option<AnyWindowHandle>,
282 cx: &mut App,
283 ) -> ToolResult;
284 fn deserialize_card(
285 &self,
286 output: serde_json::Value,
287 project: Entity<Project>,
288 window: &mut Window,
289 cx: &mut App,
290 ) -> Option<AnyToolCard>;
291}
292
293struct ErasedToolWrapper<T: Tool> {
294 tool: Arc<T>,
295}
296
297impl<T: Tool> ErasedTool for ErasedToolWrapper<T> {
298 fn name(&self) -> String {
299 self.tool.name()
300 }
301
302 fn description(&self) -> String {
303 self.tool.description()
304 }
305
306 fn icon(&self) -> IconName {
307 self.tool.icon()
308 }
309
310 fn source(&self) -> ToolSource {
311 self.tool.source()
312 }
313
314 fn may_perform_edits(&self) -> bool {
315 self.tool.may_perform_edits()
316 }
317
318 fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
319 match serde_json::from_value::<T::Input>(input.clone()) {
320 Ok(parsed_input) => self.tool.needs_confirmation(&parsed_input, cx),
321 Err(_) => true,
322 }
323 }
324
325 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
326 self.tool.input_schema(format)
327 }
328
329 fn ui_text(&self, input: &serde_json::Value) -> String {
330 match serde_json::from_value::<T::Input>(input.clone()) {
331 Ok(parsed_input) => self.tool.ui_text(&parsed_input),
332 Err(_) => "Invalid input".to_string(),
333 }
334 }
335
336 fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
337 match serde_json::from_value::<T::Input>(input.clone()) {
338 Ok(parsed_input) => self.tool.still_streaming_ui_text(&parsed_input),
339 Err(_) => "Invalid input".to_string(),
340 }
341 }
342
343 fn run(
344 &self,
345 input: serde_json::Value,
346 request: Arc<LanguageModelRequest>,
347 project: Entity<Project>,
348 action_log: Entity<ActionLog>,
349 model: Arc<dyn LanguageModel>,
350 window: Option<AnyWindowHandle>,
351 cx: &mut App,
352 ) -> ToolResult {
353 match serde_json::from_value::<T::Input>(input) {
354 Ok(parsed_input) => self.tool.clone().run(
355 parsed_input,
356 request,
357 project,
358 action_log,
359 model,
360 window,
361 cx,
362 ),
363 Err(err) => ToolResult::from(Task::ready(Err(err.into()))),
364 }
365 }
366
367 fn deserialize_card(
368 &self,
369 output: serde_json::Value,
370 project: Entity<Project>,
371 window: &mut Window,
372 cx: &mut App,
373 ) -> Option<AnyToolCard> {
374 self.tool
375 .clone()
376 .deserialize_card(output, project, window, cx)
377 }
378}
379
380impl<T: Tool> From<Arc<T>> for AnyTool {
381 fn from(tool: Arc<T>) -> Self {
382 Self {
383 inner: Arc::new(ErasedToolWrapper { tool }),
384 }
385 }
386}
387
388impl AnyTool {
389 pub fn name(&self) -> String {
390 self.inner.name()
391 }
392
393 pub fn description(&self) -> String {
394 self.inner.description()
395 }
396
397 pub fn icon(&self) -> IconName {
398 self.inner.icon()
399 }
400
401 pub fn source(&self) -> ToolSource {
402 self.inner.source()
403 }
404
405 pub fn may_perform_edits(&self) -> bool {
406 self.inner.may_perform_edits()
407 }
408
409 pub fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
410 self.inner.needs_confirmation(input, cx)
411 }
412
413 pub fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
414 self.inner.input_schema(format)
415 }
416
417 pub fn ui_text(&self, input: &serde_json::Value) -> String {
418 self.inner.ui_text(input)
419 }
420
421 pub fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
422 self.inner.still_streaming_ui_text(input)
423 }
424
425 pub fn run(
426 &self,
427 input: serde_json::Value,
428 request: Arc<LanguageModelRequest>,
429 project: Entity<Project>,
430 action_log: Entity<ActionLog>,
431 model: Arc<dyn LanguageModel>,
432 window: Option<AnyWindowHandle>,
433 cx: &mut App,
434 ) -> ToolResult {
435 self.inner
436 .run(input, request, project, action_log, model, window, cx)
437 }
438
439 pub fn deserialize_card(
440 &self,
441 output: serde_json::Value,
442 project: Entity<Project>,
443 window: &mut Window,
444 cx: &mut App,
445 ) -> Option<AnyToolCard> {
446 self.inner.deserialize_card(output, project, window, cx)
447 }
448}
449
450impl Debug for AnyTool {
451 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
452 f.debug_struct("Tool").field("name", &self.name()).finish()
453 }
454}