1mod tool_registry;
2mod tool_working_set;
3
4use anyhow::Result;
5use collections::{HashMap, HashSet};
6use gpui::{App, Context, Entity, SharedString, Task};
7use language::Buffer;
8use language_model::LanguageModelRequestMessage;
9use project::Project;
10use std::fmt::{self, Debug, Formatter};
11use std::sync::Arc;
12use ui::IconName;
13
14pub use crate::tool_registry::*;
15pub use crate::tool_working_set::*;
16
17pub fn init(cx: &mut App) {
18 ToolRegistry::default_global(cx);
19}
20
21#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
22pub enum ToolSource {
23 /// A native tool built-in to Zed.
24 Native,
25 /// A tool provided by a context server.
26 ContextServer { id: SharedString },
27}
28
29/// A tool that can be used by a language model.
30pub trait Tool: 'static + Send + Sync {
31 /// Returns the name of the tool.
32 fn name(&self) -> String;
33
34 /// Returns the description of the tool.
35 fn description(&self) -> String;
36
37 /// Returns the icon for the tool.
38 fn icon(&self) -> IconName;
39
40 /// Returns the source of the tool.
41 fn source(&self) -> ToolSource {
42 ToolSource::Native
43 }
44
45 /// Returns true iff the tool needs the users's confirmation
46 /// before having permission to run.
47 fn needs_confirmation(&self) -> bool;
48
49 /// Returns the JSON schema that describes the tool's input.
50 fn input_schema(&self) -> serde_json::Value {
51 serde_json::Value::Object(serde_json::Map::default())
52 }
53
54 /// Returns markdown to be displayed in the UI for this tool.
55 fn ui_text(&self, input: &serde_json::Value) -> String;
56
57 /// Runs the tool with the provided input.
58 fn run(
59 self: Arc<Self>,
60 input: serde_json::Value,
61 messages: &[LanguageModelRequestMessage],
62 project: Entity<Project>,
63 action_log: Entity<ActionLog>,
64 cx: &mut App,
65 ) -> Task<Result<String>>;
66}
67
68impl Debug for dyn Tool {
69 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
70 f.debug_struct("Tool").field("name", &self.name()).finish()
71 }
72}
73
74/// Tracks actions performed by tools in a thread
75#[derive(Debug)]
76pub struct ActionLog {
77 /// Buffers that user manually added to the context, and whose content has
78 /// changed since the model last saw them.
79 stale_buffers_in_context: HashSet<Entity<Buffer>>,
80 /// Buffers that we want to notify the model about when they change.
81 tracked_buffers: HashMap<Entity<Buffer>, TrackedBuffer>,
82}
83
84#[derive(Debug, Default)]
85struct TrackedBuffer {
86 version: clock::Global,
87}
88
89impl ActionLog {
90 /// Creates a new, empty action log.
91 pub fn new() -> Self {
92 Self {
93 stale_buffers_in_context: HashSet::default(),
94 tracked_buffers: HashMap::default(),
95 }
96 }
97
98 /// Track a buffer as read, so we can notify the model about user edits.
99 pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
100 let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
101 tracked_buffer.version = buffer.read(cx).version();
102 }
103
104 /// Mark a buffer as edited, so we can refresh it in the context
105 pub fn buffer_edited(&mut self, buffers: HashSet<Entity<Buffer>>, cx: &mut Context<Self>) {
106 for buffer in &buffers {
107 let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
108 tracked_buffer.version = buffer.read(cx).version();
109 }
110
111 self.stale_buffers_in_context.extend(buffers);
112 }
113
114 /// Iterate over buffers changed since last read or edited by the model
115 pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
116 self.tracked_buffers
117 .iter()
118 .filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version)
119 .map(|(buffer, _)| buffer)
120 }
121
122 /// Takes and returns the set of buffers pending refresh, clearing internal state.
123 pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
124 std::mem::take(&mut self.stale_buffers_in_context)
125 }
126}