test_tools.rs

  1use super::*;
  2use agent_settings::AgentSettings;
  3use gpui::{App, SharedString, Task};
  4use std::future;
  5use std::sync::Mutex;
  6use std::sync::atomic::{AtomicBool, Ordering};
  7use std::time::Duration;
  8
  9/// A streaming tool that echoes its input, used to test streaming tool
 10/// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
 11/// before `is_input_complete`).
 12#[derive(JsonSchema, Serialize, Deserialize)]
 13pub struct StreamingEchoToolInput {
 14    /// The text to echo.
 15    pub text: String,
 16}
 17
 18pub struct StreamingEchoTool {
 19    wait_until_complete_rx: Mutex<Option<oneshot::Receiver<()>>>,
 20}
 21
 22impl StreamingEchoTool {
 23    pub fn new() -> Self {
 24        Self {
 25            wait_until_complete_rx: Mutex::new(None),
 26        }
 27    }
 28
 29    pub fn with_wait_until_complete(mut self, receiver: oneshot::Receiver<()>) -> Self {
 30        self.wait_until_complete_rx = Mutex::new(Some(receiver));
 31        self
 32    }
 33}
 34
 35impl AgentTool for StreamingEchoTool {
 36    type Input = StreamingEchoToolInput;
 37    type Output = String;
 38
 39    const NAME: &'static str = "streaming_echo";
 40
 41    fn supports_input_streaming() -> bool {
 42        true
 43    }
 44
 45    fn kind() -> acp::ToolKind {
 46        acp::ToolKind::Other
 47    }
 48
 49    fn initial_title(
 50        &self,
 51        _input: Result<Self::Input, serde_json::Value>,
 52        _cx: &mut App,
 53    ) -> SharedString {
 54        "Streaming Echo".into()
 55    }
 56
 57    fn run(
 58        self: Arc<Self>,
 59        input: ToolInput<Self::Input>,
 60        _event_stream: ToolCallEventStream,
 61        cx: &mut App,
 62    ) -> Task<Result<String, String>> {
 63        let wait_until_complete_rx = self.wait_until_complete_rx.lock().unwrap().take();
 64        cx.spawn(async move |_cx| {
 65            let input = input
 66                .recv()
 67                .await
 68                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
 69            if let Some(rx) = wait_until_complete_rx {
 70                rx.await.ok();
 71            }
 72            Ok(input.text)
 73        })
 74    }
 75}
 76
 77#[derive(JsonSchema, Serialize, Deserialize)]
 78pub struct StreamingJsonErrorContextToolInput {
 79    /// The text to echo.
 80    pub text: String,
 81}
 82
 83pub struct StreamingJsonErrorContextTool;
 84
 85impl AgentTool for StreamingJsonErrorContextTool {
 86    type Input = StreamingJsonErrorContextToolInput;
 87    type Output = String;
 88
 89    const NAME: &'static str = "streaming_json_error_context";
 90
 91    fn supports_input_streaming() -> bool {
 92        true
 93    }
 94
 95    fn kind() -> acp::ToolKind {
 96        acp::ToolKind::Other
 97    }
 98
 99    fn initial_title(
100        &self,
101        _input: Result<Self::Input, serde_json::Value>,
102        _cx: &mut App,
103    ) -> SharedString {
104        "Streaming JSON Error Context".into()
105    }
106
107    fn run(
108        self: Arc<Self>,
109        mut input: ToolInput<Self::Input>,
110        _event_stream: ToolCallEventStream,
111        cx: &mut App,
112    ) -> Task<Result<String, String>> {
113        cx.spawn(async move |_cx| {
114            let mut last_partial_text = None;
115
116            loop {
117                match input.next().await {
118                    Ok(ToolInputPayload::Partial(partial)) => {
119                        if let Some(text) = partial.get("text").and_then(|value| value.as_str()) {
120                            last_partial_text = Some(text.to_string());
121                        }
122                    }
123                    Ok(ToolInputPayload::Full(input)) => return Ok(input.text),
124                    Ok(ToolInputPayload::InvalidJson { error_message }) => {
125                        let partial_text = last_partial_text.unwrap_or_default();
126                        return Err(format!(
127                            "Saw partial text '{partial_text}' before invalid JSON: {error_message}"
128                        ));
129                    }
130                    Err(error) => {
131                        return Err(format!("Failed to receive tool input: {error}"));
132                    }
133                }
134            }
135        })
136    }
137}
138
139/// A streaming tool that echoes its input, used to test streaming tool
140/// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
141/// before `is_input_complete`).
142#[derive(JsonSchema, Serialize, Deserialize)]
143pub struct StreamingFailingEchoToolInput {
144    /// The text to echo.
145    pub text: String,
146}
147
148pub struct StreamingFailingEchoTool {
149    pub receive_chunks_until_failure: usize,
150}
151
152impl AgentTool for StreamingFailingEchoTool {
153    type Input = StreamingFailingEchoToolInput;
154
155    type Output = String;
156
157    const NAME: &'static str = "streaming_failing_echo";
158
159    fn kind() -> acp::ToolKind {
160        acp::ToolKind::Other
161    }
162
163    fn supports_input_streaming() -> bool {
164        true
165    }
166
167    fn initial_title(
168        &self,
169        _input: Result<Self::Input, serde_json::Value>,
170        _cx: &mut App,
171    ) -> SharedString {
172        "echo".into()
173    }
174
175    fn run(
176        self: Arc<Self>,
177        mut input: ToolInput<Self::Input>,
178        _event_stream: ToolCallEventStream,
179        cx: &mut App,
180    ) -> Task<Result<Self::Output, Self::Output>> {
181        cx.spawn(async move |_cx| {
182            for _ in 0..self.receive_chunks_until_failure {
183                let _ = input.next().await;
184            }
185            Err("failed".into())
186        })
187    }
188}
189
190/// A tool that echoes its input
191#[derive(JsonSchema, Serialize, Deserialize)]
192pub struct EchoToolInput {
193    /// The text to echo.
194    pub text: String,
195}
196
197pub struct EchoTool;
198
199impl AgentTool for EchoTool {
200    type Input = EchoToolInput;
201    type Output = String;
202
203    const NAME: &'static str = "echo";
204
205    fn kind() -> acp::ToolKind {
206        acp::ToolKind::Other
207    }
208
209    fn initial_title(
210        &self,
211        _input: Result<Self::Input, serde_json::Value>,
212        _cx: &mut App,
213    ) -> SharedString {
214        "Echo".into()
215    }
216
217    fn run(
218        self: Arc<Self>,
219        input: ToolInput<Self::Input>,
220        _event_stream: ToolCallEventStream,
221        cx: &mut App,
222    ) -> Task<Result<String, String>> {
223        cx.spawn(async move |_cx| {
224            let input = input
225                .recv()
226                .await
227                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
228            Ok(input.text)
229        })
230    }
231}
232
233/// A tool that waits for a specified delay
234#[derive(JsonSchema, Serialize, Deserialize)]
235pub struct DelayToolInput {
236    /// The delay in milliseconds.
237    ms: u64,
238}
239
240pub struct DelayTool;
241
242impl AgentTool for DelayTool {
243    type Input = DelayToolInput;
244    type Output = String;
245
246    const NAME: &'static str = "delay";
247
248    fn initial_title(
249        &self,
250        input: Result<Self::Input, serde_json::Value>,
251        _cx: &mut App,
252    ) -> SharedString {
253        if let Ok(input) = input {
254            format!("Delay {}ms", input.ms).into()
255        } else {
256            "Delay".into()
257        }
258    }
259
260    fn kind() -> acp::ToolKind {
261        acp::ToolKind::Other
262    }
263
264    fn run(
265        self: Arc<Self>,
266        input: ToolInput<Self::Input>,
267        _event_stream: ToolCallEventStream,
268        cx: &mut App,
269    ) -> Task<Result<String, String>>
270    where
271        Self: Sized,
272    {
273        let executor = cx.background_executor().clone();
274        cx.foreground_executor().spawn(async move {
275            let input = input
276                .recv()
277                .await
278                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
279            executor.timer(Duration::from_millis(input.ms)).await;
280            Ok("Ding".to_string())
281        })
282    }
283}
284
285#[derive(JsonSchema, Serialize, Deserialize)]
286pub struct ToolRequiringPermissionInput {}
287
288pub struct ToolRequiringPermission;
289
290impl AgentTool for ToolRequiringPermission {
291    type Input = ToolRequiringPermissionInput;
292    type Output = String;
293
294    const NAME: &'static str = "tool_requiring_permission";
295
296    fn kind() -> acp::ToolKind {
297        acp::ToolKind::Other
298    }
299
300    fn initial_title(
301        &self,
302        _input: Result<Self::Input, serde_json::Value>,
303        _cx: &mut App,
304    ) -> SharedString {
305        "This tool requires permission".into()
306    }
307
308    fn run(
309        self: Arc<Self>,
310        input: ToolInput<Self::Input>,
311        event_stream: ToolCallEventStream,
312        cx: &mut App,
313    ) -> Task<Result<String, String>> {
314        cx.spawn(async move |cx| {
315            let _input = input
316                .recv()
317                .await
318                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
319
320            let decision = cx.update(|cx| {
321                decide_permission_from_settings(
322                    Self::NAME,
323                    &[String::new()],
324                    AgentSettings::get_global(cx),
325                )
326            });
327
328            let authorize = match decision {
329                ToolPermissionDecision::Allow => None,
330                ToolPermissionDecision::Deny(reason) => {
331                    return Err(reason);
332                }
333                ToolPermissionDecision::Confirm => Some(cx.update(|cx| {
334                    let context = crate::ToolPermissionContext::new(
335                        "tool_requiring_permission",
336                        vec![String::new()],
337                    );
338                    event_stream.authorize("Authorize?", context, cx)
339                })),
340            };
341
342            if let Some(authorize) = authorize {
343                authorize.await.map_err(|e| e.to_string())?;
344            }
345            Ok("Allowed".to_string())
346        })
347    }
348}
349
350#[derive(JsonSchema, Serialize, Deserialize)]
351pub struct InfiniteToolInput {}
352
353pub struct InfiniteTool;
354
355impl AgentTool for InfiniteTool {
356    type Input = InfiniteToolInput;
357    type Output = String;
358
359    const NAME: &'static str = "infinite";
360
361    fn kind() -> acp::ToolKind {
362        acp::ToolKind::Other
363    }
364
365    fn initial_title(
366        &self,
367        _input: Result<Self::Input, serde_json::Value>,
368        _cx: &mut App,
369    ) -> SharedString {
370        "Infinite Tool".into()
371    }
372
373    fn run(
374        self: Arc<Self>,
375        input: ToolInput<Self::Input>,
376        _event_stream: ToolCallEventStream,
377        cx: &mut App,
378    ) -> Task<Result<String, String>> {
379        cx.foreground_executor().spawn(async move {
380            let _input = input
381                .recv()
382                .await
383                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
384            future::pending::<()>().await;
385            unreachable!()
386        })
387    }
388}
389
390/// A tool that loops forever but properly handles cancellation via `select!`,
391/// similar to how edit_file_tool handles cancellation.
392#[derive(JsonSchema, Serialize, Deserialize)]
393pub struct CancellationAwareToolInput {}
394
395pub struct CancellationAwareTool {
396    pub was_cancelled: Arc<AtomicBool>,
397}
398
399impl CancellationAwareTool {
400    pub fn new() -> (Self, Arc<AtomicBool>) {
401        let was_cancelled = Arc::new(AtomicBool::new(false));
402        (
403            Self {
404                was_cancelled: was_cancelled.clone(),
405            },
406            was_cancelled,
407        )
408    }
409}
410
411impl AgentTool for CancellationAwareTool {
412    type Input = CancellationAwareToolInput;
413    type Output = String;
414
415    const NAME: &'static str = "cancellation_aware";
416
417    fn kind() -> acp::ToolKind {
418        acp::ToolKind::Other
419    }
420
421    fn initial_title(
422        &self,
423        _input: Result<Self::Input, serde_json::Value>,
424        _cx: &mut App,
425    ) -> SharedString {
426        "Cancellation Aware Tool".into()
427    }
428
429    fn run(
430        self: Arc<Self>,
431        input: ToolInput<Self::Input>,
432        event_stream: ToolCallEventStream,
433        cx: &mut App,
434    ) -> Task<Result<String, String>> {
435        cx.foreground_executor().spawn(async move {
436            let _input = input
437                .recv()
438                .await
439                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
440            // Wait for cancellation - this tool does nothing but wait to be cancelled
441            event_stream.cancelled_by_user().await;
442            self.was_cancelled.store(true, Ordering::SeqCst);
443            Err("Tool cancelled by user".to_string())
444        })
445    }
446}
447
448/// A tool that takes an object with map from letters to random words starting with that letter.
449/// All fiealds are required! Pass a word for every letter!
450#[derive(JsonSchema, Serialize, Deserialize)]
451pub struct WordListInput {
452    /// Provide a random word that starts with A.
453    a: Option<String>,
454    /// Provide a random word that starts with B.
455    b: Option<String>,
456    /// Provide a random word that starts with C.
457    c: Option<String>,
458    /// Provide a random word that starts with D.
459    d: Option<String>,
460    /// Provide a random word that starts with E.
461    e: Option<String>,
462    /// Provide a random word that starts with F.
463    f: Option<String>,
464    /// Provide a random word that starts with G.
465    g: Option<String>,
466}
467
468pub struct WordListTool;
469
470impl AgentTool for WordListTool {
471    type Input = WordListInput;
472    type Output = String;
473
474    const NAME: &'static str = "word_list";
475
476    fn kind() -> acp::ToolKind {
477        acp::ToolKind::Other
478    }
479
480    fn initial_title(
481        &self,
482        _input: Result<Self::Input, serde_json::Value>,
483        _cx: &mut App,
484    ) -> SharedString {
485        "List of random words".into()
486    }
487
488    fn run(
489        self: Arc<Self>,
490        input: ToolInput<Self::Input>,
491        _event_stream: ToolCallEventStream,
492        cx: &mut App,
493    ) -> Task<Result<String, String>> {
494        cx.spawn(async move |_cx| {
495            let _input = input
496                .recv()
497                .await
498                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
499            Ok("ok".to_string())
500        })
501    }
502}