test_tools.rs

  1use super::*;
  2use agent_settings::AgentSettings;
  3use gpui::{App, SharedString, Task};
  4use std::future;
  5use std::sync::Mutex;
  6use std::sync::atomic::{AtomicBool, Ordering};
  7use std::time::Duration;
  8
  9/// A streaming tool that echoes its input, used to test streaming tool
 10/// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
 11/// before `is_input_complete`).
 12#[derive(JsonSchema, Serialize, Deserialize)]
 13pub struct StreamingEchoToolInput {
 14    /// The text to echo.
 15    pub text: String,
 16}
 17
 18pub struct StreamingEchoTool {
 19    wait_until_complete_rx: Mutex<Option<oneshot::Receiver<()>>>,
 20}
 21
 22impl StreamingEchoTool {
 23    pub fn new() -> Self {
 24        Self {
 25            wait_until_complete_rx: Mutex::new(None),
 26        }
 27    }
 28
 29    pub fn with_wait_until_complete(mut self, receiver: oneshot::Receiver<()>) -> Self {
 30        self.wait_until_complete_rx = Mutex::new(Some(receiver));
 31        self
 32    }
 33}
 34
 35impl AgentTool for StreamingEchoTool {
 36    type Input = StreamingEchoToolInput;
 37    type Output = String;
 38
 39    const NAME: &'static str = "streaming_echo";
 40
 41    fn supports_input_streaming() -> bool {
 42        true
 43    }
 44
 45    fn kind() -> acp::ToolKind {
 46        acp::ToolKind::Other
 47    }
 48
 49    fn initial_title(
 50        &self,
 51        _input: Result<Self::Input, serde_json::Value>,
 52        _cx: &mut App,
 53    ) -> SharedString {
 54        "Streaming Echo".into()
 55    }
 56
 57    fn run(
 58        self: Arc<Self>,
 59        mut input: ToolInput<Self::Input>,
 60        _event_stream: ToolCallEventStream,
 61        cx: &mut App,
 62    ) -> Task<Result<String, String>> {
 63        let wait_until_complete_rx = self.wait_until_complete_rx.lock().unwrap().take();
 64        cx.spawn(async move |_cx| {
 65            while input.recv_partial().await.is_some() {}
 66            let input = input
 67                .recv()
 68                .await
 69                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
 70            if let Some(rx) = wait_until_complete_rx {
 71                rx.await.ok();
 72            }
 73            Ok(input.text)
 74        })
 75    }
 76}
 77
 78/// A streaming tool that echoes its input, used to test streaming tool
 79/// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
 80/// before `is_input_complete`).
 81#[derive(JsonSchema, Serialize, Deserialize)]
 82pub struct StreamingFailingEchoToolInput {
 83    /// The text to echo.
 84    pub text: String,
 85}
 86
 87pub struct StreamingFailingEchoTool {
 88    pub receive_chunks_until_failure: usize,
 89}
 90
 91impl AgentTool for StreamingFailingEchoTool {
 92    type Input = StreamingFailingEchoToolInput;
 93
 94    type Output = String;
 95
 96    const NAME: &'static str = "streaming_failing_echo";
 97
 98    fn kind() -> acp::ToolKind {
 99        acp::ToolKind::Other
100    }
101
102    fn supports_input_streaming() -> bool {
103        true
104    }
105
106    fn initial_title(
107        &self,
108        _input: Result<Self::Input, serde_json::Value>,
109        _cx: &mut App,
110    ) -> SharedString {
111        "echo".into()
112    }
113
114    fn run(
115        self: Arc<Self>,
116        mut input: ToolInput<Self::Input>,
117        _event_stream: ToolCallEventStream,
118        cx: &mut App,
119    ) -> Task<Result<Self::Output, Self::Output>> {
120        cx.spawn(async move |_cx| {
121            for _ in 0..self.receive_chunks_until_failure {
122                let _ = input.recv_partial().await;
123            }
124            Err("failed".into())
125        })
126    }
127}
128
129/// A tool that echoes its input
130#[derive(JsonSchema, Serialize, Deserialize)]
131pub struct EchoToolInput {
132    /// The text to echo.
133    pub text: String,
134}
135
136pub struct EchoTool;
137
138impl AgentTool for EchoTool {
139    type Input = EchoToolInput;
140    type Output = String;
141
142    const NAME: &'static str = "echo";
143
144    fn kind() -> acp::ToolKind {
145        acp::ToolKind::Other
146    }
147
148    fn initial_title(
149        &self,
150        _input: Result<Self::Input, serde_json::Value>,
151        _cx: &mut App,
152    ) -> SharedString {
153        "Echo".into()
154    }
155
156    fn run(
157        self: Arc<Self>,
158        input: ToolInput<Self::Input>,
159        _event_stream: ToolCallEventStream,
160        cx: &mut App,
161    ) -> Task<Result<String, String>> {
162        cx.spawn(async move |_cx| {
163            let input = input
164                .recv()
165                .await
166                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
167            Ok(input.text)
168        })
169    }
170}
171
172/// A tool that waits for a specified delay
173#[derive(JsonSchema, Serialize, Deserialize)]
174pub struct DelayToolInput {
175    /// The delay in milliseconds.
176    ms: u64,
177}
178
179pub struct DelayTool;
180
181impl AgentTool for DelayTool {
182    type Input = DelayToolInput;
183    type Output = String;
184
185    const NAME: &'static str = "delay";
186
187    fn initial_title(
188        &self,
189        input: Result<Self::Input, serde_json::Value>,
190        _cx: &mut App,
191    ) -> SharedString {
192        if let Ok(input) = input {
193            format!("Delay {}ms", input.ms).into()
194        } else {
195            "Delay".into()
196        }
197    }
198
199    fn kind() -> acp::ToolKind {
200        acp::ToolKind::Other
201    }
202
203    fn run(
204        self: Arc<Self>,
205        input: ToolInput<Self::Input>,
206        _event_stream: ToolCallEventStream,
207        cx: &mut App,
208    ) -> Task<Result<String, String>>
209    where
210        Self: Sized,
211    {
212        let executor = cx.background_executor().clone();
213        cx.foreground_executor().spawn(async move {
214            let input = input
215                .recv()
216                .await
217                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
218            executor.timer(Duration::from_millis(input.ms)).await;
219            Ok("Ding".to_string())
220        })
221    }
222}
223
224#[derive(JsonSchema, Serialize, Deserialize)]
225pub struct ToolRequiringPermissionInput {}
226
227pub struct ToolRequiringPermission;
228
229impl AgentTool for ToolRequiringPermission {
230    type Input = ToolRequiringPermissionInput;
231    type Output = String;
232
233    const NAME: &'static str = "tool_requiring_permission";
234
235    fn kind() -> acp::ToolKind {
236        acp::ToolKind::Other
237    }
238
239    fn initial_title(
240        &self,
241        _input: Result<Self::Input, serde_json::Value>,
242        _cx: &mut App,
243    ) -> SharedString {
244        "This tool requires permission".into()
245    }
246
247    fn run(
248        self: Arc<Self>,
249        input: ToolInput<Self::Input>,
250        event_stream: ToolCallEventStream,
251        cx: &mut App,
252    ) -> Task<Result<String, String>> {
253        cx.spawn(async move |cx| {
254            let _input = input
255                .recv()
256                .await
257                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
258
259            let decision = cx.update(|cx| {
260                decide_permission_from_settings(
261                    Self::NAME,
262                    &[String::new()],
263                    AgentSettings::get_global(cx),
264                )
265            });
266
267            let authorize = match decision {
268                ToolPermissionDecision::Allow => None,
269                ToolPermissionDecision::Deny(reason) => {
270                    return Err(reason);
271                }
272                ToolPermissionDecision::Confirm => Some(cx.update(|cx| {
273                    let context = crate::ToolPermissionContext::new(
274                        "tool_requiring_permission",
275                        vec![String::new()],
276                    );
277                    event_stream.authorize("Authorize?", context, cx)
278                })),
279            };
280
281            if let Some(authorize) = authorize {
282                authorize.await.map_err(|e| e.to_string())?;
283            }
284            Ok("Allowed".to_string())
285        })
286    }
287}
288
289#[derive(JsonSchema, Serialize, Deserialize)]
290pub struct InfiniteToolInput {}
291
292pub struct InfiniteTool;
293
294impl AgentTool for InfiniteTool {
295    type Input = InfiniteToolInput;
296    type Output = String;
297
298    const NAME: &'static str = "infinite";
299
300    fn kind() -> acp::ToolKind {
301        acp::ToolKind::Other
302    }
303
304    fn initial_title(
305        &self,
306        _input: Result<Self::Input, serde_json::Value>,
307        _cx: &mut App,
308    ) -> SharedString {
309        "Infinite Tool".into()
310    }
311
312    fn run(
313        self: Arc<Self>,
314        input: ToolInput<Self::Input>,
315        _event_stream: ToolCallEventStream,
316        cx: &mut App,
317    ) -> Task<Result<String, String>> {
318        cx.foreground_executor().spawn(async move {
319            let _input = input
320                .recv()
321                .await
322                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
323            future::pending::<()>().await;
324            unreachable!()
325        })
326    }
327}
328
329/// A tool that loops forever but properly handles cancellation via `select!`,
330/// similar to how edit_file_tool handles cancellation.
331#[derive(JsonSchema, Serialize, Deserialize)]
332pub struct CancellationAwareToolInput {}
333
334pub struct CancellationAwareTool {
335    pub was_cancelled: Arc<AtomicBool>,
336}
337
338impl CancellationAwareTool {
339    pub fn new() -> (Self, Arc<AtomicBool>) {
340        let was_cancelled = Arc::new(AtomicBool::new(false));
341        (
342            Self {
343                was_cancelled: was_cancelled.clone(),
344            },
345            was_cancelled,
346        )
347    }
348}
349
350impl AgentTool for CancellationAwareTool {
351    type Input = CancellationAwareToolInput;
352    type Output = String;
353
354    const NAME: &'static str = "cancellation_aware";
355
356    fn kind() -> acp::ToolKind {
357        acp::ToolKind::Other
358    }
359
360    fn initial_title(
361        &self,
362        _input: Result<Self::Input, serde_json::Value>,
363        _cx: &mut App,
364    ) -> SharedString {
365        "Cancellation Aware Tool".into()
366    }
367
368    fn run(
369        self: Arc<Self>,
370        input: ToolInput<Self::Input>,
371        event_stream: ToolCallEventStream,
372        cx: &mut App,
373    ) -> Task<Result<String, String>> {
374        cx.foreground_executor().spawn(async move {
375            let _input = input
376                .recv()
377                .await
378                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
379            // Wait for cancellation - this tool does nothing but wait to be cancelled
380            event_stream.cancelled_by_user().await;
381            self.was_cancelled.store(true, Ordering::SeqCst);
382            Err("Tool cancelled by user".to_string())
383        })
384    }
385}
386
387/// A tool that takes an object with map from letters to random words starting with that letter.
388/// All fiealds are required! Pass a word for every letter!
389#[derive(JsonSchema, Serialize, Deserialize)]
390pub struct WordListInput {
391    /// Provide a random word that starts with A.
392    a: Option<String>,
393    /// Provide a random word that starts with B.
394    b: Option<String>,
395    /// Provide a random word that starts with C.
396    c: Option<String>,
397    /// Provide a random word that starts with D.
398    d: Option<String>,
399    /// Provide a random word that starts with E.
400    e: Option<String>,
401    /// Provide a random word that starts with F.
402    f: Option<String>,
403    /// Provide a random word that starts with G.
404    g: Option<String>,
405}
406
407pub struct WordListTool;
408
409impl AgentTool for WordListTool {
410    type Input = WordListInput;
411    type Output = String;
412
413    const NAME: &'static str = "word_list";
414
415    fn kind() -> acp::ToolKind {
416        acp::ToolKind::Other
417    }
418
419    fn initial_title(
420        &self,
421        _input: Result<Self::Input, serde_json::Value>,
422        _cx: &mut App,
423    ) -> SharedString {
424        "List of random words".into()
425    }
426
427    fn run(
428        self: Arc<Self>,
429        input: ToolInput<Self::Input>,
430        _event_stream: ToolCallEventStream,
431        cx: &mut App,
432    ) -> Task<Result<String, String>> {
433        cx.spawn(async move |_cx| {
434            let _input = input
435                .recv()
436                .await
437                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
438            Ok("ok".to_string())
439        })
440    }
441}