assistant_tool.rs

  1mod tool_registry;
  2mod tool_working_set;
  3
  4use std::fmt::{self, Debug, Formatter};
  5use std::sync::Arc;
  6
  7use anyhow::Result;
  8use collections::{HashMap, HashSet};
  9use gpui::{App, Context, Entity, SharedString, Task};
 10use icons::IconName;
 11use language::Buffer;
 12use language_model::LanguageModelRequestMessage;
 13use project::Project;
 14
 15pub use crate::tool_registry::*;
 16pub use crate::tool_working_set::*;
 17
 18pub fn init(cx: &mut App) {
 19    ToolRegistry::default_global(cx);
 20}
 21
 22#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
 23pub enum ToolSource {
 24    /// A native tool built-in to Zed.
 25    Native,
 26    /// A tool provided by a context server.
 27    ContextServer { id: SharedString },
 28}
 29
 30/// A tool that can be used by a language model.
 31pub trait Tool: 'static + Send + Sync {
 32    /// Returns the name of the tool.
 33    fn name(&self) -> String;
 34
 35    /// Returns the description of the tool.
 36    fn description(&self) -> String;
 37
 38    /// Returns the icon for the tool.
 39    fn icon(&self) -> IconName;
 40
 41    /// Returns the source of the tool.
 42    fn source(&self) -> ToolSource {
 43        ToolSource::Native
 44    }
 45
 46    /// Returns true iff the tool needs the users's confirmation
 47    /// before having permission to run.
 48    fn needs_confirmation(&self) -> bool;
 49
 50    /// Returns the JSON schema that describes the tool's input.
 51    fn input_schema(&self) -> serde_json::Value {
 52        serde_json::Value::Object(serde_json::Map::default())
 53    }
 54
 55    /// Returns markdown to be displayed in the UI for this tool.
 56    fn ui_text(&self, input: &serde_json::Value) -> String;
 57
 58    /// Runs the tool with the provided input.
 59    fn run(
 60        self: Arc<Self>,
 61        input: serde_json::Value,
 62        messages: &[LanguageModelRequestMessage],
 63        project: Entity<Project>,
 64        action_log: Entity<ActionLog>,
 65        cx: &mut App,
 66    ) -> Task<Result<String>>;
 67}
 68
 69impl Debug for dyn Tool {
 70    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 71        f.debug_struct("Tool").field("name", &self.name()).finish()
 72    }
 73}
 74
 75/// Tracks actions performed by tools in a thread
 76#[derive(Debug)]
 77pub struct ActionLog {
 78    /// Buffers that user manually added to the context, and whose content has
 79    /// changed since the model last saw them.
 80    stale_buffers_in_context: HashSet<Entity<Buffer>>,
 81    /// Buffers that we want to notify the model about when they change.
 82    tracked_buffers: HashMap<Entity<Buffer>, TrackedBuffer>,
 83    /// Has the model edited a file since it last checked diagnostics?
 84    edited_since_project_diagnostics_check: bool,
 85}
 86
 87#[derive(Debug, Default)]
 88struct TrackedBuffer {
 89    version: clock::Global,
 90}
 91
 92impl ActionLog {
 93    /// Creates a new, empty action log.
 94    pub fn new() -> Self {
 95        Self {
 96            stale_buffers_in_context: HashSet::default(),
 97            tracked_buffers: HashMap::default(),
 98            edited_since_project_diagnostics_check: false,
 99        }
100    }
101
102    /// Track a buffer as read, so we can notify the model about user edits.
103    pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
104        let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
105        tracked_buffer.version = buffer.read(cx).version();
106    }
107
108    /// Mark a buffer as edited, so we can refresh it in the context
109    pub fn buffer_edited(&mut self, buffers: HashSet<Entity<Buffer>>, cx: &mut Context<Self>) {
110        for buffer in &buffers {
111            let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
112            tracked_buffer.version = buffer.read(cx).version();
113        }
114
115        self.stale_buffers_in_context.extend(buffers);
116        self.edited_since_project_diagnostics_check = true;
117    }
118
119    /// Notifies a diagnostics check
120    pub fn checked_project_diagnostics(&mut self) {
121        self.edited_since_project_diagnostics_check = false;
122    }
123
124    /// Iterate over buffers changed since last read or edited by the model
125    pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
126        self.tracked_buffers
127            .iter()
128            .filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version)
129            .map(|(buffer, _)| buffer)
130    }
131
132    /// Returns true if any files have been edited since the last project diagnostics check
133    pub fn has_edited_files_since_project_diagnostics_check(&self) -> bool {
134        self.edited_since_project_diagnostics_check
135    }
136
137    /// Takes and returns the set of buffers pending refresh, clearing internal state.
138    pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
139        std::mem::take(&mut self.stale_buffers_in_context)
140    }
141}