1use super::*;
2use crate::templates::Templates;
3use acp_thread::AgentConnection as _;
4use agent_client_protocol as acp;
5use client::{Client, UserStore};
6use gpui::{AppContext, Entity, TestAppContext};
7use language_model::{
8 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
9 LanguageModelRegistry, MessageContent, StopReason,
10};
11use project::Project;
12use reqwest_client::ReqwestClient;
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15use smol::stream::StreamExt;
16use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
17
18mod test_tools;
19use test_tools::*;
20
21#[gpui::test]
22async fn test_echo(cx: &mut TestAppContext) {
23 let ThreadTest { model, thread, .. } = setup(cx).await;
24
25 let events = thread
26 .update(cx, |thread, cx| {
27 thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
28 })
29 .collect()
30 .await;
31 thread.update(cx, |thread, _cx| {
32 assert_eq!(
33 thread.messages().last().unwrap().content,
34 vec![MessageContent::Text("Hello".to_string())]
35 );
36 });
37 assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
38}
39
40#[gpui::test]
41async fn test_basic_tool_calls(cx: &mut TestAppContext) {
42 let ThreadTest { model, thread, .. } = setup(cx).await;
43
44 // Test a tool call that's likely to complete *before* streaming stops.
45 let events = thread
46 .update(cx, |thread, cx| {
47 thread.add_tool(EchoTool);
48 thread.send(
49 model.clone(),
50 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
51 cx,
52 )
53 })
54 .collect()
55 .await;
56 assert_eq!(
57 stop_events(events),
58 vec![StopReason::ToolUse, StopReason::EndTurn]
59 );
60
61 // Test a tool calls that's likely to complete *after* streaming stops.
62 let events = thread
63 .update(cx, |thread, cx| {
64 thread.remove_tool(&AgentTool::name(&EchoTool));
65 thread.add_tool(DelayTool);
66 thread.send(
67 model.clone(),
68 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
69 cx,
70 )
71 })
72 .collect()
73 .await;
74 assert_eq!(
75 stop_events(events),
76 vec![StopReason::ToolUse, StopReason::EndTurn]
77 );
78 thread.update(cx, |thread, _cx| {
79 assert!(thread
80 .messages()
81 .last()
82 .unwrap()
83 .content
84 .iter()
85 .any(|content| {
86 if let MessageContent::Text(text) = content {
87 text.contains("Ding")
88 } else {
89 false
90 }
91 }));
92 });
93}
94
95#[gpui::test]
96async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
97 let ThreadTest { model, thread, .. } = setup(cx).await;
98
99 // Test a tool call that's likely to complete *before* streaming stops.
100 let mut events = thread.update(cx, |thread, cx| {
101 thread.add_tool(WordListTool);
102 thread.send(model.clone(), "Test the word_list tool.", cx)
103 });
104
105 let mut saw_partial_tool_use = false;
106 while let Some(event) = events.next().await {
107 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
108 thread.update(cx, |thread, _cx| {
109 // Look for a tool use in the thread's last message
110 let last_content = thread.messages().last().unwrap().content.last().unwrap();
111 if let MessageContent::ToolUse(last_tool_use) = last_content {
112 assert_eq!(last_tool_use.name.as_ref(), "word_list");
113 if tool_use_event.is_input_complete {
114 last_tool_use
115 .input
116 .get("a")
117 .expect("'a' has streamed because input is now complete");
118 last_tool_use
119 .input
120 .get("g")
121 .expect("'g' has streamed because input is now complete");
122 } else {
123 if !last_tool_use.is_input_complete
124 && last_tool_use.input.get("g").is_none()
125 {
126 saw_partial_tool_use = true;
127 }
128 }
129 } else {
130 panic!("last content should be a tool use");
131 }
132 });
133 }
134 }
135
136 assert!(
137 saw_partial_tool_use,
138 "should see at least one partially streamed tool use in the history"
139 );
140}
141
142#[gpui::test]
143async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
144 let ThreadTest { model, thread, .. } = setup(cx).await;
145
146 // Test concurrent tool calls with different delay times
147 let events = thread
148 .update(cx, |thread, cx| {
149 thread.add_tool(DelayTool);
150 thread.send(
151 model.clone(),
152 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
153 cx,
154 )
155 })
156 .collect()
157 .await;
158
159 let stop_reasons = stop_events(events);
160 if stop_reasons.len() == 2 {
161 assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
162 } else if stop_reasons.len() == 3 {
163 assert_eq!(
164 stop_reasons,
165 vec![
166 StopReason::ToolUse,
167 StopReason::ToolUse,
168 StopReason::EndTurn
169 ]
170 );
171 } else {
172 panic!("Expected either 1 or 2 tool uses followed by end turn");
173 }
174
175 thread.update(cx, |thread, _cx| {
176 let last_message = thread.messages().last().unwrap();
177 let text = last_message
178 .content
179 .iter()
180 .filter_map(|content| {
181 if let MessageContent::Text(text) = content {
182 Some(text.as_str())
183 } else {
184 None
185 }
186 })
187 .collect::<String>();
188
189 assert!(text.contains("Ding"));
190 });
191}
192
193#[gpui::test]
194async fn test_agent_connection(cx: &mut TestAppContext) {
195 cx.executor().allow_parking();
196 cx.update(settings::init);
197 let templates = Templates::new();
198
199 // Initialize language model system with test provider
200 cx.update(|cx| {
201 gpui_tokio::init(cx);
202 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
203 cx.set_http_client(Arc::new(http_client));
204
205 client::init_settings(cx);
206 let client = Client::production(cx);
207 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
208 language_model::init(client.clone(), cx);
209 language_models::init(user_store.clone(), client.clone(), cx);
210
211 // Initialize project settings
212 Project::init_settings(cx);
213
214 // Use test registry with fake provider
215 LanguageModelRegistry::test(cx);
216 });
217
218 // Create agent and connection
219 let agent = cx.new(|_| NativeAgent::new(templates.clone()));
220 let connection = NativeAgentConnection(agent.clone());
221
222 // Test model_selector returns Some
223 let selector_opt = connection.model_selector();
224 assert!(
225 selector_opt.is_some(),
226 "agent2 should always support ModelSelector"
227 );
228 let selector = selector_opt.unwrap();
229
230 // Test list_models
231 let listed_models = cx
232 .update(|cx| {
233 let mut async_cx = cx.to_async();
234 selector.list_models(&mut async_cx)
235 })
236 .await
237 .expect("list_models should succeed");
238 assert!(!listed_models.is_empty(), "should have at least one model");
239 assert_eq!(listed_models[0].id().0, "fake");
240
241 // Create a project for new_thread
242 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
243 let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
244
245 // Create a thread using new_thread
246 let cwd = Path::new("/test");
247 let connection_rc = Rc::new(connection.clone());
248 let acp_thread = cx
249 .update(|cx| {
250 let mut async_cx = cx.to_async();
251 connection_rc.new_thread(project, cwd, &mut async_cx)
252 })
253 .await
254 .expect("new_thread should succeed");
255
256 // Get the session_id from the AcpThread
257 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
258
259 // Test selected_model returns the default
260 let selected = cx
261 .update(|cx| {
262 let mut async_cx = cx.to_async();
263 selector.selected_model(&session_id, &mut async_cx)
264 })
265 .await
266 .expect("selected_model should succeed");
267 assert_eq!(selected.id().0, "fake", "should return default model");
268
269 // The thread was created via prompt with the default model
270 // We can verify it through selected_model
271
272 // Test prompt uses the selected model
273 let prompt_request = acp::PromptRequest {
274 session_id: session_id.clone(),
275 prompt: vec![acp::ContentBlock::Text(acp::TextContent {
276 text: "Test prompt".into(),
277 annotations: None,
278 })],
279 };
280
281 cx.update(|cx| connection.prompt(prompt_request, cx))
282 .await
283 .expect("prompt should succeed");
284
285 // The prompt was sent successfully
286
287 // Test cancel
288 cx.update(|cx| connection.cancel(&session_id, cx));
289
290 // After cancel, selected_model should fail
291 let result = cx
292 .update(|cx| {
293 let mut async_cx = cx.to_async();
294 selector.selected_model(&session_id, &mut async_cx)
295 })
296 .await;
297 assert!(result.is_err(), "selected_model should fail after cancel");
298
299 // Test error case: invalid session
300 let invalid_session = acp::SessionId("invalid".into());
301 let result = cx
302 .update(|cx| {
303 let mut async_cx = cx.to_async();
304 selector.selected_model(&invalid_session, &mut async_cx)
305 })
306 .await;
307 assert!(result.is_err(), "should fail for invalid session");
308 if let Err(e) = result {
309 assert!(
310 e.to_string().contains("Session not found"),
311 "should have correct error message"
312 );
313 }
314}
315
316/// Filters out the stop events for asserting against in tests
317fn stop_events(
318 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
319) -> Vec<StopReason> {
320 result_events
321 .into_iter()
322 .filter_map(|event| match event.unwrap() {
323 LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
324 _ => None,
325 })
326 .collect()
327}
328
329struct ThreadTest {
330 model: Arc<dyn LanguageModel>,
331 thread: Entity<Thread>,
332}
333
334async fn setup(cx: &mut TestAppContext) -> ThreadTest {
335 cx.executor().allow_parking();
336 cx.update(settings::init);
337 let templates = Templates::new();
338
339 let model = cx
340 .update(|cx| {
341 gpui_tokio::init(cx);
342 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
343 cx.set_http_client(Arc::new(http_client));
344
345 client::init_settings(cx);
346 let client = Client::production(cx);
347 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
348 language_model::init(client.clone(), cx);
349 language_models::init(user_store.clone(), client.clone(), cx);
350
351 let models = LanguageModelRegistry::read_global(cx);
352 let model = models
353 .available_models(cx)
354 .find(|model| model.id().0 == "claude-3-7-sonnet-latest")
355 .unwrap();
356
357 let provider = models.provider(&model.provider_id()).unwrap();
358 let authenticated = provider.authenticate(cx);
359
360 cx.spawn(async move |_cx| {
361 authenticated.await.unwrap();
362 model
363 })
364 })
365 .await;
366
367 let thread = cx.new(|_| Thread::new(templates, model.clone()));
368
369 ThreadTest { model, thread }
370}
371
372#[cfg(test)]
373#[ctor::ctor]
374fn init_logger() {
375 if std::env::var("RUST_LOG").is_ok() {
376 env_logger::init();
377 }
378}