assistant_tool.rs

  1mod action_log;
  2pub mod outline;
  3mod tool_registry;
  4mod tool_schema;
  5mod tool_working_set;
  6
  7use std::fmt;
  8use std::fmt::Debug;
  9use std::fmt::Formatter;
 10use std::ops::Deref;
 11use std::sync::Arc;
 12
 13use anyhow::Result;
 14use gpui::AnyElement;
 15use gpui::AnyWindowHandle;
 16use gpui::Context;
 17use gpui::IntoElement;
 18use gpui::Window;
 19use gpui::{App, Entity, SharedString, Task, WeakEntity};
 20use icons::IconName;
 21use language_model::LanguageModel;
 22use language_model::LanguageModelRequest;
 23use language_model::LanguageModelToolSchemaFormat;
 24use project::Project;
 25use workspace::Workspace;
 26
 27pub use crate::action_log::*;
 28pub use crate::tool_registry::*;
 29pub use crate::tool_schema::*;
 30pub use crate::tool_working_set::*;
 31
 32pub fn init(cx: &mut App) {
 33    ToolRegistry::default_global(cx);
 34}
 35
 36#[derive(Debug, Clone)]
 37pub enum ToolUseStatus {
 38    InputStillStreaming,
 39    NeedsConfirmation,
 40    Pending,
 41    Running,
 42    Finished(SharedString),
 43    Error(SharedString),
 44}
 45
 46impl ToolUseStatus {
 47    pub fn text(&self) -> SharedString {
 48        match self {
 49            ToolUseStatus::NeedsConfirmation => "".into(),
 50            ToolUseStatus::InputStillStreaming => "".into(),
 51            ToolUseStatus::Pending => "".into(),
 52            ToolUseStatus::Running => "".into(),
 53            ToolUseStatus::Finished(out) => out.clone(),
 54            ToolUseStatus::Error(out) => out.clone(),
 55        }
 56    }
 57
 58    pub fn error(&self) -> Option<SharedString> {
 59        match self {
 60            ToolUseStatus::Error(out) => Some(out.clone()),
 61            _ => None,
 62        }
 63    }
 64}
 65
 66#[derive(Debug)]
 67pub struct ToolResultOutput {
 68    pub content: String,
 69    pub output: Option<serde_json::Value>,
 70}
 71
 72impl From<String> for ToolResultOutput {
 73    fn from(value: String) -> Self {
 74        ToolResultOutput {
 75            content: value,
 76            output: None,
 77        }
 78    }
 79}
 80
 81impl Deref for ToolResultOutput {
 82    type Target = String;
 83
 84    fn deref(&self) -> &Self::Target {
 85        &self.content
 86    }
 87}
 88
 89/// The result of running a tool, containing both the asynchronous output
 90/// and an optional card view that can be rendered immediately.
 91pub struct ToolResult {
 92    /// The asynchronous task that will eventually resolve to the tool's output
 93    pub output: Task<Result<ToolResultOutput>>,
 94    /// An optional view to present the output of the tool.
 95    pub card: Option<AnyToolCard>,
 96}
 97
 98pub trait ToolCard: 'static + Sized {
 99    fn render(
100        &mut self,
101        status: &ToolUseStatus,
102        window: &mut Window,
103        workspace: WeakEntity<Workspace>,
104        cx: &mut Context<Self>,
105    ) -> impl IntoElement;
106}
107
108#[derive(Clone)]
109pub struct AnyToolCard {
110    entity: gpui::AnyEntity,
111    render: fn(
112        entity: gpui::AnyEntity,
113        status: &ToolUseStatus,
114        window: &mut Window,
115        workspace: WeakEntity<Workspace>,
116        cx: &mut App,
117    ) -> AnyElement,
118}
119
120impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
121    fn from(entity: Entity<T>) -> Self {
122        fn downcast_render<T: ToolCard>(
123            entity: gpui::AnyEntity,
124            status: &ToolUseStatus,
125            window: &mut Window,
126            workspace: WeakEntity<Workspace>,
127            cx: &mut App,
128        ) -> AnyElement {
129            let entity = entity.downcast::<T>().unwrap();
130            entity.update(cx, |entity, cx| {
131                entity
132                    .render(status, window, workspace, cx)
133                    .into_any_element()
134            })
135        }
136
137        Self {
138            entity: entity.into(),
139            render: downcast_render::<T>,
140        }
141    }
142}
143
144impl AnyToolCard {
145    pub fn render(
146        &self,
147        status: &ToolUseStatus,
148        window: &mut Window,
149        workspace: WeakEntity<Workspace>,
150        cx: &mut App,
151    ) -> AnyElement {
152        (self.render)(self.entity.clone(), status, window, workspace, cx)
153    }
154}
155
156impl From<Task<Result<ToolResultOutput>>> for ToolResult {
157    /// Convert from a task to a ToolResult with no card
158    fn from(output: Task<Result<ToolResultOutput>>) -> Self {
159        Self { output, card: None }
160    }
161}
162
163#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
164pub enum ToolSource {
165    /// A native tool built-in to Zed.
166    Native,
167    /// A tool provided by a context server.
168    ContextServer { id: SharedString },
169}
170
171/// A tool that can be used by a language model.
172pub trait Tool: 'static + Send + Sync {
173    /// Returns the name of the tool.
174    fn name(&self) -> String;
175
176    /// Returns the description of the tool.
177    fn description(&self) -> String;
178
179    /// Returns the icon for the tool.
180    fn icon(&self) -> IconName;
181
182    /// Returns the source of the tool.
183    fn source(&self) -> ToolSource {
184        ToolSource::Native
185    }
186
187    /// Returns true iff the tool needs the users's confirmation
188    /// before having permission to run.
189    fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
190
191    /// Returns the JSON schema that describes the tool's input.
192    fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
193        Ok(serde_json::Value::Object(serde_json::Map::default()))
194    }
195
196    /// Returns markdown to be displayed in the UI for this tool.
197    fn ui_text(&self, input: &serde_json::Value) -> String;
198
199    /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
200    /// (so information may be missing).
201    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
202        self.ui_text(input)
203    }
204
205    /// Runs the tool with the provided input.
206    fn run(
207        self: Arc<Self>,
208        input: serde_json::Value,
209        request: Arc<LanguageModelRequest>,
210        project: Entity<Project>,
211        action_log: Entity<ActionLog>,
212        model: Arc<dyn LanguageModel>,
213        window: Option<AnyWindowHandle>,
214        cx: &mut App,
215    ) -> ToolResult;
216
217    fn deserialize_card(
218        self: Arc<Self>,
219        _output: serde_json::Value,
220        _project: Entity<Project>,
221        _window: &mut Window,
222        _cx: &mut App,
223    ) -> Option<AnyToolCard> {
224        None
225    }
226}
227
228impl Debug for dyn Tool {
229    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
230        f.debug_struct("Tool").field("name", &self.name()).finish()
231    }
232}