1mod tool_registry;
2mod tool_working_set;
3
4use std::sync::Arc;
5
6use anyhow::Result;
7use collections::{HashMap, HashSet};
8use gpui::{App, Context, Entity, SharedString, Task};
9use language::Buffer;
10use language_model::LanguageModelRequestMessage;
11use project::Project;
12
13pub use crate::tool_registry::*;
14pub use crate::tool_working_set::*;
15
16pub fn init(cx: &mut App) {
17 ToolRegistry::default_global(cx);
18}
19
20#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
21pub enum ToolSource {
22 /// A native tool built-in to Zed.
23 Native,
24 /// A tool provided by a context server.
25 ContextServer { id: SharedString },
26}
27
28/// A tool that can be used by a language model.
29pub trait Tool: 'static + Send + Sync {
30 /// Returns the name of the tool.
31 fn name(&self) -> String;
32
33 /// Returns the description of the tool.
34 fn description(&self) -> String;
35
36 /// Returns the source of the tool.
37 fn source(&self) -> ToolSource {
38 ToolSource::Native
39 }
40
41 /// Returns the JSON schema that describes the tool's input.
42 fn input_schema(&self) -> serde_json::Value {
43 serde_json::Value::Object(serde_json::Map::default())
44 }
45
46 /// Returns markdown to be displayed in the UI for this tool.
47 fn ui_text(&self, input: &serde_json::Value) -> String;
48
49 /// Runs the tool with the provided input.
50 fn run(
51 self: Arc<Self>,
52 input: serde_json::Value,
53 messages: &[LanguageModelRequestMessage],
54 project: Entity<Project>,
55 action_log: Entity<ActionLog>,
56 cx: &mut App,
57 ) -> Task<Result<String>>;
58}
59
60/// Tracks actions performed by tools in a thread
61#[derive(Debug)]
62pub struct ActionLog {
63 /// Buffers that user manually added to the context, and whose content has
64 /// changed since the model last saw them.
65 stale_buffers_in_context: HashSet<Entity<Buffer>>,
66 /// Buffers that we want to notify the model about when they change.
67 tracked_buffers: HashMap<Entity<Buffer>, TrackedBuffer>,
68}
69
70#[derive(Debug, Default)]
71struct TrackedBuffer {
72 version: clock::Global,
73}
74
75impl ActionLog {
76 /// Creates a new, empty action log.
77 pub fn new() -> Self {
78 Self {
79 stale_buffers_in_context: HashSet::default(),
80 tracked_buffers: HashMap::default(),
81 }
82 }
83
84 /// Track a buffer as read, so we can notify the model about user edits.
85 pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
86 let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
87 tracked_buffer.version = buffer.read(cx).version();
88 }
89
90 /// Mark a buffer as edited, so we can refresh it in the context
91 pub fn buffer_edited(&mut self, buffers: HashSet<Entity<Buffer>>, cx: &mut Context<Self>) {
92 for buffer in &buffers {
93 let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
94 tracked_buffer.version = buffer.read(cx).version();
95 }
96
97 self.stale_buffers_in_context.extend(buffers);
98 }
99
100 /// Iterate over buffers changed since last read or edited by the model
101 pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
102 self.tracked_buffers
103 .iter()
104 .filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version)
105 .map(|(buffer, _)| buffer)
106 }
107
108 /// Takes and returns the set of buffers pending refresh, clearing internal state.
109 pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
110 std::mem::take(&mut self.stale_buffers_in_context)
111 }
112}