assistant_tool.rs

  1mod action_log;
  2mod tool_registry;
  3mod tool_schema;
  4mod tool_working_set;
  5
  6use std::fmt;
  7use std::fmt::Debug;
  8use std::fmt::Formatter;
  9use std::sync::Arc;
 10
 11use anyhow::Result;
 12use gpui::AnyElement;
 13use gpui::Context;
 14use gpui::IntoElement;
 15use gpui::Window;
 16use gpui::{App, Entity, SharedString, Task};
 17use icons::IconName;
 18use language_model::LanguageModelRequestMessage;
 19use language_model::LanguageModelToolSchemaFormat;
 20use project::Project;
 21
 22pub use crate::action_log::*;
 23pub use crate::tool_registry::*;
 24pub use crate::tool_schema::*;
 25pub use crate::tool_working_set::*;
 26
 27pub fn init(cx: &mut App) {
 28    ToolRegistry::default_global(cx);
 29}
 30
 31#[derive(Debug, Clone)]
 32pub enum ToolUseStatus {
 33    NeedsConfirmation,
 34    Pending,
 35    Running,
 36    Finished(SharedString),
 37    Error(SharedString),
 38}
 39
 40impl ToolUseStatus {
 41    pub fn text(&self) -> SharedString {
 42        match self {
 43            ToolUseStatus::NeedsConfirmation => "".into(),
 44            ToolUseStatus::Pending => "".into(),
 45            ToolUseStatus::Running => "".into(),
 46            ToolUseStatus::Finished(out) => out.clone(),
 47            ToolUseStatus::Error(out) => out.clone(),
 48        }
 49    }
 50}
 51
 52/// The result of running a tool, containing both the asynchronous output
 53/// and an optional card view that can be rendered immediately.
 54pub struct ToolResult {
 55    /// The asynchronous task that will eventually resolve to the tool's output
 56    pub output: Task<Result<String>>,
 57    /// An optional view to present the output of the tool.
 58    pub card: Option<AnyToolCard>,
 59}
 60
 61pub trait ToolCard: 'static + Sized {
 62    fn render(
 63        &mut self,
 64        status: &ToolUseStatus,
 65        window: &mut Window,
 66        cx: &mut Context<Self>,
 67    ) -> impl IntoElement;
 68}
 69
 70#[derive(Clone)]
 71pub struct AnyToolCard {
 72    entity: gpui::AnyEntity,
 73    render: fn(
 74        entity: gpui::AnyEntity,
 75        status: &ToolUseStatus,
 76        window: &mut Window,
 77        cx: &mut App,
 78    ) -> AnyElement,
 79}
 80
 81impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
 82    fn from(entity: Entity<T>) -> Self {
 83        fn downcast_render<T: ToolCard>(
 84            entity: gpui::AnyEntity,
 85            status: &ToolUseStatus,
 86            window: &mut Window,
 87            cx: &mut App,
 88        ) -> AnyElement {
 89            let entity = entity.downcast::<T>().unwrap();
 90            entity.update(cx, |entity, cx| {
 91                entity.render(status, window, cx).into_any_element()
 92            })
 93        }
 94
 95        Self {
 96            entity: entity.into(),
 97            render: downcast_render::<T>,
 98        }
 99    }
100}
101
102impl AnyToolCard {
103    pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
104        (self.render)(self.entity.clone(), status, window, cx)
105    }
106}
107
108impl From<Task<Result<String>>> for ToolResult {
109    /// Convert from a task to a ToolResult with no card
110    fn from(output: Task<Result<String>>) -> Self {
111        Self { output, card: None }
112    }
113}
114
115#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
116pub enum ToolSource {
117    /// A native tool built-in to Zed.
118    Native,
119    /// A tool provided by a context server.
120    ContextServer { id: SharedString },
121}
122
123/// A tool that can be used by a language model.
124pub trait Tool: 'static + Send + Sync {
125    /// Returns the name of the tool.
126    fn name(&self) -> String;
127
128    /// Returns the description of the tool.
129    fn description(&self) -> String;
130
131    /// Returns the icon for the tool.
132    fn icon(&self) -> IconName;
133
134    /// Returns the source of the tool.
135    fn source(&self) -> ToolSource {
136        ToolSource::Native
137    }
138
139    /// Returns true iff the tool needs the users's confirmation
140    /// before having permission to run.
141    fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
142
143    /// Returns the JSON schema that describes the tool's input.
144    fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
145        Ok(serde_json::Value::Object(serde_json::Map::default()))
146    }
147
148    /// Returns markdown to be displayed in the UI for this tool.
149    fn ui_text(&self, input: &serde_json::Value) -> String;
150
151    /// Runs the tool with the provided input.
152    fn run(
153        self: Arc<Self>,
154        input: serde_json::Value,
155        messages: &[LanguageModelRequestMessage],
156        project: Entity<Project>,
157        action_log: Entity<ActionLog>,
158        cx: &mut App,
159    ) -> ToolResult;
160}
161
162impl Debug for dyn Tool {
163    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
164        f.debug_struct("Tool").field("name", &self.name()).finish()
165    }
166}