assistant_tool.rs

  1mod action_log;
  2pub mod outline;
  3mod tool_registry;
  4mod tool_schema;
  5mod tool_working_set;
  6
  7use std::fmt;
  8use std::fmt::Debug;
  9use std::fmt::Formatter;
 10use std::ops::Deref;
 11use std::sync::Arc;
 12
 13use anyhow::Result;
 14use gpui::AnyElement;
 15use gpui::AnyWindowHandle;
 16use gpui::Context;
 17use gpui::IntoElement;
 18use gpui::Window;
 19use gpui::{App, Entity, SharedString, Task, WeakEntity};
 20use icons::IconName;
 21use language_model::LanguageModel;
 22use language_model::LanguageModelImage;
 23use language_model::LanguageModelRequest;
 24use language_model::LanguageModelToolSchemaFormat;
 25use project::Project;
 26use workspace::Workspace;
 27
 28pub use crate::action_log::*;
 29pub use crate::tool_registry::*;
 30pub use crate::tool_schema::*;
 31pub use crate::tool_working_set::*;
 32
 33pub fn init(cx: &mut App) {
 34    ToolRegistry::default_global(cx);
 35}
 36
 37#[derive(Debug, Clone)]
 38pub enum ToolUseStatus {
 39    InputStillStreaming,
 40    NeedsConfirmation,
 41    Pending,
 42    Running,
 43    Finished(SharedString),
 44    Error(SharedString),
 45}
 46
 47impl ToolUseStatus {
 48    pub fn text(&self) -> SharedString {
 49        match self {
 50            ToolUseStatus::NeedsConfirmation => "".into(),
 51            ToolUseStatus::InputStillStreaming => "".into(),
 52            ToolUseStatus::Pending => "".into(),
 53            ToolUseStatus::Running => "".into(),
 54            ToolUseStatus::Finished(out) => out.clone(),
 55            ToolUseStatus::Error(out) => out.clone(),
 56        }
 57    }
 58
 59    pub fn error(&self) -> Option<SharedString> {
 60        match self {
 61            ToolUseStatus::Error(out) => Some(out.clone()),
 62            _ => None,
 63        }
 64    }
 65}
 66
 67#[derive(Debug)]
 68pub struct ToolResultOutput {
 69    pub content: ToolResultContent,
 70    pub output: Option<serde_json::Value>,
 71}
 72
 73#[derive(Clone, Debug, PartialEq, Eq)]
 74pub enum ToolResultContent {
 75    Text(String),
 76    Image(LanguageModelImage),
 77}
 78
 79impl ToolResultContent {
 80    pub fn len(&self) -> usize {
 81        match self {
 82            ToolResultContent::Text(str) => str.len(),
 83            ToolResultContent::Image(image) => image.len(),
 84        }
 85    }
 86
 87    pub fn is_empty(&self) -> bool {
 88        match self {
 89            ToolResultContent::Text(str) => str.is_empty(),
 90            ToolResultContent::Image(image) => image.is_empty(),
 91        }
 92    }
 93
 94    pub fn as_str(&self) -> Option<&str> {
 95        match self {
 96            ToolResultContent::Text(str) => Some(str),
 97            ToolResultContent::Image(_) => None,
 98        }
 99    }
100}
101
102impl From<String> for ToolResultOutput {
103    fn from(value: String) -> Self {
104        ToolResultOutput {
105            content: ToolResultContent::Text(value),
106            output: None,
107        }
108    }
109}
110
111impl Deref for ToolResultOutput {
112    type Target = ToolResultContent;
113
114    fn deref(&self) -> &Self::Target {
115        &self.content
116    }
117}
118
119/// The result of running a tool, containing both the asynchronous output
120/// and an optional card view that can be rendered immediately.
121pub struct ToolResult {
122    /// The asynchronous task that will eventually resolve to the tool's output
123    pub output: Task<Result<ToolResultOutput>>,
124    /// An optional view to present the output of the tool.
125    pub card: Option<AnyToolCard>,
126}
127
128pub trait ToolCard: 'static + Sized {
129    fn render(
130        &mut self,
131        status: &ToolUseStatus,
132        window: &mut Window,
133        workspace: WeakEntity<Workspace>,
134        cx: &mut Context<Self>,
135    ) -> impl IntoElement;
136}
137
138#[derive(Debug, Clone)]
139#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq, Eq))]
140pub struct AnyToolCard {
141    entity: gpui::AnyEntity,
142    render: fn(
143        entity: gpui::AnyEntity,
144        status: &ToolUseStatus,
145        window: &mut Window,
146        workspace: WeakEntity<Workspace>,
147        cx: &mut App,
148    ) -> AnyElement,
149}
150
151impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
152    fn from(entity: Entity<T>) -> Self {
153        fn downcast_render<T: ToolCard>(
154            entity: gpui::AnyEntity,
155            status: &ToolUseStatus,
156            window: &mut Window,
157            workspace: WeakEntity<Workspace>,
158            cx: &mut App,
159        ) -> AnyElement {
160            let entity = entity.downcast::<T>().unwrap();
161            entity.update(cx, |entity, cx| {
162                entity
163                    .render(status, window, workspace, cx)
164                    .into_any_element()
165            })
166        }
167
168        Self {
169            entity: entity.into(),
170            render: downcast_render::<T>,
171        }
172    }
173}
174
175impl AnyToolCard {
176    pub fn render(
177        &self,
178        status: &ToolUseStatus,
179        window: &mut Window,
180        workspace: WeakEntity<Workspace>,
181        cx: &mut App,
182    ) -> AnyElement {
183        (self.render)(self.entity.clone(), status, window, workspace, cx)
184    }
185}
186
187impl From<Task<Result<ToolResultOutput>>> for ToolResult {
188    /// Convert from a task to a ToolResult with no card
189    fn from(output: Task<Result<ToolResultOutput>>) -> Self {
190        Self { output, card: None }
191    }
192}
193
194#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
195pub enum ToolSource {
196    /// A native tool built-in to Zed.
197    Native,
198    /// A tool provided by a context server.
199    ContextServer { id: SharedString },
200}
201
202/// A tool that can be used by a language model.
203pub trait Tool: 'static + Send + Sync {
204    /// Returns the name of the tool.
205    fn name(&self) -> String;
206
207    /// Returns the description of the tool.
208    fn description(&self) -> String;
209
210    /// Returns the icon for the tool.
211    fn icon(&self) -> IconName;
212
213    /// Returns the source of the tool.
214    fn source(&self) -> ToolSource {
215        ToolSource::Native
216    }
217
218    /// Returns true if the tool needs the users's confirmation
219    /// before having permission to run.
220    fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
221
222    /// Returns true if the tool may perform edits.
223    fn may_perform_edits(&self) -> bool;
224
225    /// Returns the JSON schema that describes the tool's input.
226    fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
227        Ok(serde_json::Value::Object(serde_json::Map::default()))
228    }
229
230    /// Returns markdown to be displayed in the UI for this tool.
231    fn ui_text(&self, input: &serde_json::Value) -> String;
232
233    /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
234    /// (so information may be missing).
235    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
236        self.ui_text(input)
237    }
238
239    /// Runs the tool with the provided input.
240    fn run(
241        self: Arc<Self>,
242        input: serde_json::Value,
243        request: Arc<LanguageModelRequest>,
244        project: Entity<Project>,
245        action_log: Entity<ActionLog>,
246        model: Arc<dyn LanguageModel>,
247        window: Option<AnyWindowHandle>,
248        cx: &mut App,
249    ) -> ToolResult;
250
251    fn deserialize_card(
252        self: Arc<Self>,
253        _output: serde_json::Value,
254        _project: Entity<Project>,
255        _window: &mut Window,
256        _cx: &mut App,
257    ) -> Option<AnyToolCard> {
258        None
259    }
260}
261
262impl Debug for dyn Tool {
263    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
264        f.debug_struct("Tool").field("name", &self.name()).finish()
265    }
266}