assistant_tool.rs

  1mod tool_registry;
  2mod tool_working_set;
  3
  4use std::sync::Arc;
  5
  6use anyhow::Result;
  7use collections::{HashMap, HashSet};
  8use gpui::Context;
  9use gpui::{App, Entity, SharedString, Task};
 10use language::Buffer;
 11use language_model::LanguageModelRequestMessage;
 12use project::Project;
 13
 14pub use crate::tool_registry::*;
 15pub use crate::tool_working_set::*;
 16
 17pub fn init(cx: &mut App) {
 18    ToolRegistry::default_global(cx);
 19}
 20
 21#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
 22pub enum ToolSource {
 23    /// A native tool built-in to Zed.
 24    Native,
 25    /// A tool provided by a context server.
 26    ContextServer { id: SharedString },
 27}
 28
 29/// A tool that can be used by a language model.
 30pub trait Tool: 'static + Send + Sync {
 31    /// Returns the name of the tool.
 32    fn name(&self) -> String;
 33
 34    /// Returns the description of the tool.
 35    fn description(&self) -> String;
 36
 37    /// Returns the source of the tool.
 38    fn source(&self) -> ToolSource {
 39        ToolSource::Native
 40    }
 41
 42    /// Returns the JSON schema that describes the tool's input.
 43    fn input_schema(&self) -> serde_json::Value {
 44        serde_json::Value::Object(serde_json::Map::default())
 45    }
 46
 47    /// Runs the tool with the provided input.
 48    fn run(
 49        self: Arc<Self>,
 50        input: serde_json::Value,
 51        messages: &[LanguageModelRequestMessage],
 52        project: Entity<Project>,
 53        action_log: Entity<ActionLog>,
 54        cx: &mut App,
 55    ) -> Task<Result<String>>;
 56}
 57
 58/// Tracks actions performed by tools in a thread
 59#[derive(Debug)]
 60pub struct ActionLog {
 61    /// Buffers that user manually added to the context, and whose content has
 62    /// changed since the model last saw them.
 63    stale_buffers_in_context: HashSet<Entity<Buffer>>,
 64    /// Buffers that we want to notify the model about when they change.
 65    tracked_buffers: HashMap<Entity<Buffer>, TrackedBuffer>,
 66}
 67
 68#[derive(Debug, Default)]
 69struct TrackedBuffer {
 70    version: clock::Global,
 71}
 72
 73impl ActionLog {
 74    /// Creates a new, empty action log.
 75    pub fn new() -> Self {
 76        Self {
 77            stale_buffers_in_context: HashSet::default(),
 78            tracked_buffers: HashMap::default(),
 79        }
 80    }
 81
 82    /// Track a buffer as read, so we can notify the model about user edits.
 83    pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
 84        let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
 85        tracked_buffer.version = buffer.read(cx).version();
 86    }
 87
 88    /// Mark a buffer as edited, so we can refresh it in the context
 89    pub fn buffer_edited(&mut self, buffers: HashSet<Entity<Buffer>>, cx: &mut Context<Self>) {
 90        for buffer in &buffers {
 91            let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
 92            tracked_buffer.version = buffer.read(cx).version();
 93        }
 94
 95        self.stale_buffers_in_context.extend(buffers);
 96    }
 97
 98    /// Iterate over buffers changed since last read or edited by the model
 99    pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
100        self.tracked_buffers
101            .iter()
102            .filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version)
103            .map(|(buffer, _)| buffer)
104    }
105
106    /// Takes and returns the set of buffers pending refresh, clearing internal state.
107    pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
108        std::mem::take(&mut self.stale_buffers_in_context)
109    }
110}