test_tools.rs

  1use super::*;
  2use anyhow::Result;
  3use gpui::{App, SharedString, Task};
  4use std::future;
  5use std::sync::atomic::{AtomicBool, Ordering};
  6
  7/// A tool that echoes its input
  8#[derive(JsonSchema, Serialize, Deserialize)]
  9pub struct EchoToolInput {
 10    /// The text to echo.
 11    pub text: String,
 12}
 13
 14pub struct EchoTool;
 15
 16impl AgentTool for EchoTool {
 17    type Input = EchoToolInput;
 18    type Output = String;
 19
 20    fn name() -> &'static str {
 21        "echo"
 22    }
 23
 24    fn kind() -> acp::ToolKind {
 25        acp::ToolKind::Other
 26    }
 27
 28    fn initial_title(
 29        &self,
 30        _input: Result<Self::Input, serde_json::Value>,
 31        _cx: &mut App,
 32    ) -> SharedString {
 33        "Echo".into()
 34    }
 35
 36    fn run(
 37        self: Arc<Self>,
 38        input: Self::Input,
 39        _event_stream: ToolCallEventStream,
 40        _cx: &mut App,
 41    ) -> Task<Result<String>> {
 42        Task::ready(Ok(input.text))
 43    }
 44}
 45
 46/// A tool that waits for a specified delay
 47#[derive(JsonSchema, Serialize, Deserialize)]
 48pub struct DelayToolInput {
 49    /// The delay in milliseconds.
 50    ms: u64,
 51}
 52
 53pub struct DelayTool;
 54
 55impl AgentTool for DelayTool {
 56    type Input = DelayToolInput;
 57    type Output = String;
 58
 59    fn name() -> &'static str {
 60        "delay"
 61    }
 62
 63    fn initial_title(
 64        &self,
 65        input: Result<Self::Input, serde_json::Value>,
 66        _cx: &mut App,
 67    ) -> SharedString {
 68        if let Ok(input) = input {
 69            format!("Delay {}ms", input.ms).into()
 70        } else {
 71            "Delay".into()
 72        }
 73    }
 74
 75    fn kind() -> acp::ToolKind {
 76        acp::ToolKind::Other
 77    }
 78
 79    fn run(
 80        self: Arc<Self>,
 81        input: Self::Input,
 82        _event_stream: ToolCallEventStream,
 83        cx: &mut App,
 84    ) -> Task<Result<String>>
 85    where
 86        Self: Sized,
 87    {
 88        cx.foreground_executor().spawn(async move {
 89            smol::Timer::after(Duration::from_millis(input.ms)).await;
 90            Ok("Ding".to_string())
 91        })
 92    }
 93}
 94
 95#[derive(JsonSchema, Serialize, Deserialize)]
 96pub struct ToolRequiringPermissionInput {}
 97
 98pub struct ToolRequiringPermission;
 99
100impl AgentTool for ToolRequiringPermission {
101    type Input = ToolRequiringPermissionInput;
102    type Output = String;
103
104    fn name() -> &'static str {
105        "tool_requiring_permission"
106    }
107
108    fn kind() -> acp::ToolKind {
109        acp::ToolKind::Other
110    }
111
112    fn initial_title(
113        &self,
114        _input: Result<Self::Input, serde_json::Value>,
115        _cx: &mut App,
116    ) -> SharedString {
117        "This tool requires permission".into()
118    }
119
120    fn run(
121        self: Arc<Self>,
122        _input: Self::Input,
123        event_stream: ToolCallEventStream,
124        cx: &mut App,
125    ) -> Task<Result<String>> {
126        let authorize = event_stream.authorize("Authorize?", cx);
127        cx.foreground_executor().spawn(async move {
128            authorize.await?;
129            Ok("Allowed".to_string())
130        })
131    }
132}
133
134#[derive(JsonSchema, Serialize, Deserialize)]
135pub struct InfiniteToolInput {}
136
137pub struct InfiniteTool;
138
139impl AgentTool for InfiniteTool {
140    type Input = InfiniteToolInput;
141    type Output = String;
142
143    fn name() -> &'static str {
144        "infinite"
145    }
146
147    fn kind() -> acp::ToolKind {
148        acp::ToolKind::Other
149    }
150
151    fn initial_title(
152        &self,
153        _input: Result<Self::Input, serde_json::Value>,
154        _cx: &mut App,
155    ) -> SharedString {
156        "Infinite Tool".into()
157    }
158
159    fn run(
160        self: Arc<Self>,
161        _input: Self::Input,
162        _event_stream: ToolCallEventStream,
163        cx: &mut App,
164    ) -> Task<Result<String>> {
165        cx.foreground_executor().spawn(async move {
166            future::pending::<()>().await;
167            unreachable!()
168        })
169    }
170}
171
172/// A tool that loops forever but properly handles cancellation via `select!`,
173/// similar to how edit_file_tool handles cancellation.
174#[derive(JsonSchema, Serialize, Deserialize)]
175pub struct CancellationAwareToolInput {}
176
177pub struct CancellationAwareTool {
178    pub was_cancelled: Arc<AtomicBool>,
179}
180
181impl CancellationAwareTool {
182    pub fn new() -> (Self, Arc<AtomicBool>) {
183        let was_cancelled = Arc::new(AtomicBool::new(false));
184        (
185            Self {
186                was_cancelled: was_cancelled.clone(),
187            },
188            was_cancelled,
189        )
190    }
191}
192
193impl AgentTool for CancellationAwareTool {
194    type Input = CancellationAwareToolInput;
195    type Output = String;
196
197    fn name() -> &'static str {
198        "cancellation_aware"
199    }
200
201    fn kind() -> acp::ToolKind {
202        acp::ToolKind::Other
203    }
204
205    fn initial_title(
206        &self,
207        _input: Result<Self::Input, serde_json::Value>,
208        _cx: &mut App,
209    ) -> SharedString {
210        "Cancellation Aware Tool".into()
211    }
212
213    fn run(
214        self: Arc<Self>,
215        _input: Self::Input,
216        event_stream: ToolCallEventStream,
217        cx: &mut App,
218    ) -> Task<Result<String>> {
219        cx.foreground_executor().spawn(async move {
220            // Wait for cancellation - this tool does nothing but wait to be cancelled
221            event_stream.cancelled_by_user().await;
222            self.was_cancelled.store(true, Ordering::SeqCst);
223            anyhow::bail!("Tool cancelled by user");
224        })
225    }
226}
227
228/// A tool that takes an object with map from letters to random words starting with that letter.
229/// All fiealds are required! Pass a word for every letter!
230#[derive(JsonSchema, Serialize, Deserialize)]
231pub struct WordListInput {
232    /// Provide a random word that starts with A.
233    a: Option<String>,
234    /// Provide a random word that starts with B.
235    b: Option<String>,
236    /// Provide a random word that starts with C.
237    c: Option<String>,
238    /// Provide a random word that starts with D.
239    d: Option<String>,
240    /// Provide a random word that starts with E.
241    e: Option<String>,
242    /// Provide a random word that starts with F.
243    f: Option<String>,
244    /// Provide a random word that starts with G.
245    g: Option<String>,
246}
247
248pub struct WordListTool;
249
250impl AgentTool for WordListTool {
251    type Input = WordListInput;
252    type Output = String;
253
254    fn name() -> &'static str {
255        "word_list"
256    }
257
258    fn kind() -> acp::ToolKind {
259        acp::ToolKind::Other
260    }
261
262    fn initial_title(
263        &self,
264        _input: Result<Self::Input, serde_json::Value>,
265        _cx: &mut App,
266    ) -> SharedString {
267        "List of random words".into()
268    }
269
270    fn run(
271        self: Arc<Self>,
272        _input: Self::Input,
273        _event_stream: ToolCallEventStream,
274        _cx: &mut App,
275    ) -> Task<Result<String>> {
276        Task::ready(Ok("ok".to_string()))
277    }
278}