assistant_tool.rs

 1mod tool_registry;
 2mod tool_working_set;
 3
 4use std::sync::Arc;
 5
 6use anyhow::Result;
 7use collections::HashSet;
 8use gpui::Context;
 9use gpui::{App, Entity, SharedString, Task};
10use language::Buffer;
11use language_model::LanguageModelRequestMessage;
12use project::Project;
13
14pub use crate::tool_registry::*;
15pub use crate::tool_working_set::*;
16
17pub fn init(cx: &mut App) {
18    ToolRegistry::default_global(cx);
19}
20
21#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
22pub enum ToolSource {
23    /// A native tool built-in to Zed.
24    Native,
25    /// A tool provided by a context server.
26    ContextServer { id: SharedString },
27}
28
29/// A tool that can be used by a language model.
30pub trait Tool: 'static + Send + Sync {
31    /// Returns the name of the tool.
32    fn name(&self) -> String;
33
34    /// Returns the description of the tool.
35    fn description(&self) -> String;
36
37    /// Returns the source of the tool.
38    fn source(&self) -> ToolSource {
39        ToolSource::Native
40    }
41
42    /// Returns the JSON schema that describes the tool's input.
43    fn input_schema(&self) -> serde_json::Value {
44        serde_json::Value::Object(serde_json::Map::default())
45    }
46
47    /// Runs the tool with the provided input.
48    fn run(
49        self: Arc<Self>,
50        input: serde_json::Value,
51        messages: &[LanguageModelRequestMessage],
52        project: Entity<Project>,
53        action_log: Entity<ActionLog>,
54        cx: &mut App,
55    ) -> Task<Result<String>>;
56}
57
58/// Tracks actions performed by tools in a thread
59#[derive(Debug)]
60pub struct ActionLog {
61    changed_buffers: HashSet<Entity<Buffer>>,
62    pending_refresh: HashSet<Entity<Buffer>>,
63}
64
65impl ActionLog {
66    /// Creates a new, empty action log.
67    pub fn new() -> Self {
68        Self {
69            changed_buffers: HashSet::default(),
70            pending_refresh: HashSet::default(),
71        }
72    }
73
74    /// Registers buffers that have changed and need refreshing.
75    pub fn notify_buffers_changed(
76        &mut self,
77        buffers: HashSet<Entity<Buffer>>,
78        _cx: &mut Context<Self>,
79    ) {
80        self.changed_buffers.extend(buffers.clone());
81        self.pending_refresh.extend(buffers);
82    }
83
84    /// Takes and returns the set of buffers pending refresh, clearing internal state.
85    pub fn take_pending_refresh_buffers(&mut self) -> HashSet<Entity<Buffer>> {
86        std::mem::take(&mut self.pending_refresh)
87    }
88}