1mod tool_registry;
2mod tool_working_set;
3
4use std::sync::Arc;
5
6use anyhow::Result;
7use collections::{HashMap, HashSet};
8use gpui::Context;
9use gpui::{App, Entity, SharedString, Task};
10use language::Buffer;
11use language_model::LanguageModelRequestMessage;
12use project::Project;
13
14pub use crate::tool_registry::*;
15pub use crate::tool_working_set::*;
16
17pub fn init(cx: &mut App) {
18 ToolRegistry::default_global(cx);
19}
20
21#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
22pub enum ToolSource {
23 /// A native tool built-in to Zed.
24 Native,
25 /// A tool provided by a context server.
26 ContextServer { id: SharedString },
27}
28
29/// A tool that can be used by a language model.
30pub trait Tool: 'static + Send + Sync {
31 /// Returns the name of the tool.
32 fn name(&self) -> String;
33
34 /// Returns the description of the tool.
35 fn description(&self) -> String;
36
37 /// Returns the source of the tool.
38 fn source(&self) -> ToolSource {
39 ToolSource::Native
40 }
41
42 /// Returns the JSON schema that describes the tool's input.
43 fn input_schema(&self) -> serde_json::Value {
44 serde_json::Value::Object(serde_json::Map::default())
45 }
46
47 /// Runs the tool with the provided input.
48 fn run(
49 self: Arc<Self>,
50 input: serde_json::Value,
51 messages: &[LanguageModelRequestMessage],
52 project: Entity<Project>,
53 action_log: Entity<ActionLog>,
54 cx: &mut App,
55 ) -> Task<Result<String>>;
56}
57
58/// Tracks actions performed by tools in a thread
59#[derive(Debug)]
60pub struct ActionLog {
61 /// Buffers that user manually added to the context, and whose content has
62 /// changed since the model last saw them.
63 stale_buffers_in_context: HashSet<Entity<Buffer>>,
64 /// Buffers that we want to notify the model about when they change.
65 tracked_buffers: HashMap<Entity<Buffer>, TrackedBuffer>,
66}
67
68#[derive(Debug, Default)]
69struct TrackedBuffer {
70 version: clock::Global,
71}
72
73impl ActionLog {
74 /// Creates a new, empty action log.
75 pub fn new() -> Self {
76 Self {
77 stale_buffers_in_context: HashSet::default(),
78 tracked_buffers: HashMap::default(),
79 }
80 }
81
82 /// Track a buffer as read, so we can notify the model about user edits.
83 pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
84 let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
85 tracked_buffer.version = buffer.read(cx).version();
86 }
87
88 /// Mark a buffer as edited, so we can refresh it in the context
89 pub fn buffer_edited(&mut self, buffers: HashSet<Entity<Buffer>>, cx: &mut Context<Self>) {
90 for buffer in &buffers {
91 let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
92 tracked_buffer.version = buffer.read(cx).version();
93 }
94
95 self.stale_buffers_in_context.extend(buffers);
96 }
97
98 /// Iterate over buffers changed since last read or edited by the model
99 pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
100 self.tracked_buffers
101 .iter()
102 .filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version)
103 .map(|(buffer, _)| buffer)
104 }
105
106 /// Takes and returns the set of buffers pending refresh, clearing internal state.
107 pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
108 std::mem::take(&mut self.stale_buffers_in_context)
109 }
110}