assistant_tool.rs

  1pub mod outline;
  2mod tool_registry;
  3mod tool_schema;
  4mod tool_working_set;
  5
  6use std::fmt;
  7use std::fmt::Debug;
  8use std::fmt::Formatter;
  9use std::ops::Deref;
 10use std::sync::Arc;
 11
 12use action_log::ActionLog;
 13use anyhow::Result;
 14use gpui::AnyElement;
 15use gpui::AnyWindowHandle;
 16use gpui::Context;
 17use gpui::IntoElement;
 18use gpui::Window;
 19use gpui::{App, Entity, SharedString, Task, WeakEntity};
 20use icons::IconName;
 21use language_model::LanguageModel;
 22use language_model::LanguageModelImage;
 23use language_model::LanguageModelRequest;
 24use language_model::LanguageModelToolSchemaFormat;
 25use project::Project;
 26use workspace::Workspace;
 27
 28pub use crate::tool_registry::*;
 29pub use crate::tool_schema::*;
 30pub use crate::tool_working_set::*;
 31
 32pub fn init(cx: &mut App) {
 33    ToolRegistry::default_global(cx);
 34}
 35
 36#[derive(Debug, Clone)]
 37pub enum ToolUseStatus {
 38    InputStillStreaming,
 39    NeedsConfirmation,
 40    Pending,
 41    Running,
 42    Finished(SharedString),
 43    Error(SharedString),
 44}
 45
 46impl ToolUseStatus {
 47    pub fn text(&self) -> SharedString {
 48        match self {
 49            ToolUseStatus::NeedsConfirmation => "".into(),
 50            ToolUseStatus::InputStillStreaming => "".into(),
 51            ToolUseStatus::Pending => "".into(),
 52            ToolUseStatus::Running => "".into(),
 53            ToolUseStatus::Finished(out) => out.clone(),
 54            ToolUseStatus::Error(out) => out.clone(),
 55        }
 56    }
 57
 58    pub fn error(&self) -> Option<SharedString> {
 59        match self {
 60            ToolUseStatus::Error(out) => Some(out.clone()),
 61            _ => None,
 62        }
 63    }
 64}
 65
 66#[derive(Debug)]
 67pub struct ToolResultOutput {
 68    pub content: ToolResultContent,
 69    pub output: Option<serde_json::Value>,
 70}
 71
 72#[derive(Debug, PartialEq, Eq)]
 73pub enum ToolResultContent {
 74    Text(String),
 75    Image(LanguageModelImage),
 76}
 77
 78impl ToolResultContent {
 79    pub fn len(&self) -> usize {
 80        match self {
 81            ToolResultContent::Text(str) => str.len(),
 82            ToolResultContent::Image(image) => image.len(),
 83        }
 84    }
 85
 86    pub fn is_empty(&self) -> bool {
 87        match self {
 88            ToolResultContent::Text(str) => str.is_empty(),
 89            ToolResultContent::Image(image) => image.is_empty(),
 90        }
 91    }
 92
 93    pub fn as_str(&self) -> Option<&str> {
 94        match self {
 95            ToolResultContent::Text(str) => Some(str),
 96            ToolResultContent::Image(_) => None,
 97        }
 98    }
 99}
100
101impl From<String> for ToolResultOutput {
102    fn from(value: String) -> Self {
103        ToolResultOutput {
104            content: ToolResultContent::Text(value),
105            output: None,
106        }
107    }
108}
109
110impl Deref for ToolResultOutput {
111    type Target = ToolResultContent;
112
113    fn deref(&self) -> &Self::Target {
114        &self.content
115    }
116}
117
118/// The result of running a tool, containing both the asynchronous output
119/// and an optional card view that can be rendered immediately.
120pub struct ToolResult {
121    /// The asynchronous task that will eventually resolve to the tool's output
122    pub output: Task<Result<ToolResultOutput>>,
123    /// An optional view to present the output of the tool.
124    pub card: Option<AnyToolCard>,
125}
126
127pub trait ToolCard: 'static + Sized {
128    fn render(
129        &mut self,
130        status: &ToolUseStatus,
131        window: &mut Window,
132        workspace: WeakEntity<Workspace>,
133        cx: &mut Context<Self>,
134    ) -> impl IntoElement;
135}
136
137#[derive(Clone)]
138pub struct AnyToolCard {
139    entity: gpui::AnyEntity,
140    render: fn(
141        entity: gpui::AnyEntity,
142        status: &ToolUseStatus,
143        window: &mut Window,
144        workspace: WeakEntity<Workspace>,
145        cx: &mut App,
146    ) -> AnyElement,
147}
148
149impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
150    fn from(entity: Entity<T>) -> Self {
151        fn downcast_render<T: ToolCard>(
152            entity: gpui::AnyEntity,
153            status: &ToolUseStatus,
154            window: &mut Window,
155            workspace: WeakEntity<Workspace>,
156            cx: &mut App,
157        ) -> AnyElement {
158            let entity = entity.downcast::<T>().unwrap();
159            entity.update(cx, |entity, cx| {
160                entity
161                    .render(status, window, workspace, cx)
162                    .into_any_element()
163            })
164        }
165
166        Self {
167            entity: entity.into(),
168            render: downcast_render::<T>,
169        }
170    }
171}
172
173impl AnyToolCard {
174    pub fn render(
175        &self,
176        status: &ToolUseStatus,
177        window: &mut Window,
178        workspace: WeakEntity<Workspace>,
179        cx: &mut App,
180    ) -> AnyElement {
181        (self.render)(self.entity.clone(), status, window, workspace, cx)
182    }
183}
184
185impl From<Task<Result<ToolResultOutput>>> for ToolResult {
186    /// Convert from a task to a ToolResult with no card
187    fn from(output: Task<Result<ToolResultOutput>>) -> Self {
188        Self { output, card: None }
189    }
190}
191
192#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
193pub enum ToolSource {
194    /// A native tool built-in to Zed.
195    Native,
196    /// A tool provided by a context server.
197    ContextServer { id: SharedString },
198}
199
200/// A tool that can be used by a language model.
201pub trait Tool: 'static + Send + Sync {
202    /// Returns the name of the tool.
203    fn name(&self) -> String;
204
205    /// Returns the description of the tool.
206    fn description(&self) -> String;
207
208    /// Returns the icon for the tool.
209    fn icon(&self) -> IconName;
210
211    /// Returns the source of the tool.
212    fn source(&self) -> ToolSource {
213        ToolSource::Native
214    }
215
216    /// Returns true if the tool needs the users's confirmation
217    /// before having permission to run.
218    fn needs_confirmation(
219        &self,
220        input: &serde_json::Value,
221        project: &Entity<Project>,
222        cx: &App,
223    ) -> bool;
224
225    /// Returns true if the tool may perform edits.
226    fn may_perform_edits(&self) -> bool;
227
228    /// Returns the JSON schema that describes the tool's input.
229    fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
230        Ok(serde_json::Value::Object(serde_json::Map::default()))
231    }
232
233    /// Returns markdown to be displayed in the UI for this tool.
234    fn ui_text(&self, input: &serde_json::Value) -> String;
235
236    /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
237    /// (so information may be missing).
238    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
239        self.ui_text(input)
240    }
241
242    /// Runs the tool with the provided input.
243    fn run(
244        self: Arc<Self>,
245        input: serde_json::Value,
246        request: Arc<LanguageModelRequest>,
247        project: Entity<Project>,
248        action_log: Entity<ActionLog>,
249        model: Arc<dyn LanguageModel>,
250        window: Option<AnyWindowHandle>,
251        cx: &mut App,
252    ) -> ToolResult;
253
254    fn deserialize_card(
255        self: Arc<Self>,
256        _output: serde_json::Value,
257        _project: Entity<Project>,
258        _window: &mut Window,
259        _cx: &mut App,
260    ) -> Option<AnyToolCard> {
261        None
262    }
263}
264
265impl Debug for dyn Tool {
266    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
267        f.debug_struct("Tool").field("name", &self.name()).finish()
268    }
269}