assistant_tool.rs

  1mod action_log;
  2mod tool_registry;
  3mod tool_schema;
  4mod tool_working_set;
  5
  6use std::fmt;
  7use std::fmt::Debug;
  8use std::fmt::Formatter;
  9use std::sync::Arc;
 10
 11use anyhow::Result;
 12use gpui::AnyElement;
 13use gpui::Context;
 14use gpui::IntoElement;
 15use gpui::Window;
 16use gpui::{App, Entity, SharedString, Task};
 17use icons::IconName;
 18use language_model::LanguageModelRequestMessage;
 19use language_model::LanguageModelToolSchemaFormat;
 20use project::Project;
 21
 22pub use crate::action_log::*;
 23pub use crate::tool_registry::*;
 24pub use crate::tool_schema::*;
 25pub use crate::tool_working_set::*;
 26
 27pub fn init(cx: &mut App) {
 28    ToolRegistry::default_global(cx);
 29}
 30
 31#[derive(Debug, Clone)]
 32pub enum ToolUseStatus {
 33    InputStillStreaming,
 34    NeedsConfirmation,
 35    Pending,
 36    Running,
 37    Finished(SharedString),
 38    Error(SharedString),
 39}
 40
 41impl ToolUseStatus {
 42    pub fn text(&self) -> SharedString {
 43        match self {
 44            ToolUseStatus::NeedsConfirmation => "".into(),
 45            ToolUseStatus::InputStillStreaming => "".into(),
 46            ToolUseStatus::Pending => "".into(),
 47            ToolUseStatus::Running => "".into(),
 48            ToolUseStatus::Finished(out) => out.clone(),
 49            ToolUseStatus::Error(out) => out.clone(),
 50        }
 51    }
 52}
 53
 54/// The result of running a tool, containing both the asynchronous output
 55/// and an optional card view that can be rendered immediately.
 56pub struct ToolResult {
 57    /// The asynchronous task that will eventually resolve to the tool's output
 58    pub output: Task<Result<String>>,
 59    /// An optional view to present the output of the tool.
 60    pub card: Option<AnyToolCard>,
 61}
 62
 63pub trait ToolCard: 'static + Sized {
 64    fn render(
 65        &mut self,
 66        status: &ToolUseStatus,
 67        window: &mut Window,
 68        cx: &mut Context<Self>,
 69    ) -> impl IntoElement;
 70}
 71
 72#[derive(Clone)]
 73pub struct AnyToolCard {
 74    entity: gpui::AnyEntity,
 75    render: fn(
 76        entity: gpui::AnyEntity,
 77        status: &ToolUseStatus,
 78        window: &mut Window,
 79        cx: &mut App,
 80    ) -> AnyElement,
 81}
 82
 83impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
 84    fn from(entity: Entity<T>) -> Self {
 85        fn downcast_render<T: ToolCard>(
 86            entity: gpui::AnyEntity,
 87            status: &ToolUseStatus,
 88            window: &mut Window,
 89            cx: &mut App,
 90        ) -> AnyElement {
 91            let entity = entity.downcast::<T>().unwrap();
 92            entity.update(cx, |entity, cx| {
 93                entity.render(status, window, cx).into_any_element()
 94            })
 95        }
 96
 97        Self {
 98            entity: entity.into(),
 99            render: downcast_render::<T>,
100        }
101    }
102}
103
104impl AnyToolCard {
105    pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
106        (self.render)(self.entity.clone(), status, window, cx)
107    }
108}
109
110impl From<Task<Result<String>>> for ToolResult {
111    /// Convert from a task to a ToolResult with no card
112    fn from(output: Task<Result<String>>) -> Self {
113        Self { output, card: None }
114    }
115}
116
117#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
118pub enum ToolSource {
119    /// A native tool built-in to Zed.
120    Native,
121    /// A tool provided by a context server.
122    ContextServer { id: SharedString },
123}
124
125/// A tool that can be used by a language model.
126pub trait Tool: 'static + Send + Sync {
127    /// Returns the name of the tool.
128    fn name(&self) -> String;
129
130    /// Returns the description of the tool.
131    fn description(&self) -> String;
132
133    /// Returns the icon for the tool.
134    fn icon(&self) -> IconName;
135
136    /// Returns the source of the tool.
137    fn source(&self) -> ToolSource {
138        ToolSource::Native
139    }
140
141    /// Returns true iff the tool needs the users's confirmation
142    /// before having permission to run.
143    fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
144
145    /// Returns the JSON schema that describes the tool's input.
146    fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
147        Ok(serde_json::Value::Object(serde_json::Map::default()))
148    }
149
150    /// Returns markdown to be displayed in the UI for this tool.
151    fn ui_text(&self, input: &serde_json::Value) -> String;
152
153    /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
154    /// (so information may be missing).
155    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
156        self.ui_text(input)
157    }
158
159    /// Runs the tool with the provided input.
160    fn run(
161        self: Arc<Self>,
162        input: serde_json::Value,
163        messages: &[LanguageModelRequestMessage],
164        project: Entity<Project>,
165        action_log: Entity<ActionLog>,
166        cx: &mut App,
167    ) -> ToolResult;
168}
169
170impl Debug for dyn Tool {
171    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
172        f.debug_struct("Tool").field("name", &self.name()).finish()
173    }
174}