assistant_tool.rs

  1mod tool_registry;
  2mod tool_working_set;
  3
  4use std::sync::Arc;
  5
  6use anyhow::Result;
  7use collections::{HashMap, HashSet};
  8use gpui::{App, Context, Entity, SharedString, Task};
  9use language::Buffer;
 10use language_model::LanguageModelRequestMessage;
 11use project::Project;
 12
 13pub use crate::tool_registry::*;
 14pub use crate::tool_working_set::*;
 15
 16pub fn init(cx: &mut App) {
 17    ToolRegistry::default_global(cx);
 18}
 19
 20#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
 21pub enum ToolSource {
 22    /// A native tool built-in to Zed.
 23    Native,
 24    /// A tool provided by a context server.
 25    ContextServer { id: SharedString },
 26}
 27
 28/// A tool that can be used by a language model.
 29pub trait Tool: 'static + Send + Sync {
 30    /// Returns the name of the tool.
 31    fn name(&self) -> String;
 32
 33    /// Returns the description of the tool.
 34    fn description(&self) -> String;
 35
 36    /// Returns the source of the tool.
 37    fn source(&self) -> ToolSource {
 38        ToolSource::Native
 39    }
 40
 41    /// Returns the JSON schema that describes the tool's input.
 42    fn input_schema(&self) -> serde_json::Value {
 43        serde_json::Value::Object(serde_json::Map::default())
 44    }
 45
 46    /// Returns markdown to be displayed in the UI for this tool.
 47    fn ui_text(&self, input: &serde_json::Value) -> String;
 48
 49    /// Runs the tool with the provided input.
 50    fn run(
 51        self: Arc<Self>,
 52        input: serde_json::Value,
 53        messages: &[LanguageModelRequestMessage],
 54        project: Entity<Project>,
 55        action_log: Entity<ActionLog>,
 56        cx: &mut App,
 57    ) -> Task<Result<String>>;
 58}
 59
 60/// Tracks actions performed by tools in a thread
 61#[derive(Debug)]
 62pub struct ActionLog {
 63    /// Buffers that user manually added to the context, and whose content has
 64    /// changed since the model last saw them.
 65    stale_buffers_in_context: HashSet<Entity<Buffer>>,
 66    /// Buffers that we want to notify the model about when they change.
 67    tracked_buffers: HashMap<Entity<Buffer>, TrackedBuffer>,
 68}
 69
 70#[derive(Debug, Default)]
 71struct TrackedBuffer {
 72    version: clock::Global,
 73}
 74
 75impl ActionLog {
 76    /// Creates a new, empty action log.
 77    pub fn new() -> Self {
 78        Self {
 79            stale_buffers_in_context: HashSet::default(),
 80            tracked_buffers: HashMap::default(),
 81        }
 82    }
 83
 84    /// Track a buffer as read, so we can notify the model about user edits.
 85    pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
 86        let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
 87        tracked_buffer.version = buffer.read(cx).version();
 88    }
 89
 90    /// Mark a buffer as edited, so we can refresh it in the context
 91    pub fn buffer_edited(&mut self, buffers: HashSet<Entity<Buffer>>, cx: &mut Context<Self>) {
 92        for buffer in &buffers {
 93            let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
 94            tracked_buffer.version = buffer.read(cx).version();
 95        }
 96
 97        self.stale_buffers_in_context.extend(buffers);
 98    }
 99
100    /// Iterate over buffers changed since last read or edited by the model
101    pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
102        self.tracked_buffers
103            .iter()
104            .filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version)
105            .map(|(buffer, _)| buffer)
106    }
107
108    /// Takes and returns the set of buffers pending refresh, clearing internal state.
109    pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
110        std::mem::take(&mut self.stale_buffers_in_context)
111    }
112}