assistant_tool.rs

  1mod action_log;
  2pub mod outline;
  3mod tool_registry;
  4mod tool_schema;
  5mod tool_working_set;
  6
  7use std::{fmt, fmt::Debug, fmt::Formatter, ops::Deref, sync::Arc};
  8
  9use anyhow::Result;
 10use gpui::{
 11    AnyElement, AnyWindowHandle, App, Context, Entity, IntoElement, SharedString, Task, WeakEntity,
 12    Window,
 13};
 14use icons::IconName;
 15use language_model::{
 16    LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat,
 17};
 18use project::Project;
 19use serde::de::DeserializeOwned;
 20use workspace::Workspace;
 21
 22pub use crate::action_log::*;
 23pub use crate::tool_registry::*;
 24pub use crate::tool_schema::*;
 25pub use crate::tool_working_set::*;
 26
 27pub fn init(cx: &mut App) {
 28    ToolRegistry::default_global(cx);
 29}
 30
 31#[derive(Debug, Clone)]
 32pub enum ToolUseStatus {
 33    InputStillStreaming,
 34    NeedsConfirmation,
 35    Pending,
 36    Running,
 37    Finished(SharedString),
 38    Error(SharedString),
 39}
 40
 41impl ToolUseStatus {
 42    pub fn text(&self) -> SharedString {
 43        match self {
 44            ToolUseStatus::NeedsConfirmation => "".into(),
 45            ToolUseStatus::InputStillStreaming => "".into(),
 46            ToolUseStatus::Pending => "".into(),
 47            ToolUseStatus::Running => "".into(),
 48            ToolUseStatus::Finished(out) => out.clone(),
 49            ToolUseStatus::Error(out) => out.clone(),
 50        }
 51    }
 52
 53    pub fn error(&self) -> Option<SharedString> {
 54        match self {
 55            ToolUseStatus::Error(out) => Some(out.clone()),
 56            _ => None,
 57        }
 58    }
 59}
 60
 61#[derive(Debug)]
 62pub struct ToolResultOutput {
 63    pub content: ToolResultContent,
 64    pub output: Option<serde_json::Value>,
 65}
 66
 67#[derive(Debug, PartialEq, Eq)]
 68pub enum ToolResultContent {
 69    Text(String),
 70    Image(LanguageModelImage),
 71}
 72
 73impl ToolResultContent {
 74    pub fn len(&self) -> usize {
 75        match self {
 76            ToolResultContent::Text(str) => str.len(),
 77            ToolResultContent::Image(image) => image.len(),
 78        }
 79    }
 80
 81    pub fn is_empty(&self) -> bool {
 82        match self {
 83            ToolResultContent::Text(str) => str.is_empty(),
 84            ToolResultContent::Image(image) => image.is_empty(),
 85        }
 86    }
 87
 88    pub fn as_str(&self) -> Option<&str> {
 89        match self {
 90            ToolResultContent::Text(str) => Some(str),
 91            ToolResultContent::Image(_) => None,
 92        }
 93    }
 94}
 95
 96impl From<String> for ToolResultOutput {
 97    fn from(value: String) -> Self {
 98        ToolResultOutput {
 99            content: ToolResultContent::Text(value),
100            output: None,
101        }
102    }
103}
104
105impl Deref for ToolResultOutput {
106    type Target = ToolResultContent;
107
108    fn deref(&self) -> &Self::Target {
109        &self.content
110    }
111}
112
113/// The result of running a tool, containing both the asynchronous output
114/// and an optional card view that can be rendered immediately.
115pub struct ToolResult {
116    /// The asynchronous task that will eventually resolve to the tool's output
117    pub output: Task<Result<ToolResultOutput>>,
118    /// An optional view to present the output of the tool.
119    pub card: Option<AnyToolCard>,
120}
121
122pub trait ToolCard: 'static + Sized {
123    fn render(
124        &mut self,
125        status: &ToolUseStatus,
126        window: &mut Window,
127        workspace: WeakEntity<Workspace>,
128        cx: &mut Context<Self>,
129    ) -> impl IntoElement;
130}
131
132#[derive(Clone)]
133pub struct AnyToolCard {
134    entity: gpui::AnyEntity,
135    render: fn(
136        entity: gpui::AnyEntity,
137        status: &ToolUseStatus,
138        window: &mut Window,
139        workspace: WeakEntity<Workspace>,
140        cx: &mut App,
141    ) -> AnyElement,
142}
143
144impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
145    fn from(entity: Entity<T>) -> Self {
146        fn downcast_render<T: ToolCard>(
147            entity: gpui::AnyEntity,
148            status: &ToolUseStatus,
149            window: &mut Window,
150            workspace: WeakEntity<Workspace>,
151            cx: &mut App,
152        ) -> AnyElement {
153            let entity = entity.downcast::<T>().unwrap();
154            entity.update(cx, |entity, cx| {
155                entity
156                    .render(status, window, workspace, cx)
157                    .into_any_element()
158            })
159        }
160
161        Self {
162            entity: entity.into(),
163            render: downcast_render::<T>,
164        }
165    }
166}
167
168impl AnyToolCard {
169    pub fn render(
170        &self,
171        status: &ToolUseStatus,
172        window: &mut Window,
173        workspace: WeakEntity<Workspace>,
174        cx: &mut App,
175    ) -> AnyElement {
176        (self.render)(self.entity.clone(), status, window, workspace, cx)
177    }
178}
179
180impl From<Task<Result<ToolResultOutput>>> for ToolResult {
181    /// Convert from a task to a ToolResult with no card
182    fn from(output: Task<Result<ToolResultOutput>>) -> Self {
183        Self { output, card: None }
184    }
185}
186
187#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
188pub enum ToolSource {
189    /// A native tool built-in to Zed.
190    Native,
191    /// A tool provided by a context server.
192    ContextServer { id: SharedString },
193}
194
195/// A tool that can be used by a language model.
196pub trait Tool: Send + Sync + 'static {
197    /// The input type that is accepted by the tool.
198    type Input: DeserializeOwned;
199
200    /// Returns the name of the tool.
201    fn name(&self) -> String;
202
203    /// Returns the description of the tool.
204    fn description(&self) -> String;
205
206    /// Returns the icon for the tool.
207    fn icon(&self) -> IconName;
208
209    /// Returns the source of the tool.
210    fn source(&self) -> ToolSource {
211        ToolSource::Native
212    }
213
214    /// Returns true if the tool needs the users's confirmation
215    /// before having permission to run.
216    fn needs_confirmation(&self, input: &Self::Input, cx: &App) -> bool;
217
218    /// Returns true if the tool may perform edits.
219    fn may_perform_edits(&self) -> bool;
220
221    /// Returns the JSON schema that describes the tool's input.
222    fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
223        Ok(serde_json::Value::Object(serde_json::Map::default()))
224    }
225
226    /// Returns markdown to be displayed in the UI for this tool.
227    fn ui_text(&self, input: &Self::Input) -> String;
228
229    /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
230    /// (so information may be missing).
231    fn still_streaming_ui_text(&self, input: &Self::Input) -> String {
232        self.ui_text(input)
233    }
234
235    /// Runs the tool with the provided input.
236    fn run(
237        self: Arc<Self>,
238        input: Self::Input,
239        request: Arc<LanguageModelRequest>,
240        project: Entity<Project>,
241        action_log: Entity<ActionLog>,
242        model: Arc<dyn LanguageModel>,
243        window: Option<AnyWindowHandle>,
244        cx: &mut App,
245    ) -> ToolResult;
246
247    fn deserialize_card(
248        self: Arc<Self>,
249        _output: serde_json::Value,
250        _project: Entity<Project>,
251        _window: &mut Window,
252        _cx: &mut App,
253    ) -> Option<AnyToolCard> {
254        None
255    }
256}
257
258#[derive(Clone)]
259pub struct AnyTool {
260    inner: Arc<dyn ErasedTool>,
261}
262
263/// Copy of `Tool` where the Input type is erased.
264trait ErasedTool: Send + Sync {
265    fn name(&self) -> String;
266    fn description(&self) -> String;
267    fn icon(&self) -> IconName;
268    fn source(&self) -> ToolSource;
269    fn may_perform_edits(&self) -> bool;
270    fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
271    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
272    fn ui_text(&self, input: &serde_json::Value) -> String;
273    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String;
274    fn run(
275        &self,
276        input: serde_json::Value,
277        request: Arc<LanguageModelRequest>,
278        project: Entity<Project>,
279        action_log: Entity<ActionLog>,
280        model: Arc<dyn LanguageModel>,
281        window: Option<AnyWindowHandle>,
282        cx: &mut App,
283    ) -> ToolResult;
284    fn deserialize_card(
285        &self,
286        output: serde_json::Value,
287        project: Entity<Project>,
288        window: &mut Window,
289        cx: &mut App,
290    ) -> Option<AnyToolCard>;
291}
292
293struct ErasedToolWrapper<T: Tool> {
294    tool: Arc<T>,
295}
296
297impl<T: Tool> ErasedTool for ErasedToolWrapper<T> {
298    fn name(&self) -> String {
299        self.tool.name()
300    }
301
302    fn description(&self) -> String {
303        self.tool.description()
304    }
305
306    fn icon(&self) -> IconName {
307        self.tool.icon()
308    }
309
310    fn source(&self) -> ToolSource {
311        self.tool.source()
312    }
313
314    fn may_perform_edits(&self) -> bool {
315        self.tool.may_perform_edits()
316    }
317
318    fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
319        match serde_json::from_value::<T::Input>(input.clone()) {
320            Ok(parsed_input) => self.tool.needs_confirmation(&parsed_input, cx),
321            Err(_) => true,
322        }
323    }
324
325    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
326        self.tool.input_schema(format)
327    }
328
329    fn ui_text(&self, input: &serde_json::Value) -> String {
330        match serde_json::from_value::<T::Input>(input.clone()) {
331            Ok(parsed_input) => self.tool.ui_text(&parsed_input),
332            Err(_) => "Invalid input".to_string(),
333        }
334    }
335
336    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
337        match serde_json::from_value::<T::Input>(input.clone()) {
338            Ok(parsed_input) => self.tool.still_streaming_ui_text(&parsed_input),
339            Err(_) => "Invalid input".to_string(),
340        }
341    }
342
343    fn run(
344        &self,
345        input: serde_json::Value,
346        request: Arc<LanguageModelRequest>,
347        project: Entity<Project>,
348        action_log: Entity<ActionLog>,
349        model: Arc<dyn LanguageModel>,
350        window: Option<AnyWindowHandle>,
351        cx: &mut App,
352    ) -> ToolResult {
353        match serde_json::from_value::<T::Input>(input) {
354            Ok(parsed_input) => self.tool.clone().run(
355                parsed_input,
356                request,
357                project,
358                action_log,
359                model,
360                window,
361                cx,
362            ),
363            Err(err) => ToolResult::from(Task::ready(Err(err.into()))),
364        }
365    }
366
367    fn deserialize_card(
368        &self,
369        output: serde_json::Value,
370        project: Entity<Project>,
371        window: &mut Window,
372        cx: &mut App,
373    ) -> Option<AnyToolCard> {
374        self.tool
375            .clone()
376            .deserialize_card(output, project, window, cx)
377    }
378}
379
380impl<T: Tool> From<Arc<T>> for AnyTool {
381    fn from(tool: Arc<T>) -> Self {
382        Self {
383            inner: Arc::new(ErasedToolWrapper { tool }),
384        }
385    }
386}
387
388impl AnyTool {
389    pub fn name(&self) -> String {
390        self.inner.name()
391    }
392
393    pub fn description(&self) -> String {
394        self.inner.description()
395    }
396
397    pub fn icon(&self) -> IconName {
398        self.inner.icon()
399    }
400
401    pub fn source(&self) -> ToolSource {
402        self.inner.source()
403    }
404
405    pub fn may_perform_edits(&self) -> bool {
406        self.inner.may_perform_edits()
407    }
408
409    pub fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
410        self.inner.needs_confirmation(input, cx)
411    }
412
413    pub fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
414        self.inner.input_schema(format)
415    }
416
417    pub fn ui_text(&self, input: &serde_json::Value) -> String {
418        self.inner.ui_text(input)
419    }
420
421    pub fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
422        self.inner.still_streaming_ui_text(input)
423    }
424
425    pub fn run(
426        &self,
427        input: serde_json::Value,
428        request: Arc<LanguageModelRequest>,
429        project: Entity<Project>,
430        action_log: Entity<ActionLog>,
431        model: Arc<dyn LanguageModel>,
432        window: Option<AnyWindowHandle>,
433        cx: &mut App,
434    ) -> ToolResult {
435        self.inner
436            .run(input, request, project, action_log, model, window, cx)
437    }
438
439    pub fn deserialize_card(
440        &self,
441        output: serde_json::Value,
442        project: Entity<Project>,
443        window: &mut Window,
444        cx: &mut App,
445    ) -> Option<AnyToolCard> {
446        self.inner.deserialize_card(output, project, window, cx)
447    }
448}
449
450impl Debug for AnyTool {
451    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
452        f.debug_struct("Tool").field("name", &self.name()).finish()
453    }
454}