test_tools.rs

  1use super::*;
  2use anyhow::Result;
  3use gpui::{App, SharedString, Task};
  4use std::future;
  5use std::sync::atomic::{AtomicBool, Ordering};
  6
  7/// A tool that echoes its input
  8#[derive(JsonSchema, Serialize, Deserialize)]
  9pub struct EchoToolInput {
 10    /// The text to echo.
 11    pub text: String,
 12}
 13
 14pub struct EchoTool;
 15
 16impl AgentTool for EchoTool {
 17    type Input = EchoToolInput;
 18    type Output = String;
 19
 20    fn name() -> &'static str {
 21        "echo"
 22    }
 23
 24    fn kind() -> acp::ToolKind {
 25        acp::ToolKind::Other
 26    }
 27
 28    fn initial_title(
 29        &self,
 30        _input: Result<Self::Input, serde_json::Value>,
 31        _cx: &mut App,
 32    ) -> SharedString {
 33        "Echo".into()
 34    }
 35
 36    fn run(
 37        self: Arc<Self>,
 38        input: Self::Input,
 39        _event_stream: ToolCallEventStream,
 40        _cx: &mut App,
 41    ) -> Task<Result<String>> {
 42        Task::ready(Ok(input.text))
 43    }
 44}
 45
 46/// A tool that waits for a specified delay
 47#[derive(JsonSchema, Serialize, Deserialize)]
 48pub struct DelayToolInput {
 49    /// The delay in milliseconds.
 50    ms: u64,
 51}
 52
 53pub struct DelayTool;
 54
 55impl AgentTool for DelayTool {
 56    type Input = DelayToolInput;
 57    type Output = String;
 58
 59    fn name() -> &'static str {
 60        "delay"
 61    }
 62
 63    fn initial_title(
 64        &self,
 65        input: Result<Self::Input, serde_json::Value>,
 66        _cx: &mut App,
 67    ) -> SharedString {
 68        if let Ok(input) = input {
 69            format!("Delay {}ms", input.ms).into()
 70        } else {
 71            "Delay".into()
 72        }
 73    }
 74
 75    fn kind() -> acp::ToolKind {
 76        acp::ToolKind::Other
 77    }
 78
 79    fn run(
 80        self: Arc<Self>,
 81        input: Self::Input,
 82        _event_stream: ToolCallEventStream,
 83        cx: &mut App,
 84    ) -> Task<Result<String>>
 85    where
 86        Self: Sized,
 87    {
 88        let executor = cx.background_executor().clone();
 89        cx.foreground_executor().spawn(async move {
 90            executor.timer(Duration::from_millis(input.ms)).await;
 91            Ok("Ding".to_string())
 92        })
 93    }
 94}
 95
 96#[derive(JsonSchema, Serialize, Deserialize)]
 97pub struct ToolRequiringPermissionInput {}
 98
 99pub struct ToolRequiringPermission;
100
101impl AgentTool for ToolRequiringPermission {
102    type Input = ToolRequiringPermissionInput;
103    type Output = String;
104
105    fn name() -> &'static str {
106        "tool_requiring_permission"
107    }
108
109    fn kind() -> acp::ToolKind {
110        acp::ToolKind::Other
111    }
112
113    fn initial_title(
114        &self,
115        _input: Result<Self::Input, serde_json::Value>,
116        _cx: &mut App,
117    ) -> SharedString {
118        "This tool requires permission".into()
119    }
120
121    fn run(
122        self: Arc<Self>,
123        _input: Self::Input,
124        event_stream: ToolCallEventStream,
125        cx: &mut App,
126    ) -> Task<Result<String>> {
127        let authorize = event_stream.authorize("Authorize?", cx);
128        cx.foreground_executor().spawn(async move {
129            authorize.await?;
130            Ok("Allowed".to_string())
131        })
132    }
133}
134
135#[derive(JsonSchema, Serialize, Deserialize)]
136pub struct InfiniteToolInput {}
137
138pub struct InfiniteTool;
139
140impl AgentTool for InfiniteTool {
141    type Input = InfiniteToolInput;
142    type Output = String;
143
144    fn name() -> &'static str {
145        "infinite"
146    }
147
148    fn kind() -> acp::ToolKind {
149        acp::ToolKind::Other
150    }
151
152    fn initial_title(
153        &self,
154        _input: Result<Self::Input, serde_json::Value>,
155        _cx: &mut App,
156    ) -> SharedString {
157        "Infinite Tool".into()
158    }
159
160    fn run(
161        self: Arc<Self>,
162        _input: Self::Input,
163        _event_stream: ToolCallEventStream,
164        cx: &mut App,
165    ) -> Task<Result<String>> {
166        cx.foreground_executor().spawn(async move {
167            future::pending::<()>().await;
168            unreachable!()
169        })
170    }
171}
172
173/// A tool that loops forever but properly handles cancellation via `select!`,
174/// similar to how edit_file_tool handles cancellation.
175#[derive(JsonSchema, Serialize, Deserialize)]
176pub struct CancellationAwareToolInput {}
177
178pub struct CancellationAwareTool {
179    pub was_cancelled: Arc<AtomicBool>,
180}
181
182impl CancellationAwareTool {
183    pub fn new() -> (Self, Arc<AtomicBool>) {
184        let was_cancelled = Arc::new(AtomicBool::new(false));
185        (
186            Self {
187                was_cancelled: was_cancelled.clone(),
188            },
189            was_cancelled,
190        )
191    }
192}
193
194impl AgentTool for CancellationAwareTool {
195    type Input = CancellationAwareToolInput;
196    type Output = String;
197
198    fn name() -> &'static str {
199        "cancellation_aware"
200    }
201
202    fn kind() -> acp::ToolKind {
203        acp::ToolKind::Other
204    }
205
206    fn initial_title(
207        &self,
208        _input: Result<Self::Input, serde_json::Value>,
209        _cx: &mut App,
210    ) -> SharedString {
211        "Cancellation Aware Tool".into()
212    }
213
214    fn run(
215        self: Arc<Self>,
216        _input: Self::Input,
217        event_stream: ToolCallEventStream,
218        cx: &mut App,
219    ) -> Task<Result<String>> {
220        cx.foreground_executor().spawn(async move {
221            // Wait for cancellation - this tool does nothing but wait to be cancelled
222            event_stream.cancelled_by_user().await;
223            self.was_cancelled.store(true, Ordering::SeqCst);
224            anyhow::bail!("Tool cancelled by user");
225        })
226    }
227}
228
229/// A tool that takes an object with map from letters to random words starting with that letter.
230/// All fiealds are required! Pass a word for every letter!
231#[derive(JsonSchema, Serialize, Deserialize)]
232pub struct WordListInput {
233    /// Provide a random word that starts with A.
234    a: Option<String>,
235    /// Provide a random word that starts with B.
236    b: Option<String>,
237    /// Provide a random word that starts with C.
238    c: Option<String>,
239    /// Provide a random word that starts with D.
240    d: Option<String>,
241    /// Provide a random word that starts with E.
242    e: Option<String>,
243    /// Provide a random word that starts with F.
244    f: Option<String>,
245    /// Provide a random word that starts with G.
246    g: Option<String>,
247}
248
249pub struct WordListTool;
250
251impl AgentTool for WordListTool {
252    type Input = WordListInput;
253    type Output = String;
254
255    fn name() -> &'static str {
256        "word_list"
257    }
258
259    fn kind() -> acp::ToolKind {
260        acp::ToolKind::Other
261    }
262
263    fn initial_title(
264        &self,
265        _input: Result<Self::Input, serde_json::Value>,
266        _cx: &mut App,
267    ) -> SharedString {
268        "List of random words".into()
269    }
270
271    fn run(
272        self: Arc<Self>,
273        _input: Self::Input,
274        _event_stream: ToolCallEventStream,
275        _cx: &mut App,
276    ) -> Task<Result<String>> {
277        Task::ready(Ok("ok".to_string()))
278    }
279}