1use super::*;
2use crate::templates::Templates;
3use acp_thread::AgentConnection;
4use agent_client_protocol as acp;
5use client::{Client, UserStore};
6use fs::FakeFs;
7use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
8use indoc::indoc;
9use language_model::{
10 fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
11 LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, MessageContent,
12 StopReason,
13};
14use project::Project;
15use reqwest_client::ReqwestClient;
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use serde_json::json;
19use smol::stream::StreamExt;
20use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
21use util::path;
22
23mod test_tools;
24use test_tools::*;
25
26#[gpui::test]
27#[ignore = "temporarily disabled until it can be run on CI"]
28async fn test_echo(cx: &mut TestAppContext) {
29 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
30
31 let events = thread
32 .update(cx, |thread, cx| {
33 thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
34 })
35 .collect()
36 .await;
37 thread.update(cx, |thread, _cx| {
38 assert_eq!(
39 thread.messages().last().unwrap().content,
40 vec![MessageContent::Text("Hello".to_string())]
41 );
42 });
43 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
44}
45
46#[gpui::test]
47#[ignore = "temporarily disabled until it can be run on CI"]
48async fn test_thinking(cx: &mut TestAppContext) {
49 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
50
51 let events = thread
52 .update(cx, |thread, cx| {
53 thread.send(
54 model.clone(),
55 indoc! {"
56 Testing:
57
58 Generate a thinking step where you just think the word 'Think',
59 and have your final answer be 'Hello'
60 "},
61 cx,
62 )
63 })
64 .collect()
65 .await;
66 thread.update(cx, |thread, _cx| {
67 assert_eq!(
68 thread.messages().last().unwrap().to_markdown(),
69 indoc! {"
70 ## assistant
71 <think>Think</think>
72 Hello
73 "}
74 )
75 });
76 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
77}
78
79#[gpui::test]
80#[ignore = "temporarily disabled until it can be run on CI"]
81async fn test_basic_tool_calls(cx: &mut TestAppContext) {
82 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
83
84 // Test a tool call that's likely to complete *before* streaming stops.
85 let events = thread
86 .update(cx, |thread, cx| {
87 thread.add_tool(EchoTool);
88 thread.send(
89 model.clone(),
90 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
91 cx,
92 )
93 })
94 .collect()
95 .await;
96 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
97
98 // Test a tool calls that's likely to complete *after* streaming stops.
99 let events = thread
100 .update(cx, |thread, cx| {
101 thread.remove_tool(&AgentTool::name(&EchoTool));
102 thread.add_tool(DelayTool);
103 thread.send(
104 model.clone(),
105 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
106 cx,
107 )
108 })
109 .collect()
110 .await;
111 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
112 thread.update(cx, |thread, _cx| {
113 assert!(thread
114 .messages()
115 .last()
116 .unwrap()
117 .content
118 .iter()
119 .any(|content| {
120 if let MessageContent::Text(text) = content {
121 text.contains("Ding")
122 } else {
123 false
124 }
125 }));
126 });
127}
128
129#[gpui::test]
130#[ignore = "temporarily disabled until it can be run on CI"]
131async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
132 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
133
134 // Test a tool call that's likely to complete *before* streaming stops.
135 let mut events = thread.update(cx, |thread, cx| {
136 thread.add_tool(WordListTool);
137 thread.send(model.clone(), "Test the word_list tool.", cx)
138 });
139
140 let mut saw_partial_tool_use = false;
141 while let Some(event) = events.next().await {
142 if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
143 thread.update(cx, |thread, _cx| {
144 // Look for a tool use in the thread's last message
145 let last_content = thread.messages().last().unwrap().content.last().unwrap();
146 if let MessageContent::ToolUse(last_tool_use) = last_content {
147 assert_eq!(last_tool_use.name.as_ref(), "word_list");
148 if tool_call.status == acp::ToolCallStatus::Pending {
149 if !last_tool_use.is_input_complete
150 && last_tool_use.input.get("g").is_none()
151 {
152 saw_partial_tool_use = true;
153 }
154 } else {
155 last_tool_use
156 .input
157 .get("a")
158 .expect("'a' has streamed because input is now complete");
159 last_tool_use
160 .input
161 .get("g")
162 .expect("'g' has streamed because input is now complete");
163 }
164 } else {
165 panic!("last content should be a tool use");
166 }
167 });
168 }
169 }
170
171 assert!(
172 saw_partial_tool_use,
173 "should see at least one partially streamed tool use in the history"
174 );
175}
176
177#[gpui::test]
178#[ignore = "temporarily disabled until it can be run on CI"]
179async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
180 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
181
182 // Test concurrent tool calls with different delay times
183 let events = thread
184 .update(cx, |thread, cx| {
185 thread.add_tool(DelayTool);
186 thread.send(
187 model.clone(),
188 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
189 cx,
190 )
191 })
192 .collect()
193 .await;
194
195 let stop_reasons = stop_events(events);
196 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
197
198 thread.update(cx, |thread, _cx| {
199 let last_message = thread.messages().last().unwrap();
200 let text = last_message
201 .content
202 .iter()
203 .filter_map(|content| {
204 if let MessageContent::Text(text) = content {
205 Some(text.as_str())
206 } else {
207 None
208 }
209 })
210 .collect::<String>();
211
212 assert!(text.contains("Ding"));
213 });
214}
215
216#[gpui::test]
217#[ignore = "temporarily disabled until it can be run on CI"]
218async fn test_cancellation(cx: &mut TestAppContext) {
219 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
220
221 let mut events = thread.update(cx, |thread, cx| {
222 thread.add_tool(InfiniteTool);
223 thread.add_tool(EchoTool);
224 thread.send(
225 model.clone(),
226 "Call the echo tool and then call the infinite tool, then explain their output",
227 cx,
228 )
229 });
230
231 // Wait until both tools are called.
232 let mut expected_tool_calls = vec!["echo", "infinite"];
233 let mut echo_id = None;
234 let mut echo_completed = false;
235 while let Some(event) = events.next().await {
236 match event.unwrap() {
237 AgentResponseEvent::ToolCall(tool_call) => {
238 assert_eq!(tool_call.title, expected_tool_calls.remove(0));
239 if tool_call.title == "echo" {
240 echo_id = Some(tool_call.id);
241 }
242 }
243 AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate {
244 id,
245 fields:
246 acp::ToolCallUpdateFields {
247 status: Some(acp::ToolCallStatus::Completed),
248 ..
249 },
250 }) if Some(&id) == echo_id.as_ref() => {
251 echo_completed = true;
252 }
253 _ => {}
254 }
255
256 if expected_tool_calls.is_empty() && echo_completed {
257 break;
258 }
259 }
260
261 // Cancel the current send and ensure that the event stream is closed, even
262 // if one of the tools is still running.
263 thread.update(cx, |thread, _cx| thread.cancel());
264 events.collect::<Vec<_>>().await;
265
266 // Ensure we can still send a new message after cancellation.
267 let events = thread
268 .update(cx, |thread, cx| {
269 thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
270 })
271 .collect::<Vec<_>>()
272 .await;
273 thread.update(cx, |thread, _cx| {
274 assert_eq!(
275 thread.messages().last().unwrap().content,
276 vec![MessageContent::Text("Hello".to_string())]
277 );
278 });
279 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
280}
281
282#[gpui::test]
283async fn test_refusal(cx: &mut TestAppContext) {
284 let fake_model = Arc::new(FakeLanguageModel::default());
285 let ThreadTest { thread, .. } = setup(cx, TestModel::Fake(fake_model.clone())).await;
286
287 let events = thread.update(cx, |thread, cx| {
288 thread.send(fake_model.clone(), "Hello", cx)
289 });
290 cx.run_until_parked();
291 thread.read_with(cx, |thread, _| {
292 assert_eq!(
293 thread.to_markdown(),
294 indoc! {"
295 ## user
296 Hello
297 "}
298 );
299 });
300
301 fake_model.send_last_completion_stream_text_chunk("Hey!");
302 cx.run_until_parked();
303 thread.read_with(cx, |thread, _| {
304 assert_eq!(
305 thread.to_markdown(),
306 indoc! {"
307 ## user
308 Hello
309 ## assistant
310 Hey!
311 "}
312 );
313 });
314
315 // If the model refuses to continue, the thread should remove all the messages after the last user message.
316 fake_model
317 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
318 let events = events.collect::<Vec<_>>().await;
319 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
320 thread.read_with(cx, |thread, _| {
321 assert_eq!(thread.to_markdown(), "");
322 });
323}
324
325#[gpui::test]
326async fn test_agent_connection(cx: &mut TestAppContext) {
327 cx.update(settings::init);
328 let templates = Templates::new();
329
330 // Initialize language model system with test provider
331 cx.update(|cx| {
332 gpui_tokio::init(cx);
333 client::init_settings(cx);
334
335 let http_client = FakeHttpClient::with_404_response();
336 let clock = Arc::new(clock::FakeSystemClock::new());
337 let client = Client::new(clock, http_client, cx);
338 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
339 language_model::init(client.clone(), cx);
340 language_models::init(user_store.clone(), client.clone(), cx);
341 Project::init_settings(cx);
342 LanguageModelRegistry::test(cx);
343 });
344 cx.executor().forbid_parking();
345
346 // Create agent and connection
347 let agent = cx.new(|_| NativeAgent::new(templates.clone()));
348 let connection = NativeAgentConnection(agent.clone());
349
350 // Test model_selector returns Some
351 let selector_opt = connection.model_selector();
352 assert!(
353 selector_opt.is_some(),
354 "agent2 should always support ModelSelector"
355 );
356 let selector = selector_opt.unwrap();
357
358 // Test list_models
359 let listed_models = cx
360 .update(|cx| {
361 let mut async_cx = cx.to_async();
362 selector.list_models(&mut async_cx)
363 })
364 .await
365 .expect("list_models should succeed");
366 assert!(!listed_models.is_empty(), "should have at least one model");
367 assert_eq!(listed_models[0].id().0, "fake");
368
369 // Create a project for new_thread
370 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
371 let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
372
373 // Create a thread using new_thread
374 let cwd = Path::new("/test");
375 let connection_rc = Rc::new(connection.clone());
376 let acp_thread = cx
377 .update(|cx| {
378 let mut async_cx = cx.to_async();
379 connection_rc.new_thread(project, cwd, &mut async_cx)
380 })
381 .await
382 .expect("new_thread should succeed");
383
384 // Get the session_id from the AcpThread
385 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
386
387 // Test selected_model returns the default
388 let model = cx
389 .update(|cx| {
390 let mut async_cx = cx.to_async();
391 selector.selected_model(&session_id, &mut async_cx)
392 })
393 .await
394 .expect("selected_model should succeed");
395 let model = model.as_fake();
396 assert_eq!(model.id().0, "fake", "should return default model");
397
398 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
399 cx.run_until_parked();
400 model.send_last_completion_stream_text_chunk("def");
401 cx.run_until_parked();
402 acp_thread.read_with(cx, |thread, cx| {
403 assert_eq!(
404 thread.to_markdown(cx),
405 indoc! {"
406 ## User
407
408 abc
409
410 ## Assistant
411
412 def
413
414 "}
415 )
416 });
417
418 // Test cancel
419 cx.update(|cx| connection.cancel(&session_id, cx));
420 request.await.expect("prompt should fail gracefully");
421
422 // Ensure that dropping the ACP thread causes the native thread to be
423 // dropped as well.
424 cx.update(|_| drop(acp_thread));
425 let result = cx
426 .update(|cx| {
427 connection.prompt(
428 acp::PromptRequest {
429 session_id: session_id.clone(),
430 prompt: vec!["ghi".into()],
431 },
432 cx,
433 )
434 })
435 .await;
436 assert_eq!(
437 result.as_ref().unwrap_err().to_string(),
438 "Session not found",
439 "unexpected result: {:?}",
440 result
441 );
442}
443
444/// Filters out the stop events for asserting against in tests
445fn stop_events(
446 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
447) -> Vec<acp::StopReason> {
448 result_events
449 .into_iter()
450 .filter_map(|event| match event.unwrap() {
451 AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
452 _ => None,
453 })
454 .collect()
455}
456
457struct ThreadTest {
458 model: Arc<dyn LanguageModel>,
459 thread: Entity<Thread>,
460}
461
462enum TestModel {
463 Sonnet4,
464 Sonnet4Thinking,
465 Fake(Arc<FakeLanguageModel>),
466}
467
468impl TestModel {
469 fn id(&self) -> LanguageModelId {
470 match self {
471 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
472 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
473 TestModel::Fake(fake_model) => fake_model.id(),
474 }
475 }
476}
477
478async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
479 cx.executor().allow_parking();
480 cx.update(|cx| {
481 settings::init(cx);
482 Project::init_settings(cx);
483 });
484 let templates = Templates::new();
485
486 let fs = FakeFs::new(cx.background_executor.clone());
487 fs.insert_tree(path!("/test"), json!({})).await;
488 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
489
490 let model = cx
491 .update(|cx| {
492 gpui_tokio::init(cx);
493 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
494 cx.set_http_client(Arc::new(http_client));
495
496 client::init_settings(cx);
497 let client = Client::production(cx);
498 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
499 language_model::init(client.clone(), cx);
500 language_models::init(user_store.clone(), client.clone(), cx);
501
502 if let TestModel::Fake(model) = model {
503 Task::ready(model as Arc<_>)
504 } else {
505 let model_id = model.id();
506 let models = LanguageModelRegistry::read_global(cx);
507 let model = models
508 .available_models(cx)
509 .find(|model| model.id() == model_id)
510 .unwrap();
511
512 let provider = models.provider(&model.provider_id()).unwrap();
513 let authenticated = provider.authenticate(cx);
514
515 cx.spawn(async move |_cx| {
516 authenticated.await.unwrap();
517 model
518 })
519 }
520 })
521 .await;
522
523 let thread = cx.new(|_| Thread::new(project, templates, model.clone()));
524
525 ThreadTest { model, thread }
526}
527
528#[cfg(test)]
529#[ctor::ctor]
530fn init_logger() {
531 if std::env::var("RUST_LOG").is_ok() {
532 env_logger::init();
533 }
534}