1mod tool_registry;
2mod tool_working_set;
3
4use std::fmt::{self, Debug, Formatter};
5use std::sync::Arc;
6
7use anyhow::Result;
8use collections::{HashMap, HashSet};
9use gpui::{App, Context, Entity, SharedString, Task};
10use icons::IconName;
11use language::Buffer;
12use language_model::LanguageModelRequestMessage;
13use project::Project;
14
15pub use crate::tool_registry::*;
16pub use crate::tool_working_set::*;
17
18pub fn init(cx: &mut App) {
19 ToolRegistry::default_global(cx);
20}
21
22#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
23pub enum ToolSource {
24 /// A native tool built-in to Zed.
25 Native,
26 /// A tool provided by a context server.
27 ContextServer { id: SharedString },
28}
29
30/// A tool that can be used by a language model.
31pub trait Tool: 'static + Send + Sync {
32 /// Returns the name of the tool.
33 fn name(&self) -> String;
34
35 /// Returns the description of the tool.
36 fn description(&self) -> String;
37
38 /// Returns the icon for the tool.
39 fn icon(&self) -> IconName;
40
41 /// Returns the source of the tool.
42 fn source(&self) -> ToolSource {
43 ToolSource::Native
44 }
45
46 /// Returns true iff the tool needs the users's confirmation
47 /// before having permission to run.
48 fn needs_confirmation(&self) -> bool;
49
50 /// Returns the JSON schema that describes the tool's input.
51 fn input_schema(&self) -> serde_json::Value {
52 serde_json::Value::Object(serde_json::Map::default())
53 }
54
55 /// Returns markdown to be displayed in the UI for this tool.
56 fn ui_text(&self, input: &serde_json::Value) -> String;
57
58 /// Runs the tool with the provided input.
59 fn run(
60 self: Arc<Self>,
61 input: serde_json::Value,
62 messages: &[LanguageModelRequestMessage],
63 project: Entity<Project>,
64 action_log: Entity<ActionLog>,
65 cx: &mut App,
66 ) -> Task<Result<String>>;
67}
68
69impl Debug for dyn Tool {
70 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
71 f.debug_struct("Tool").field("name", &self.name()).finish()
72 }
73}
74
75/// Tracks actions performed by tools in a thread
76#[derive(Debug)]
77pub struct ActionLog {
78 /// Buffers that user manually added to the context, and whose content has
79 /// changed since the model last saw them.
80 stale_buffers_in_context: HashSet<Entity<Buffer>>,
81 /// Buffers that we want to notify the model about when they change.
82 tracked_buffers: HashMap<Entity<Buffer>, TrackedBuffer>,
83 /// Has the model edited a file since it last checked diagnostics?
84 edited_since_project_diagnostics_check: bool,
85}
86
87#[derive(Debug, Default)]
88struct TrackedBuffer {
89 version: clock::Global,
90}
91
92impl ActionLog {
93 /// Creates a new, empty action log.
94 pub fn new() -> Self {
95 Self {
96 stale_buffers_in_context: HashSet::default(),
97 tracked_buffers: HashMap::default(),
98 edited_since_project_diagnostics_check: false,
99 }
100 }
101
102 /// Track a buffer as read, so we can notify the model about user edits.
103 pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
104 let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
105 tracked_buffer.version = buffer.read(cx).version();
106 }
107
108 /// Mark a buffer as edited, so we can refresh it in the context
109 pub fn buffer_edited(&mut self, buffers: HashSet<Entity<Buffer>>, cx: &mut Context<Self>) {
110 for buffer in &buffers {
111 let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
112 tracked_buffer.version = buffer.read(cx).version();
113 }
114
115 self.stale_buffers_in_context.extend(buffers);
116 self.edited_since_project_diagnostics_check = true;
117 }
118
119 /// Notifies a diagnostics check
120 pub fn checked_project_diagnostics(&mut self) {
121 self.edited_since_project_diagnostics_check = false;
122 }
123
124 /// Iterate over buffers changed since last read or edited by the model
125 pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
126 self.tracked_buffers
127 .iter()
128 .filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version)
129 .map(|(buffer, _)| buffer)
130 }
131
132 /// Returns true if any files have been edited since the last project diagnostics check
133 pub fn has_edited_files_since_project_diagnostics_check(&self) -> bool {
134 self.edited_since_project_diagnostics_check
135 }
136
137 /// Takes and returns the set of buffers pending refresh, clearing internal state.
138 pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
139 std::mem::take(&mut self.stale_buffers_in_context)
140 }
141}