assistant_tool.rs

  1mod action_log;
  2mod tool_registry;
  3mod tool_schema;
  4mod tool_working_set;
  5
  6use std::fmt;
  7use std::fmt::Debug;
  8use std::fmt::Formatter;
  9use std::sync::Arc;
 10
 11use anyhow::Result;
 12use gpui::AnyElement;
 13use gpui::AnyWindowHandle;
 14use gpui::Context;
 15use gpui::IntoElement;
 16use gpui::Window;
 17use gpui::{App, Entity, SharedString, Task, WeakEntity};
 18use icons::IconName;
 19use language_model::LanguageModelRequestMessage;
 20use language_model::LanguageModelToolSchemaFormat;
 21use project::Project;
 22use workspace::Workspace;
 23
 24pub use crate::action_log::*;
 25pub use crate::tool_registry::*;
 26pub use crate::tool_schema::*;
 27pub use crate::tool_working_set::*;
 28
 29pub fn init(cx: &mut App) {
 30    ToolRegistry::default_global(cx);
 31}
 32
 33#[derive(Debug, Clone)]
 34pub enum ToolUseStatus {
 35    InputStillStreaming,
 36    NeedsConfirmation,
 37    Pending,
 38    Running,
 39    Finished(SharedString),
 40    Error(SharedString),
 41}
 42
 43impl ToolUseStatus {
 44    pub fn text(&self) -> SharedString {
 45        match self {
 46            ToolUseStatus::NeedsConfirmation => "".into(),
 47            ToolUseStatus::InputStillStreaming => "".into(),
 48            ToolUseStatus::Pending => "".into(),
 49            ToolUseStatus::Running => "".into(),
 50            ToolUseStatus::Finished(out) => out.clone(),
 51            ToolUseStatus::Error(out) => out.clone(),
 52        }
 53    }
 54}
 55
 56/// The result of running a tool, containing both the asynchronous output
 57/// and an optional card view that can be rendered immediately.
 58pub struct ToolResult {
 59    /// The asynchronous task that will eventually resolve to the tool's output
 60    pub output: Task<Result<String>>,
 61    /// An optional view to present the output of the tool.
 62    pub card: Option<AnyToolCard>,
 63}
 64
 65pub trait ToolCard: 'static + Sized {
 66    fn render(
 67        &mut self,
 68        status: &ToolUseStatus,
 69        window: &mut Window,
 70        workspace: WeakEntity<Workspace>,
 71        cx: &mut Context<Self>,
 72    ) -> impl IntoElement;
 73}
 74
 75#[derive(Clone)]
 76pub struct AnyToolCard {
 77    entity: gpui::AnyEntity,
 78    render: fn(
 79        entity: gpui::AnyEntity,
 80        status: &ToolUseStatus,
 81        window: &mut Window,
 82        workspace: WeakEntity<Workspace>,
 83        cx: &mut App,
 84    ) -> AnyElement,
 85}
 86
 87impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
 88    fn from(entity: Entity<T>) -> Self {
 89        fn downcast_render<T: ToolCard>(
 90            entity: gpui::AnyEntity,
 91            status: &ToolUseStatus,
 92            window: &mut Window,
 93            workspace: WeakEntity<Workspace>,
 94            cx: &mut App,
 95        ) -> AnyElement {
 96            let entity = entity.downcast::<T>().unwrap();
 97            entity.update(cx, |entity, cx| {
 98                entity
 99                    .render(status, window, workspace, cx)
100                    .into_any_element()
101            })
102        }
103
104        Self {
105            entity: entity.into(),
106            render: downcast_render::<T>,
107        }
108    }
109}
110
111impl AnyToolCard {
112    pub fn render(
113        &self,
114        status: &ToolUseStatus,
115        window: &mut Window,
116        workspace: WeakEntity<Workspace>,
117        cx: &mut App,
118    ) -> AnyElement {
119        (self.render)(self.entity.clone(), status, window, workspace, cx)
120    }
121}
122
123impl From<Task<Result<String>>> for ToolResult {
124    /// Convert from a task to a ToolResult with no card
125    fn from(output: Task<Result<String>>) -> Self {
126        Self { output, card: None }
127    }
128}
129
130#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
131pub enum ToolSource {
132    /// A native tool built-in to Zed.
133    Native,
134    /// A tool provided by a context server.
135    ContextServer { id: SharedString },
136}
137
138/// A tool that can be used by a language model.
139pub trait Tool: 'static + Send + Sync {
140    /// Returns the name of the tool.
141    fn name(&self) -> String;
142
143    /// Returns the description of the tool.
144    fn description(&self) -> String;
145
146    /// Returns the icon for the tool.
147    fn icon(&self) -> IconName;
148
149    /// Returns the source of the tool.
150    fn source(&self) -> ToolSource {
151        ToolSource::Native
152    }
153
154    /// Returns true iff the tool needs the users's confirmation
155    /// before having permission to run.
156    fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
157
158    /// Returns the JSON schema that describes the tool's input.
159    fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
160        Ok(serde_json::Value::Object(serde_json::Map::default()))
161    }
162
163    /// Returns markdown to be displayed in the UI for this tool.
164    fn ui_text(&self, input: &serde_json::Value) -> String;
165
166    /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
167    /// (so information may be missing).
168    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
169        self.ui_text(input)
170    }
171
172    /// Runs the tool with the provided input.
173    fn run(
174        self: Arc<Self>,
175        input: serde_json::Value,
176        messages: &[LanguageModelRequestMessage],
177        project: Entity<Project>,
178        action_log: Entity<ActionLog>,
179        window: Option<AnyWindowHandle>,
180        cx: &mut App,
181    ) -> ToolResult;
182}
183
184impl Debug for dyn Tool {
185    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
186        f.debug_struct("Tool").field("name", &self.name()).finish()
187    }
188}