assistant_tool.rs

  1mod action_log;
  2pub mod outline;
  3mod tool_registry;
  4mod tool_schema;
  5mod tool_working_set;
  6
  7use std::fmt;
  8use std::fmt::Debug;
  9use std::fmt::Formatter;
 10use std::ops::Deref;
 11use std::sync::Arc;
 12
 13use anyhow::Result;
 14use gpui::AnyElement;
 15use gpui::AnyWindowHandle;
 16use gpui::Context;
 17use gpui::IntoElement;
 18use gpui::Window;
 19use gpui::{App, Entity, SharedString, Task, WeakEntity};
 20use icons::IconName;
 21use language_model::LanguageModelRequestMessage;
 22use language_model::LanguageModelToolSchemaFormat;
 23use project::Project;
 24use workspace::Workspace;
 25
 26pub use crate::action_log::*;
 27pub use crate::tool_registry::*;
 28pub use crate::tool_schema::*;
 29pub use crate::tool_working_set::*;
 30
 31pub fn init(cx: &mut App) {
 32    ToolRegistry::default_global(cx);
 33}
 34
 35#[derive(Debug, Clone)]
 36pub enum ToolUseStatus {
 37    InputStillStreaming,
 38    NeedsConfirmation,
 39    Pending,
 40    Running,
 41    Finished(SharedString),
 42    Error(SharedString),
 43}
 44
 45impl ToolUseStatus {
 46    pub fn text(&self) -> SharedString {
 47        match self {
 48            ToolUseStatus::NeedsConfirmation => "".into(),
 49            ToolUseStatus::InputStillStreaming => "".into(),
 50            ToolUseStatus::Pending => "".into(),
 51            ToolUseStatus::Running => "".into(),
 52            ToolUseStatus::Finished(out) => out.clone(),
 53            ToolUseStatus::Error(out) => out.clone(),
 54        }
 55    }
 56
 57    pub fn error(&self) -> Option<SharedString> {
 58        match self {
 59            ToolUseStatus::Error(out) => Some(out.clone()),
 60            _ => None,
 61        }
 62    }
 63}
 64
 65#[derive(Debug)]
 66pub struct ToolResultOutput {
 67    pub content: String,
 68    pub output: Option<serde_json::Value>,
 69}
 70
 71impl From<String> for ToolResultOutput {
 72    fn from(value: String) -> Self {
 73        ToolResultOutput {
 74            content: value,
 75            output: None,
 76        }
 77    }
 78}
 79
 80impl Deref for ToolResultOutput {
 81    type Target = String;
 82
 83    fn deref(&self) -> &Self::Target {
 84        &self.content
 85    }
 86}
 87
 88/// The result of running a tool, containing both the asynchronous output
 89/// and an optional card view that can be rendered immediately.
 90pub struct ToolResult {
 91    /// The asynchronous task that will eventually resolve to the tool's output
 92    pub output: Task<Result<ToolResultOutput>>,
 93    /// An optional view to present the output of the tool.
 94    pub card: Option<AnyToolCard>,
 95}
 96
 97pub trait ToolCard: 'static + Sized {
 98    fn render(
 99        &mut self,
100        status: &ToolUseStatus,
101        window: &mut Window,
102        workspace: WeakEntity<Workspace>,
103        cx: &mut Context<Self>,
104    ) -> impl IntoElement;
105}
106
107#[derive(Clone)]
108pub struct AnyToolCard {
109    entity: gpui::AnyEntity,
110    render: fn(
111        entity: gpui::AnyEntity,
112        status: &ToolUseStatus,
113        window: &mut Window,
114        workspace: WeakEntity<Workspace>,
115        cx: &mut App,
116    ) -> AnyElement,
117}
118
119impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
120    fn from(entity: Entity<T>) -> Self {
121        fn downcast_render<T: ToolCard>(
122            entity: gpui::AnyEntity,
123            status: &ToolUseStatus,
124            window: &mut Window,
125            workspace: WeakEntity<Workspace>,
126            cx: &mut App,
127        ) -> AnyElement {
128            let entity = entity.downcast::<T>().unwrap();
129            entity.update(cx, |entity, cx| {
130                entity
131                    .render(status, window, workspace, cx)
132                    .into_any_element()
133            })
134        }
135
136        Self {
137            entity: entity.into(),
138            render: downcast_render::<T>,
139        }
140    }
141}
142
143impl AnyToolCard {
144    pub fn render(
145        &self,
146        status: &ToolUseStatus,
147        window: &mut Window,
148        workspace: WeakEntity<Workspace>,
149        cx: &mut App,
150    ) -> AnyElement {
151        (self.render)(self.entity.clone(), status, window, workspace, cx)
152    }
153}
154
155impl From<Task<Result<ToolResultOutput>>> for ToolResult {
156    /// Convert from a task to a ToolResult with no card
157    fn from(output: Task<Result<ToolResultOutput>>) -> Self {
158        Self { output, card: None }
159    }
160}
161
162#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
163pub enum ToolSource {
164    /// A native tool built-in to Zed.
165    Native,
166    /// A tool provided by a context server.
167    ContextServer { id: SharedString },
168}
169
170/// A tool that can be used by a language model.
171pub trait Tool: 'static + Send + Sync {
172    /// Returns the name of the tool.
173    fn name(&self) -> String;
174
175    /// Returns the description of the tool.
176    fn description(&self) -> String;
177
178    /// Returns the icon for the tool.
179    fn icon(&self) -> IconName;
180
181    /// Returns the source of the tool.
182    fn source(&self) -> ToolSource {
183        ToolSource::Native
184    }
185
186    /// Returns true iff the tool needs the users's confirmation
187    /// before having permission to run.
188    fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
189
190    /// Returns the JSON schema that describes the tool's input.
191    fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
192        Ok(serde_json::Value::Object(serde_json::Map::default()))
193    }
194
195    /// Returns markdown to be displayed in the UI for this tool.
196    fn ui_text(&self, input: &serde_json::Value) -> String;
197
198    /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
199    /// (so information may be missing).
200    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
201        self.ui_text(input)
202    }
203
204    /// Runs the tool with the provided input.
205    fn run(
206        self: Arc<Self>,
207        input: serde_json::Value,
208        messages: &[LanguageModelRequestMessage],
209        project: Entity<Project>,
210        action_log: Entity<ActionLog>,
211        window: Option<AnyWindowHandle>,
212        cx: &mut App,
213    ) -> ToolResult;
214
215    fn deserialize_card(
216        self: Arc<Self>,
217        _output: serde_json::Value,
218        _project: Entity<Project>,
219        _window: &mut Window,
220        _cx: &mut App,
221    ) -> Option<AnyToolCard> {
222        None
223    }
224}
225
226impl Debug for dyn Tool {
227    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
228        f.debug_struct("Tool").field("name", &self.name()).finish()
229    }
230}