assistant_tool.rs

  1mod action_log;
  2pub mod outline;
  3mod tool_registry;
  4mod tool_schema;
  5mod tool_working_set;
  6
  7use std::fmt;
  8use std::fmt::Debug;
  9use std::fmt::Formatter;
 10use std::ops::Deref;
 11use std::sync::Arc;
 12
 13use anyhow::Result;
 14use gpui::AnyElement;
 15use gpui::AnyWindowHandle;
 16use gpui::Context;
 17use gpui::IntoElement;
 18use gpui::Window;
 19use gpui::{App, Entity, SharedString, Task, WeakEntity};
 20use icons::IconName;
 21use language_model::LanguageModel;
 22use language_model::LanguageModelImage;
 23use language_model::LanguageModelRequest;
 24use language_model::LanguageModelToolSchemaFormat;
 25use project::Project;
 26use workspace::Workspace;
 27
 28pub use crate::action_log::*;
 29pub use crate::tool_registry::*;
 30pub use crate::tool_schema::*;
 31pub use crate::tool_working_set::*;
 32
 33pub fn init(cx: &mut App) {
 34    ToolRegistry::default_global(cx);
 35}
 36
 37#[derive(Debug, Clone)]
 38pub enum ToolUseStatus {
 39    InputStillStreaming,
 40    NeedsConfirmation,
 41    Pending,
 42    Running,
 43    Finished(SharedString),
 44    Error(SharedString),
 45}
 46
 47impl ToolUseStatus {
 48    pub fn text(&self) -> SharedString {
 49        match self {
 50            ToolUseStatus::NeedsConfirmation => "".into(),
 51            ToolUseStatus::InputStillStreaming => "".into(),
 52            ToolUseStatus::Pending => "".into(),
 53            ToolUseStatus::Running => "".into(),
 54            ToolUseStatus::Finished(out) => out.clone(),
 55            ToolUseStatus::Error(out) => out.clone(),
 56        }
 57    }
 58
 59    pub fn error(&self) -> Option<SharedString> {
 60        match self {
 61            ToolUseStatus::Error(out) => Some(out.clone()),
 62            _ => None,
 63        }
 64    }
 65}
 66
 67#[derive(Debug)]
 68pub struct ToolResultOutput {
 69    pub content: ToolResultContent,
 70    pub output: Option<serde_json::Value>,
 71}
 72
 73#[derive(Debug, PartialEq, Eq)]
 74pub enum ToolResultContent {
 75    Text(String),
 76    Image(LanguageModelImage),
 77}
 78
 79impl ToolResultContent {
 80    pub fn len(&self) -> usize {
 81        match self {
 82            ToolResultContent::Text(str) => str.len(),
 83            ToolResultContent::Image(image) => image.len(),
 84        }
 85    }
 86
 87    pub fn is_empty(&self) -> bool {
 88        match self {
 89            ToolResultContent::Text(str) => str.is_empty(),
 90            ToolResultContent::Image(image) => image.is_empty(),
 91        }
 92    }
 93
 94    pub fn as_str(&self) -> Option<&str> {
 95        match self {
 96            ToolResultContent::Text(str) => Some(str),
 97            ToolResultContent::Image(_) => None,
 98        }
 99    }
100}
101
102impl From<String> for ToolResultOutput {
103    fn from(value: String) -> Self {
104        ToolResultOutput {
105            content: ToolResultContent::Text(value),
106            output: None,
107        }
108    }
109}
110
111impl Deref for ToolResultOutput {
112    type Target = ToolResultContent;
113
114    fn deref(&self) -> &Self::Target {
115        &self.content
116    }
117}
118
119/// The result of running a tool, containing both the asynchronous output
120/// and an optional card view that can be rendered immediately.
121pub struct ToolResult {
122    /// The asynchronous task that will eventually resolve to the tool's output
123    pub output: Task<Result<ToolResultOutput>>,
124    /// An optional view to present the output of the tool.
125    pub card: Option<AnyToolCard>,
126}
127
128pub trait ToolCard: 'static + Sized {
129    fn render(
130        &mut self,
131        status: &ToolUseStatus,
132        window: &mut Window,
133        workspace: WeakEntity<Workspace>,
134        cx: &mut Context<Self>,
135    ) -> impl IntoElement;
136}
137
138#[derive(Clone)]
139pub struct AnyToolCard {
140    entity: gpui::AnyEntity,
141    render: fn(
142        entity: gpui::AnyEntity,
143        status: &ToolUseStatus,
144        window: &mut Window,
145        workspace: WeakEntity<Workspace>,
146        cx: &mut App,
147    ) -> AnyElement,
148}
149
150impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
151    fn from(entity: Entity<T>) -> Self {
152        fn downcast_render<T: ToolCard>(
153            entity: gpui::AnyEntity,
154            status: &ToolUseStatus,
155            window: &mut Window,
156            workspace: WeakEntity<Workspace>,
157            cx: &mut App,
158        ) -> AnyElement {
159            let entity = entity.downcast::<T>().unwrap();
160            entity.update(cx, |entity, cx| {
161                entity
162                    .render(status, window, workspace, cx)
163                    .into_any_element()
164            })
165        }
166
167        Self {
168            entity: entity.into(),
169            render: downcast_render::<T>,
170        }
171    }
172}
173
174impl AnyToolCard {
175    pub fn render(
176        &self,
177        status: &ToolUseStatus,
178        window: &mut Window,
179        workspace: WeakEntity<Workspace>,
180        cx: &mut App,
181    ) -> AnyElement {
182        (self.render)(self.entity.clone(), status, window, workspace, cx)
183    }
184}
185
186impl From<Task<Result<ToolResultOutput>>> for ToolResult {
187    /// Convert from a task to a ToolResult with no card
188    fn from(output: Task<Result<ToolResultOutput>>) -> Self {
189        Self { output, card: None }
190    }
191}
192
193#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
194pub enum ToolSource {
195    /// A native tool built-in to Zed.
196    Native,
197    /// A tool provided by a context server.
198    ContextServer { id: SharedString },
199}
200
201/// A tool that can be used by a language model.
202pub trait Tool: 'static + Send + Sync {
203    /// Returns the name of the tool.
204    fn name(&self) -> String;
205
206    /// Returns the description of the tool.
207    fn description(&self) -> String;
208
209    /// Returns the icon for the tool.
210    fn icon(&self) -> IconName;
211
212    /// Returns the source of the tool.
213    fn source(&self) -> ToolSource {
214        ToolSource::Native
215    }
216
217    /// Returns true if the tool needs the users's confirmation
218    /// before having permission to run.
219    fn needs_confirmation(
220        &self,
221        input: &serde_json::Value,
222        project: &Entity<Project>,
223        cx: &App,
224    ) -> bool;
225
226    /// Returns true if the tool may perform edits.
227    fn may_perform_edits(&self) -> bool;
228
229    /// Returns the JSON schema that describes the tool's input.
230    fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
231        Ok(serde_json::Value::Object(serde_json::Map::default()))
232    }
233
234    /// Returns markdown to be displayed in the UI for this tool.
235    fn ui_text(&self, input: &serde_json::Value) -> String;
236
237    /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
238    /// (so information may be missing).
239    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
240        self.ui_text(input)
241    }
242
243    /// Runs the tool with the provided input.
244    fn run(
245        self: Arc<Self>,
246        input: serde_json::Value,
247        request: Arc<LanguageModelRequest>,
248        project: Entity<Project>,
249        action_log: Entity<ActionLog>,
250        model: Arc<dyn LanguageModel>,
251        window: Option<AnyWindowHandle>,
252        cx: &mut App,
253    ) -> ToolResult;
254
255    fn deserialize_card(
256        self: Arc<Self>,
257        _output: serde_json::Value,
258        _project: Entity<Project>,
259        _window: &mut Window,
260        _cx: &mut App,
261    ) -> Option<AnyToolCard> {
262        None
263    }
264}
265
266impl Debug for dyn Tool {
267    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
268        f.debug_struct("Tool").field("name", &self.name()).finish()
269    }
270}