test_tools.rs

  1use super::*;
  2use agent_settings::AgentSettings;
  3use gpui::{App, SharedString, Task};
  4use std::future;
  5use std::sync::atomic::{AtomicBool, Ordering};
  6use std::time::Duration;
  7
  8/// A streaming tool that echoes its input, used to test streaming tool
  9/// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
 10/// before `is_input_complete`).
 11#[derive(JsonSchema, Serialize, Deserialize)]
 12pub struct StreamingEchoToolInput {
 13    /// The text to echo.
 14    pub text: String,
 15}
 16
 17pub struct StreamingEchoTool;
 18
 19impl AgentTool for StreamingEchoTool {
 20    type Input = StreamingEchoToolInput;
 21    type Output = String;
 22
 23    const NAME: &'static str = "streaming_echo";
 24
 25    fn supports_input_streaming() -> bool {
 26        true
 27    }
 28
 29    fn kind() -> acp::ToolKind {
 30        acp::ToolKind::Other
 31    }
 32
 33    fn initial_title(
 34        &self,
 35        _input: Result<Self::Input, serde_json::Value>,
 36        _cx: &mut App,
 37    ) -> SharedString {
 38        "Streaming Echo".into()
 39    }
 40
 41    fn run(
 42        self: Arc<Self>,
 43        mut input: ToolInput<Self::Input>,
 44        _event_stream: ToolCallEventStream,
 45        cx: &mut App,
 46    ) -> Task<Result<String, String>> {
 47        cx.spawn(async move |_cx| {
 48            while input.recv_partial().await.is_some() {}
 49            let input = input
 50                .recv()
 51                .await
 52                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
 53            Ok(input.text)
 54        })
 55    }
 56}
 57
 58/// A tool that echoes its input
 59#[derive(JsonSchema, Serialize, Deserialize)]
 60pub struct EchoToolInput {
 61    /// The text to echo.
 62    pub text: String,
 63}
 64
 65pub struct EchoTool;
 66
 67impl AgentTool for EchoTool {
 68    type Input = EchoToolInput;
 69    type Output = String;
 70
 71    const NAME: &'static str = "echo";
 72
 73    fn kind() -> acp::ToolKind {
 74        acp::ToolKind::Other
 75    }
 76
 77    fn initial_title(
 78        &self,
 79        _input: Result<Self::Input, serde_json::Value>,
 80        _cx: &mut App,
 81    ) -> SharedString {
 82        "Echo".into()
 83    }
 84
 85    fn run(
 86        self: Arc<Self>,
 87        input: ToolInput<Self::Input>,
 88        _event_stream: ToolCallEventStream,
 89        cx: &mut App,
 90    ) -> Task<Result<String, String>> {
 91        cx.spawn(async move |_cx| {
 92            let input = input
 93                .recv()
 94                .await
 95                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
 96            Ok(input.text)
 97        })
 98    }
 99}
100
101/// A tool that waits for a specified delay
102#[derive(JsonSchema, Serialize, Deserialize)]
103pub struct DelayToolInput {
104    /// The delay in milliseconds.
105    ms: u64,
106}
107
108pub struct DelayTool;
109
110impl AgentTool for DelayTool {
111    type Input = DelayToolInput;
112    type Output = String;
113
114    const NAME: &'static str = "delay";
115
116    fn initial_title(
117        &self,
118        input: Result<Self::Input, serde_json::Value>,
119        _cx: &mut App,
120    ) -> SharedString {
121        if let Ok(input) = input {
122            format!("Delay {}ms", input.ms).into()
123        } else {
124            "Delay".into()
125        }
126    }
127
128    fn kind() -> acp::ToolKind {
129        acp::ToolKind::Other
130    }
131
132    fn run(
133        self: Arc<Self>,
134        input: ToolInput<Self::Input>,
135        _event_stream: ToolCallEventStream,
136        cx: &mut App,
137    ) -> Task<Result<String, String>>
138    where
139        Self: Sized,
140    {
141        let executor = cx.background_executor().clone();
142        cx.foreground_executor().spawn(async move {
143            let input = input
144                .recv()
145                .await
146                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
147            executor.timer(Duration::from_millis(input.ms)).await;
148            Ok("Ding".to_string())
149        })
150    }
151}
152
153#[derive(JsonSchema, Serialize, Deserialize)]
154pub struct ToolRequiringPermissionInput {}
155
156pub struct ToolRequiringPermission;
157
158impl AgentTool for ToolRequiringPermission {
159    type Input = ToolRequiringPermissionInput;
160    type Output = String;
161
162    const NAME: &'static str = "tool_requiring_permission";
163
164    fn kind() -> acp::ToolKind {
165        acp::ToolKind::Other
166    }
167
168    fn initial_title(
169        &self,
170        _input: Result<Self::Input, serde_json::Value>,
171        _cx: &mut App,
172    ) -> SharedString {
173        "This tool requires permission".into()
174    }
175
176    fn run(
177        self: Arc<Self>,
178        input: ToolInput<Self::Input>,
179        event_stream: ToolCallEventStream,
180        cx: &mut App,
181    ) -> Task<Result<String, String>> {
182        cx.spawn(async move |cx| {
183            let _input = input
184                .recv()
185                .await
186                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
187
188            let decision = cx.update(|cx| {
189                decide_permission_from_settings(
190                    Self::NAME,
191                    &[String::new()],
192                    AgentSettings::get_global(cx),
193                )
194            });
195
196            let authorize = match decision {
197                ToolPermissionDecision::Allow => None,
198                ToolPermissionDecision::Deny(reason) => {
199                    return Err(reason);
200                }
201                ToolPermissionDecision::Confirm => Some(cx.update(|cx| {
202                    let context = crate::ToolPermissionContext::new(
203                        "tool_requiring_permission",
204                        vec![String::new()],
205                    );
206                    event_stream.authorize("Authorize?", context, cx)
207                })),
208            };
209
210            if let Some(authorize) = authorize {
211                authorize.await.map_err(|e| e.to_string())?;
212            }
213            Ok("Allowed".to_string())
214        })
215    }
216}
217
218#[derive(JsonSchema, Serialize, Deserialize)]
219pub struct InfiniteToolInput {}
220
221pub struct InfiniteTool;
222
223impl AgentTool for InfiniteTool {
224    type Input = InfiniteToolInput;
225    type Output = String;
226
227    const NAME: &'static str = "infinite";
228
229    fn kind() -> acp::ToolKind {
230        acp::ToolKind::Other
231    }
232
233    fn initial_title(
234        &self,
235        _input: Result<Self::Input, serde_json::Value>,
236        _cx: &mut App,
237    ) -> SharedString {
238        "Infinite Tool".into()
239    }
240
241    fn run(
242        self: Arc<Self>,
243        input: ToolInput<Self::Input>,
244        _event_stream: ToolCallEventStream,
245        cx: &mut App,
246    ) -> Task<Result<String, String>> {
247        cx.foreground_executor().spawn(async move {
248            let _input = input
249                .recv()
250                .await
251                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
252            future::pending::<()>().await;
253            unreachable!()
254        })
255    }
256}
257
258/// A tool that loops forever but properly handles cancellation via `select!`,
259/// similar to how edit_file_tool handles cancellation.
260#[derive(JsonSchema, Serialize, Deserialize)]
261pub struct CancellationAwareToolInput {}
262
263pub struct CancellationAwareTool {
264    pub was_cancelled: Arc<AtomicBool>,
265}
266
267impl CancellationAwareTool {
268    pub fn new() -> (Self, Arc<AtomicBool>) {
269        let was_cancelled = Arc::new(AtomicBool::new(false));
270        (
271            Self {
272                was_cancelled: was_cancelled.clone(),
273            },
274            was_cancelled,
275        )
276    }
277}
278
279impl AgentTool for CancellationAwareTool {
280    type Input = CancellationAwareToolInput;
281    type Output = String;
282
283    const NAME: &'static str = "cancellation_aware";
284
285    fn kind() -> acp::ToolKind {
286        acp::ToolKind::Other
287    }
288
289    fn initial_title(
290        &self,
291        _input: Result<Self::Input, serde_json::Value>,
292        _cx: &mut App,
293    ) -> SharedString {
294        "Cancellation Aware Tool".into()
295    }
296
297    fn run(
298        self: Arc<Self>,
299        input: ToolInput<Self::Input>,
300        event_stream: ToolCallEventStream,
301        cx: &mut App,
302    ) -> Task<Result<String, String>> {
303        cx.foreground_executor().spawn(async move {
304            let _input = input
305                .recv()
306                .await
307                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
308            // Wait for cancellation - this tool does nothing but wait to be cancelled
309            event_stream.cancelled_by_user().await;
310            self.was_cancelled.store(true, Ordering::SeqCst);
311            Err("Tool cancelled by user".to_string())
312        })
313    }
314}
315
316/// A tool that takes an object with map from letters to random words starting with that letter.
317/// All fiealds are required! Pass a word for every letter!
318#[derive(JsonSchema, Serialize, Deserialize)]
319pub struct WordListInput {
320    /// Provide a random word that starts with A.
321    a: Option<String>,
322    /// Provide a random word that starts with B.
323    b: Option<String>,
324    /// Provide a random word that starts with C.
325    c: Option<String>,
326    /// Provide a random word that starts with D.
327    d: Option<String>,
328    /// Provide a random word that starts with E.
329    e: Option<String>,
330    /// Provide a random word that starts with F.
331    f: Option<String>,
332    /// Provide a random word that starts with G.
333    g: Option<String>,
334}
335
336pub struct WordListTool;
337
338impl AgentTool for WordListTool {
339    type Input = WordListInput;
340    type Output = String;
341
342    const NAME: &'static str = "word_list";
343
344    fn kind() -> acp::ToolKind {
345        acp::ToolKind::Other
346    }
347
348    fn initial_title(
349        &self,
350        _input: Result<Self::Input, serde_json::Value>,
351        _cx: &mut App,
352    ) -> SharedString {
353        "List of random words".into()
354    }
355
356    fn run(
357        self: Arc<Self>,
358        input: ToolInput<Self::Input>,
359        _event_stream: ToolCallEventStream,
360        cx: &mut App,
361    ) -> Task<Result<String, String>> {
362        cx.spawn(async move |_cx| {
363            let _input = input
364                .recv()
365                .await
366                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
367            Ok("ok".to_string())
368        })
369    }
370}