test_tools.rs

  1use super::*;
  2use agent_settings::AgentSettings;
  3use anyhow::Result;
  4use gpui::{App, SharedString, Task};
  5use std::future;
  6use std::sync::atomic::{AtomicBool, Ordering};
  7
  8/// A tool that echoes its input
  9#[derive(JsonSchema, Serialize, Deserialize)]
 10pub struct EchoToolInput {
 11    /// The text to echo.
 12    pub text: String,
 13}
 14
 15pub struct EchoTool;
 16
 17impl AgentTool for EchoTool {
 18    type Input = EchoToolInput;
 19    type Output = String;
 20
 21    const NAME: &'static str = "echo";
 22
 23    fn kind() -> acp::ToolKind {
 24        acp::ToolKind::Other
 25    }
 26
 27    fn initial_title(
 28        &self,
 29        _input: Result<Self::Input, serde_json::Value>,
 30        _cx: &mut App,
 31    ) -> SharedString {
 32        "Echo".into()
 33    }
 34
 35    fn run(
 36        self: Arc<Self>,
 37        input: Self::Input,
 38        _event_stream: ToolCallEventStream,
 39        _cx: &mut App,
 40    ) -> Task<Result<String>> {
 41        Task::ready(Ok(input.text))
 42    }
 43}
 44
 45/// A tool that waits for a specified delay
 46#[derive(JsonSchema, Serialize, Deserialize)]
 47pub struct DelayToolInput {
 48    /// The delay in milliseconds.
 49    ms: u64,
 50}
 51
 52pub struct DelayTool;
 53
 54impl AgentTool for DelayTool {
 55    type Input = DelayToolInput;
 56    type Output = String;
 57
 58    const NAME: &'static str = "delay";
 59
 60    fn initial_title(
 61        &self,
 62        input: Result<Self::Input, serde_json::Value>,
 63        _cx: &mut App,
 64    ) -> SharedString {
 65        if let Ok(input) = input {
 66            format!("Delay {}ms", input.ms).into()
 67        } else {
 68            "Delay".into()
 69        }
 70    }
 71
 72    fn kind() -> acp::ToolKind {
 73        acp::ToolKind::Other
 74    }
 75
 76    fn run(
 77        self: Arc<Self>,
 78        input: Self::Input,
 79        _event_stream: ToolCallEventStream,
 80        cx: &mut App,
 81    ) -> Task<Result<String>>
 82    where
 83        Self: Sized,
 84    {
 85        let executor = cx.background_executor().clone();
 86        cx.foreground_executor().spawn(async move {
 87            executor.timer(Duration::from_millis(input.ms)).await;
 88            Ok("Ding".to_string())
 89        })
 90    }
 91}
 92
 93#[derive(JsonSchema, Serialize, Deserialize)]
 94pub struct ToolRequiringPermissionInput {}
 95
 96pub struct ToolRequiringPermission;
 97
 98impl AgentTool for ToolRequiringPermission {
 99    type Input = ToolRequiringPermissionInput;
100    type Output = String;
101
102    const NAME: &'static str = "tool_requiring_permission";
103
104    fn kind() -> acp::ToolKind {
105        acp::ToolKind::Other
106    }
107
108    fn initial_title(
109        &self,
110        _input: Result<Self::Input, serde_json::Value>,
111        _cx: &mut App,
112    ) -> SharedString {
113        "This tool requires permission".into()
114    }
115
116    fn run(
117        self: Arc<Self>,
118        _input: Self::Input,
119        event_stream: ToolCallEventStream,
120        cx: &mut App,
121    ) -> Task<Result<String>> {
122        let settings = AgentSettings::get_global(cx);
123        let decision = decide_permission_from_settings(Self::NAME, &[String::new()], settings);
124
125        let authorize = match decision {
126            ToolPermissionDecision::Allow => None,
127            ToolPermissionDecision::Deny(reason) => {
128                return Task::ready(Err(anyhow::anyhow!("{}", reason)));
129            }
130            ToolPermissionDecision::Confirm => {
131                let context = crate::ToolPermissionContext::new(
132                    "tool_requiring_permission",
133                    vec![String::new()],
134                );
135                Some(event_stream.authorize("Authorize?", context, cx))
136            }
137        };
138
139        cx.foreground_executor().spawn(async move {
140            if let Some(authorize) = authorize {
141                authorize.await?;
142            }
143            Ok("Allowed".to_string())
144        })
145    }
146}
147
148#[derive(JsonSchema, Serialize, Deserialize)]
149pub struct InfiniteToolInput {}
150
151pub struct InfiniteTool;
152
153impl AgentTool for InfiniteTool {
154    type Input = InfiniteToolInput;
155    type Output = String;
156
157    const NAME: &'static str = "infinite";
158
159    fn kind() -> acp::ToolKind {
160        acp::ToolKind::Other
161    }
162
163    fn initial_title(
164        &self,
165        _input: Result<Self::Input, serde_json::Value>,
166        _cx: &mut App,
167    ) -> SharedString {
168        "Infinite Tool".into()
169    }
170
171    fn run(
172        self: Arc<Self>,
173        _input: Self::Input,
174        _event_stream: ToolCallEventStream,
175        cx: &mut App,
176    ) -> Task<Result<String>> {
177        cx.foreground_executor().spawn(async move {
178            future::pending::<()>().await;
179            unreachable!()
180        })
181    }
182}
183
184/// A tool that loops forever but properly handles cancellation via `select!`,
185/// similar to how edit_file_tool handles cancellation.
186#[derive(JsonSchema, Serialize, Deserialize)]
187pub struct CancellationAwareToolInput {}
188
189pub struct CancellationAwareTool {
190    pub was_cancelled: Arc<AtomicBool>,
191}
192
193impl CancellationAwareTool {
194    pub fn new() -> (Self, Arc<AtomicBool>) {
195        let was_cancelled = Arc::new(AtomicBool::new(false));
196        (
197            Self {
198                was_cancelled: was_cancelled.clone(),
199            },
200            was_cancelled,
201        )
202    }
203}
204
205impl AgentTool for CancellationAwareTool {
206    type Input = CancellationAwareToolInput;
207    type Output = String;
208
209    const NAME: &'static str = "cancellation_aware";
210
211    fn kind() -> acp::ToolKind {
212        acp::ToolKind::Other
213    }
214
215    fn initial_title(
216        &self,
217        _input: Result<Self::Input, serde_json::Value>,
218        _cx: &mut App,
219    ) -> SharedString {
220        "Cancellation Aware Tool".into()
221    }
222
223    fn run(
224        self: Arc<Self>,
225        _input: Self::Input,
226        event_stream: ToolCallEventStream,
227        cx: &mut App,
228    ) -> Task<Result<String>> {
229        cx.foreground_executor().spawn(async move {
230            // Wait for cancellation - this tool does nothing but wait to be cancelled
231            event_stream.cancelled_by_user().await;
232            self.was_cancelled.store(true, Ordering::SeqCst);
233            anyhow::bail!("Tool cancelled by user");
234        })
235    }
236}
237
238/// A tool that takes an object with map from letters to random words starting with that letter.
239/// All fiealds are required! Pass a word for every letter!
240#[derive(JsonSchema, Serialize, Deserialize)]
241pub struct WordListInput {
242    /// Provide a random word that starts with A.
243    a: Option<String>,
244    /// Provide a random word that starts with B.
245    b: Option<String>,
246    /// Provide a random word that starts with C.
247    c: Option<String>,
248    /// Provide a random word that starts with D.
249    d: Option<String>,
250    /// Provide a random word that starts with E.
251    e: Option<String>,
252    /// Provide a random word that starts with F.
253    f: Option<String>,
254    /// Provide a random word that starts with G.
255    g: Option<String>,
256}
257
258pub struct WordListTool;
259
260impl AgentTool for WordListTool {
261    type Input = WordListInput;
262    type Output = String;
263
264    const NAME: &'static str = "word_list";
265
266    fn kind() -> acp::ToolKind {
267        acp::ToolKind::Other
268    }
269
270    fn initial_title(
271        &self,
272        _input: Result<Self::Input, serde_json::Value>,
273        _cx: &mut App,
274    ) -> SharedString {
275        "List of random words".into()
276    }
277
278    fn run(
279        self: Arc<Self>,
280        _input: Self::Input,
281        _event_stream: ToolCallEventStream,
282        _cx: &mut App,
283    ) -> Task<Result<String>> {
284        Task::ready(Ok("ok".to_string()))
285    }
286}