1use super::*;
2use crate::templates::Templates;
3use acp_thread::AgentConnection as _;
4use agent_client_protocol as acp;
5use client::{Client, UserStore};
6use fs::FakeFs;
7use gpui::{AppContext, Entity, Task, TestAppContext};
8use indoc::indoc;
9use language_model::{
10 fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
11 LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, MessageContent,
12 StopReason,
13};
14use project::Project;
15use reqwest_client::ReqwestClient;
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use serde_json::json;
19use smol::stream::StreamExt;
20use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
21use util::path;
22
23mod test_tools;
24use test_tools::*;
25
26#[gpui::test]
27#[ignore = "temporarily disabled until it can be run on CI"]
28async fn test_echo(cx: &mut TestAppContext) {
29 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
30
31 let events = thread
32 .update(cx, |thread, cx| {
33 thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
34 })
35 .collect()
36 .await;
37 thread.update(cx, |thread, _cx| {
38 assert_eq!(
39 thread.messages().last().unwrap().content,
40 vec![MessageContent::Text("Hello".to_string())]
41 );
42 });
43 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
44}
45
46#[gpui::test]
47#[ignore = "temporarily disabled until it can be run on CI"]
48async fn test_thinking(cx: &mut TestAppContext) {
49 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
50
51 let events = thread
52 .update(cx, |thread, cx| {
53 thread.send(
54 model.clone(),
55 indoc! {"
56 Testing:
57
58 Generate a thinking step where you just think the word 'Think',
59 and have your final answer be 'Hello'
60 "},
61 cx,
62 )
63 })
64 .collect()
65 .await;
66 thread.update(cx, |thread, _cx| {
67 assert_eq!(
68 thread.messages().last().unwrap().to_markdown(),
69 indoc! {"
70 ## assistant
71 <think>Think</think>
72 Hello
73 "}
74 )
75 });
76 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
77}
78
79#[gpui::test]
80#[ignore = "temporarily disabled until it can be run on CI"]
81async fn test_basic_tool_calls(cx: &mut TestAppContext) {
82 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
83
84 // Test a tool call that's likely to complete *before* streaming stops.
85 let events = thread
86 .update(cx, |thread, cx| {
87 thread.add_tool(EchoTool);
88 thread.send(
89 model.clone(),
90 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
91 cx,
92 )
93 })
94 .collect()
95 .await;
96 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
97
98 // Test a tool calls that's likely to complete *after* streaming stops.
99 let events = thread
100 .update(cx, |thread, cx| {
101 thread.remove_tool(&AgentTool::name(&EchoTool));
102 thread.add_tool(DelayTool);
103 thread.send(
104 model.clone(),
105 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
106 cx,
107 )
108 })
109 .collect()
110 .await;
111 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
112 thread.update(cx, |thread, _cx| {
113 assert!(thread
114 .messages()
115 .last()
116 .unwrap()
117 .content
118 .iter()
119 .any(|content| {
120 if let MessageContent::Text(text) = content {
121 text.contains("Ding")
122 } else {
123 false
124 }
125 }));
126 });
127}
128
129#[gpui::test]
130#[ignore = "temporarily disabled until it can be run on CI"]
131async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
132 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
133
134 // Test a tool call that's likely to complete *before* streaming stops.
135 let mut events = thread.update(cx, |thread, cx| {
136 thread.add_tool(WordListTool);
137 thread.send(model.clone(), "Test the word_list tool.", cx)
138 });
139
140 let mut saw_partial_tool_use = false;
141 while let Some(event) = events.next().await {
142 if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
143 thread.update(cx, |thread, _cx| {
144 // Look for a tool use in the thread's last message
145 let last_content = thread.messages().last().unwrap().content.last().unwrap();
146 if let MessageContent::ToolUse(last_tool_use) = last_content {
147 assert_eq!(last_tool_use.name.as_ref(), "word_list");
148 if tool_call.status == acp::ToolCallStatus::Pending {
149 if !last_tool_use.is_input_complete
150 && last_tool_use.input.get("g").is_none()
151 {
152 saw_partial_tool_use = true;
153 }
154 } else {
155 last_tool_use
156 .input
157 .get("a")
158 .expect("'a' has streamed because input is now complete");
159 last_tool_use
160 .input
161 .get("g")
162 .expect("'g' has streamed because input is now complete");
163 }
164 } else {
165 panic!("last content should be a tool use");
166 }
167 });
168 }
169 }
170
171 assert!(
172 saw_partial_tool_use,
173 "should see at least one partially streamed tool use in the history"
174 );
175}
176
177#[gpui::test]
178#[ignore = "temporarily disabled until it can be run on CI"]
179async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
180 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
181
182 // Test concurrent tool calls with different delay times
183 let events = thread
184 .update(cx, |thread, cx| {
185 thread.add_tool(DelayTool);
186 thread.send(
187 model.clone(),
188 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
189 cx,
190 )
191 })
192 .collect()
193 .await;
194
195 let stop_reasons = stop_events(events);
196 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
197
198 thread.update(cx, |thread, _cx| {
199 let last_message = thread.messages().last().unwrap();
200 let text = last_message
201 .content
202 .iter()
203 .filter_map(|content| {
204 if let MessageContent::Text(text) = content {
205 Some(text.as_str())
206 } else {
207 None
208 }
209 })
210 .collect::<String>();
211
212 assert!(text.contains("Ding"));
213 });
214}
215
216#[gpui::test]
217#[ignore = "temporarily disabled until it can be run on CI"]
218async fn test_cancellation(cx: &mut TestAppContext) {
219 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
220
221 let mut events = thread.update(cx, |thread, cx| {
222 thread.add_tool(InfiniteTool);
223 thread.add_tool(EchoTool);
224 thread.send(
225 model.clone(),
226 "Call the echo tool and then call the infinite tool, then explain their output",
227 cx,
228 )
229 });
230
231 // Wait until both tools are called.
232 let mut expected_tool_calls = vec!["echo", "infinite"];
233 let mut echo_id = None;
234 let mut echo_completed = false;
235 while let Some(event) = events.next().await {
236 match event.unwrap() {
237 AgentResponseEvent::ToolCall(tool_call) => {
238 assert_eq!(tool_call.title, expected_tool_calls.remove(0));
239 if tool_call.title == "echo" {
240 echo_id = Some(tool_call.id);
241 }
242 }
243 AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate {
244 id,
245 fields:
246 acp::ToolCallUpdateFields {
247 status: Some(acp::ToolCallStatus::Completed),
248 ..
249 },
250 }) if Some(&id) == echo_id.as_ref() => {
251 echo_completed = true;
252 }
253 _ => {}
254 }
255
256 if expected_tool_calls.is_empty() && echo_completed {
257 break;
258 }
259 }
260
261 // Cancel the current send and ensure that the event stream is closed, even
262 // if one of the tools is still running.
263 thread.update(cx, |thread, _cx| thread.cancel());
264 events.collect::<Vec<_>>().await;
265
266 // Ensure we can still send a new message after cancellation.
267 let events = thread
268 .update(cx, |thread, cx| {
269 thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
270 })
271 .collect::<Vec<_>>()
272 .await;
273 thread.update(cx, |thread, _cx| {
274 assert_eq!(
275 thread.messages().last().unwrap().content,
276 vec![MessageContent::Text("Hello".to_string())]
277 );
278 });
279 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
280}
281
282#[gpui::test]
283async fn test_refusal(cx: &mut TestAppContext) {
284 let fake_model = Arc::new(FakeLanguageModel::default());
285 let ThreadTest { thread, .. } = setup(cx, TestModel::Fake(fake_model.clone())).await;
286
287 let events = thread.update(cx, |thread, cx| {
288 thread.send(fake_model.clone(), "Hello", cx)
289 });
290 cx.run_until_parked();
291 thread.read_with(cx, |thread, _| {
292 assert_eq!(
293 thread.to_markdown(),
294 indoc! {"
295 ## user
296 Hello
297 "}
298 );
299 });
300
301 fake_model.send_last_completion_stream_text_chunk("Hey!");
302 cx.run_until_parked();
303 thread.read_with(cx, |thread, _| {
304 assert_eq!(
305 thread.to_markdown(),
306 indoc! {"
307 ## user
308 Hello
309 ## assistant
310 Hey!
311 "}
312 );
313 });
314
315 // If the model refuses to continue, the thread should remove all the messages after the last user message.
316 fake_model
317 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
318 let events = events.collect::<Vec<_>>().await;
319 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
320 thread.read_with(cx, |thread, _| {
321 assert_eq!(thread.to_markdown(), "");
322 });
323}
324
325#[ignore = "temporarily disabled until it can be run on CI"]
326#[gpui::test]
327async fn test_agent_connection(cx: &mut TestAppContext) {
328 cx.executor().allow_parking();
329 cx.update(settings::init);
330 let templates = Templates::new();
331
332 // Initialize language model system with test provider
333 cx.update(|cx| {
334 gpui_tokio::init(cx);
335 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
336 cx.set_http_client(Arc::new(http_client));
337
338 client::init_settings(cx);
339 let client = Client::production(cx);
340 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
341 language_model::init(client.clone(), cx);
342 language_models::init(user_store.clone(), client.clone(), cx);
343
344 // Initialize project settings
345 Project::init_settings(cx);
346
347 // Use test registry with fake provider
348 LanguageModelRegistry::test(cx);
349 });
350
351 // Create agent and connection
352 let agent = cx.new(|_| NativeAgent::new(templates.clone()));
353 let connection = NativeAgentConnection(agent.clone());
354
355 // Test model_selector returns Some
356 let selector_opt = connection.model_selector();
357 assert!(
358 selector_opt.is_some(),
359 "agent2 should always support ModelSelector"
360 );
361 let selector = selector_opt.unwrap();
362
363 // Test list_models
364 let listed_models = cx
365 .update(|cx| {
366 let mut async_cx = cx.to_async();
367 selector.list_models(&mut async_cx)
368 })
369 .await
370 .expect("list_models should succeed");
371 assert!(!listed_models.is_empty(), "should have at least one model");
372 assert_eq!(listed_models[0].id().0, "fake");
373
374 // Create a project for new_thread
375 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
376 let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
377
378 // Create a thread using new_thread
379 let cwd = Path::new("/test");
380 let connection_rc = Rc::new(connection.clone());
381 let acp_thread = cx
382 .update(|cx| {
383 let mut async_cx = cx.to_async();
384 connection_rc.new_thread(project, cwd, &mut async_cx)
385 })
386 .await
387 .expect("new_thread should succeed");
388
389 // Get the session_id from the AcpThread
390 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
391
392 // Test selected_model returns the default
393 let selected = cx
394 .update(|cx| {
395 let mut async_cx = cx.to_async();
396 selector.selected_model(&session_id, &mut async_cx)
397 })
398 .await
399 .expect("selected_model should succeed");
400 assert_eq!(selected.id().0, "fake", "should return default model");
401
402 // The thread was created via prompt with the default model
403 // We can verify it through selected_model
404
405 // Test prompt uses the selected model
406 let prompt_request = acp::PromptRequest {
407 session_id: session_id.clone(),
408 prompt: vec![acp::ContentBlock::Text(acp::TextContent {
409 text: "Test prompt".into(),
410 annotations: None,
411 })],
412 };
413
414 let request = cx.update(|cx| connection.prompt(prompt_request, cx));
415 let request = cx.background_spawn(request);
416 smol::Timer::after(Duration::from_millis(100)).await;
417
418 // Test cancel
419 cx.update(|cx| connection.cancel(&session_id, cx));
420 request.await.expect("prompt should fail gracefully");
421}
422
423/// Filters out the stop events for asserting against in tests
424fn stop_events(
425 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
426) -> Vec<acp::StopReason> {
427 result_events
428 .into_iter()
429 .filter_map(|event| match event.unwrap() {
430 AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
431 _ => None,
432 })
433 .collect()
434}
435
436struct ThreadTest {
437 model: Arc<dyn LanguageModel>,
438 thread: Entity<Thread>,
439}
440
441enum TestModel {
442 Sonnet4,
443 Sonnet4Thinking,
444 Fake(Arc<FakeLanguageModel>),
445}
446
447impl TestModel {
448 fn id(&self) -> LanguageModelId {
449 match self {
450 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
451 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
452 TestModel::Fake(fake_model) => fake_model.id(),
453 }
454 }
455}
456
457async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
458 cx.executor().allow_parking();
459 cx.update(|cx| {
460 settings::init(cx);
461 Project::init_settings(cx);
462 });
463 let templates = Templates::new();
464
465 let fs = FakeFs::new(cx.background_executor.clone());
466 fs.insert_tree(path!("/test"), json!({})).await;
467 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
468
469 let model = cx
470 .update(|cx| {
471 gpui_tokio::init(cx);
472 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
473 cx.set_http_client(Arc::new(http_client));
474
475 client::init_settings(cx);
476 let client = Client::production(cx);
477 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
478 language_model::init(client.clone(), cx);
479 language_models::init(user_store.clone(), client.clone(), cx);
480
481 if let TestModel::Fake(model) = model {
482 Task::ready(model as Arc<_>)
483 } else {
484 let model_id = model.id();
485 let models = LanguageModelRegistry::read_global(cx);
486 let model = models
487 .available_models(cx)
488 .find(|model| model.id() == model_id)
489 .unwrap();
490
491 let provider = models.provider(&model.provider_id()).unwrap();
492 let authenticated = provider.authenticate(cx);
493
494 cx.spawn(async move |_cx| {
495 authenticated.await.unwrap();
496 model
497 })
498 }
499 })
500 .await;
501
502 let thread = cx.new(|_| Thread::new(project, templates, model.clone()));
503
504 ThreadTest { model, thread }
505}
506
507#[cfg(test)]
508#[ctor::ctor]
509fn init_logger() {
510 if std::env::var("RUST_LOG").is_ok() {
511 env_logger::init();
512 }
513}