1mod tool_registry;
2mod tool_working_set;
3
4use std::sync::Arc;
5
6use anyhow::Result;
7use collections::HashSet;
8use gpui::Context;
9use gpui::{App, Entity, SharedString, Task};
10use language::Buffer;
11use language_model::LanguageModelRequestMessage;
12use project::Project;
13
14pub use crate::tool_registry::*;
15pub use crate::tool_working_set::*;
16
17pub fn init(cx: &mut App) {
18 ToolRegistry::default_global(cx);
19}
20
21#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
22pub enum ToolSource {
23 /// A native tool built-in to Zed.
24 Native,
25 /// A tool provided by a context server.
26 ContextServer { id: SharedString },
27}
28
29/// A tool that can be used by a language model.
30pub trait Tool: 'static + Send + Sync {
31 /// Returns the name of the tool.
32 fn name(&self) -> String;
33
34 /// Returns the description of the tool.
35 fn description(&self) -> String;
36
37 /// Returns the source of the tool.
38 fn source(&self) -> ToolSource {
39 ToolSource::Native
40 }
41
42 /// Returns the JSON schema that describes the tool's input.
43 fn input_schema(&self) -> serde_json::Value {
44 serde_json::Value::Object(serde_json::Map::default())
45 }
46
47 /// Runs the tool with the provided input.
48 fn run(
49 self: Arc<Self>,
50 input: serde_json::Value,
51 messages: &[LanguageModelRequestMessage],
52 project: Entity<Project>,
53 action_log: Entity<ActionLog>,
54 cx: &mut App,
55 ) -> Task<Result<String>>;
56}
57
58/// Tracks actions performed by tools in a thread
59#[derive(Debug)]
60pub struct ActionLog {
61 changed_buffers: HashSet<Entity<Buffer>>,
62 pending_refresh: HashSet<Entity<Buffer>>,
63}
64
65impl ActionLog {
66 /// Creates a new, empty action log.
67 pub fn new() -> Self {
68 Self {
69 changed_buffers: HashSet::default(),
70 pending_refresh: HashSet::default(),
71 }
72 }
73
74 /// Registers buffers that have changed and need refreshing.
75 pub fn notify_buffers_changed(
76 &mut self,
77 buffers: HashSet<Entity<Buffer>>,
78 _cx: &mut Context<Self>,
79 ) {
80 self.changed_buffers.extend(buffers.clone());
81 self.pending_refresh.extend(buffers);
82 }
83
84 /// Takes and returns the set of buffers pending refresh, clearing internal state.
85 pub fn take_pending_refresh_buffers(&mut self) -> HashSet<Entity<Buffer>> {
86 std::mem::take(&mut self.pending_refresh)
87 }
88}