test_tools.rs

  1use super::*;
  2use agent_settings::AgentSettings;
  3use gpui::{App, SharedString, Task};
  4use std::future;
  5use std::sync::atomic::{AtomicBool, Ordering};
  6use std::time::Duration;
  7
  8/// A tool that echoes its input
  9#[derive(JsonSchema, Serialize, Deserialize)]
 10pub struct EchoToolInput {
 11    /// The text to echo.
 12    pub text: String,
 13}
 14
 15pub struct EchoTool;
 16
 17impl AgentTool for EchoTool {
 18    type Input = EchoToolInput;
 19    type Output = String;
 20
 21    const NAME: &'static str = "echo";
 22
 23    fn kind() -> acp::ToolKind {
 24        acp::ToolKind::Other
 25    }
 26
 27    fn initial_title(
 28        &self,
 29        _input: Result<Self::Input, serde_json::Value>,
 30        _cx: &mut App,
 31    ) -> SharedString {
 32        "Echo".into()
 33    }
 34
 35    fn run(
 36        self: Arc<Self>,
 37        input: ToolInput<Self::Input>,
 38        _event_stream: ToolCallEventStream,
 39        cx: &mut App,
 40    ) -> Task<Result<String, String>> {
 41        cx.spawn(async move |_cx| {
 42            let input = input
 43                .recv()
 44                .await
 45                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
 46            Ok(input.text)
 47        })
 48    }
 49}
 50
 51/// A tool that waits for a specified delay
 52#[derive(JsonSchema, Serialize, Deserialize)]
 53pub struct DelayToolInput {
 54    /// The delay in milliseconds.
 55    ms: u64,
 56}
 57
 58pub struct DelayTool;
 59
 60impl AgentTool for DelayTool {
 61    type Input = DelayToolInput;
 62    type Output = String;
 63
 64    const NAME: &'static str = "delay";
 65
 66    fn initial_title(
 67        &self,
 68        input: Result<Self::Input, serde_json::Value>,
 69        _cx: &mut App,
 70    ) -> SharedString {
 71        if let Ok(input) = input {
 72            format!("Delay {}ms", input.ms).into()
 73        } else {
 74            "Delay".into()
 75        }
 76    }
 77
 78    fn kind() -> acp::ToolKind {
 79        acp::ToolKind::Other
 80    }
 81
 82    fn run(
 83        self: Arc<Self>,
 84        input: ToolInput<Self::Input>,
 85        _event_stream: ToolCallEventStream,
 86        cx: &mut App,
 87    ) -> Task<Result<String, String>>
 88    where
 89        Self: Sized,
 90    {
 91        let executor = cx.background_executor().clone();
 92        cx.foreground_executor().spawn(async move {
 93            let input = input
 94                .recv()
 95                .await
 96                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
 97            executor.timer(Duration::from_millis(input.ms)).await;
 98            Ok("Ding".to_string())
 99        })
100    }
101}
102
103#[derive(JsonSchema, Serialize, Deserialize)]
104pub struct ToolRequiringPermissionInput {}
105
106pub struct ToolRequiringPermission;
107
108impl AgentTool for ToolRequiringPermission {
109    type Input = ToolRequiringPermissionInput;
110    type Output = String;
111
112    const NAME: &'static str = "tool_requiring_permission";
113
114    fn kind() -> acp::ToolKind {
115        acp::ToolKind::Other
116    }
117
118    fn initial_title(
119        &self,
120        _input: Result<Self::Input, serde_json::Value>,
121        _cx: &mut App,
122    ) -> SharedString {
123        "This tool requires permission".into()
124    }
125
126    fn run(
127        self: Arc<Self>,
128        input: ToolInput<Self::Input>,
129        event_stream: ToolCallEventStream,
130        cx: &mut App,
131    ) -> Task<Result<String, String>> {
132        cx.spawn(async move |cx| {
133            let _input = input
134                .recv()
135                .await
136                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
137
138            let decision = cx.update(|cx| {
139                decide_permission_from_settings(
140                    Self::NAME,
141                    &[String::new()],
142                    AgentSettings::get_global(cx),
143                )
144            });
145
146            let authorize = match decision {
147                ToolPermissionDecision::Allow => None,
148                ToolPermissionDecision::Deny(reason) => {
149                    return Err(reason);
150                }
151                ToolPermissionDecision::Confirm => Some(cx.update(|cx| {
152                    let context = crate::ToolPermissionContext::new(
153                        "tool_requiring_permission",
154                        vec![String::new()],
155                    );
156                    event_stream.authorize("Authorize?", context, cx)
157                })),
158            };
159
160            if let Some(authorize) = authorize {
161                authorize.await.map_err(|e| e.to_string())?;
162            }
163            Ok("Allowed".to_string())
164        })
165    }
166}
167
168#[derive(JsonSchema, Serialize, Deserialize)]
169pub struct InfiniteToolInput {}
170
171pub struct InfiniteTool;
172
173impl AgentTool for InfiniteTool {
174    type Input = InfiniteToolInput;
175    type Output = String;
176
177    const NAME: &'static str = "infinite";
178
179    fn kind() -> acp::ToolKind {
180        acp::ToolKind::Other
181    }
182
183    fn initial_title(
184        &self,
185        _input: Result<Self::Input, serde_json::Value>,
186        _cx: &mut App,
187    ) -> SharedString {
188        "Infinite Tool".into()
189    }
190
191    fn run(
192        self: Arc<Self>,
193        input: ToolInput<Self::Input>,
194        _event_stream: ToolCallEventStream,
195        cx: &mut App,
196    ) -> Task<Result<String, String>> {
197        cx.foreground_executor().spawn(async move {
198            let _input = input
199                .recv()
200                .await
201                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
202            future::pending::<()>().await;
203            unreachable!()
204        })
205    }
206}
207
208/// A tool that loops forever but properly handles cancellation via `select!`,
209/// similar to how edit_file_tool handles cancellation.
210#[derive(JsonSchema, Serialize, Deserialize)]
211pub struct CancellationAwareToolInput {}
212
213pub struct CancellationAwareTool {
214    pub was_cancelled: Arc<AtomicBool>,
215}
216
217impl CancellationAwareTool {
218    pub fn new() -> (Self, Arc<AtomicBool>) {
219        let was_cancelled = Arc::new(AtomicBool::new(false));
220        (
221            Self {
222                was_cancelled: was_cancelled.clone(),
223            },
224            was_cancelled,
225        )
226    }
227}
228
229impl AgentTool for CancellationAwareTool {
230    type Input = CancellationAwareToolInput;
231    type Output = String;
232
233    const NAME: &'static str = "cancellation_aware";
234
235    fn kind() -> acp::ToolKind {
236        acp::ToolKind::Other
237    }
238
239    fn initial_title(
240        &self,
241        _input: Result<Self::Input, serde_json::Value>,
242        _cx: &mut App,
243    ) -> SharedString {
244        "Cancellation Aware Tool".into()
245    }
246
247    fn run(
248        self: Arc<Self>,
249        input: ToolInput<Self::Input>,
250        event_stream: ToolCallEventStream,
251        cx: &mut App,
252    ) -> Task<Result<String, String>> {
253        cx.foreground_executor().spawn(async move {
254            let _input = input
255                .recv()
256                .await
257                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
258            // Wait for cancellation - this tool does nothing but wait to be cancelled
259            event_stream.cancelled_by_user().await;
260            self.was_cancelled.store(true, Ordering::SeqCst);
261            Err("Tool cancelled by user".to_string())
262        })
263    }
264}
265
266/// A tool that takes an object with map from letters to random words starting with that letter.
267/// All fiealds are required! Pass a word for every letter!
268#[derive(JsonSchema, Serialize, Deserialize)]
269pub struct WordListInput {
270    /// Provide a random word that starts with A.
271    a: Option<String>,
272    /// Provide a random word that starts with B.
273    b: Option<String>,
274    /// Provide a random word that starts with C.
275    c: Option<String>,
276    /// Provide a random word that starts with D.
277    d: Option<String>,
278    /// Provide a random word that starts with E.
279    e: Option<String>,
280    /// Provide a random word that starts with F.
281    f: Option<String>,
282    /// Provide a random word that starts with G.
283    g: Option<String>,
284}
285
286pub struct WordListTool;
287
288impl AgentTool for WordListTool {
289    type Input = WordListInput;
290    type Output = String;
291
292    const NAME: &'static str = "word_list";
293
294    fn kind() -> acp::ToolKind {
295        acp::ToolKind::Other
296    }
297
298    fn initial_title(
299        &self,
300        _input: Result<Self::Input, serde_json::Value>,
301        _cx: &mut App,
302    ) -> SharedString {
303        "List of random words".into()
304    }
305
306    fn run(
307        self: Arc<Self>,
308        input: ToolInput<Self::Input>,
309        _event_stream: ToolCallEventStream,
310        cx: &mut App,
311    ) -> Task<Result<String, String>> {
312        cx.spawn(async move |_cx| {
313            let _input = input
314                .recv()
315                .await
316                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
317            Ok("ok".to_string())
318        })
319    }
320}