1use super::*;
2use acp_thread::AgentConnection;
3use agent_client_protocol::{self as acp};
4use anyhow::Result;
5use assistant_tool::ActionLog;
6use client::{Client, UserStore};
7use fs::FakeFs;
8use futures::channel::mpsc::UnboundedReceiver;
9use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
10use indoc::indoc;
11use language_model::{
12 fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
13 LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
14 LanguageModelToolUse, MessageContent, Role, StopReason,
15};
16use project::Project;
17use prompt_store::ProjectContext;
18use reqwest_client::ReqwestClient;
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use serde_json::json;
22use smol::stream::StreamExt;
23use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
24use util::path;
25
26mod test_tools;
27use test_tools::*;
28
29#[gpui::test]
30#[ignore = "can't run on CI yet"]
31async fn test_echo(cx: &mut TestAppContext) {
32 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
33
34 let events = thread
35 .update(cx, |thread, cx| {
36 thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
37 })
38 .collect()
39 .await;
40 thread.update(cx, |thread, _cx| {
41 assert_eq!(
42 thread.messages().last().unwrap().content,
43 vec![MessageContent::Text("Hello".to_string())]
44 );
45 });
46 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
47}
48
49#[gpui::test]
50#[ignore = "can't run on CI yet"]
51async fn test_thinking(cx: &mut TestAppContext) {
52 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
53
54 let events = thread
55 .update(cx, |thread, cx| {
56 thread.send(
57 model.clone(),
58 indoc! {"
59 Testing:
60
61 Generate a thinking step where you just think the word 'Think',
62 and have your final answer be 'Hello'
63 "},
64 cx,
65 )
66 })
67 .collect()
68 .await;
69 thread.update(cx, |thread, _cx| {
70 assert_eq!(
71 thread.messages().last().unwrap().to_markdown(),
72 indoc! {"
73 ## assistant
74 <think>Think</think>
75 Hello
76 "}
77 )
78 });
79 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
80}
81
82#[gpui::test]
83async fn test_system_prompt(cx: &mut TestAppContext) {
84 let ThreadTest {
85 model,
86 thread,
87 project_context,
88 ..
89 } = setup(cx, TestModel::Fake).await;
90 let fake_model = model.as_fake();
91
92 project_context.borrow_mut().shell = "test-shell".into();
93 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
94 thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
95 cx.run_until_parked();
96 let mut pending_completions = fake_model.pending_completions();
97 assert_eq!(
98 pending_completions.len(),
99 1,
100 "unexpected pending completions: {:?}",
101 pending_completions
102 );
103
104 let pending_completion = pending_completions.pop().unwrap();
105 assert_eq!(pending_completion.messages[0].role, Role::System);
106
107 let system_message = &pending_completion.messages[0];
108 let system_prompt = system_message.content[0].to_str().unwrap();
109 assert!(
110 system_prompt.contains("test-shell"),
111 "unexpected system message: {:?}",
112 system_message
113 );
114 assert!(
115 system_prompt.contains("## Fixing Diagnostics"),
116 "unexpected system message: {:?}",
117 system_message
118 );
119}
120
121#[gpui::test]
122#[ignore = "can't run on CI yet"]
123async fn test_basic_tool_calls(cx: &mut TestAppContext) {
124 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
125
126 // Test a tool call that's likely to complete *before* streaming stops.
127 let events = thread
128 .update(cx, |thread, cx| {
129 thread.add_tool(EchoTool);
130 thread.send(
131 model.clone(),
132 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
133 cx,
134 )
135 })
136 .collect()
137 .await;
138 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
139
140 // Test a tool calls that's likely to complete *after* streaming stops.
141 let events = thread
142 .update(cx, |thread, cx| {
143 thread.remove_tool(&AgentTool::name(&EchoTool));
144 thread.add_tool(DelayTool);
145 thread.send(
146 model.clone(),
147 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
148 cx,
149 )
150 })
151 .collect()
152 .await;
153 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
154 thread.update(cx, |thread, _cx| {
155 assert!(thread
156 .messages()
157 .last()
158 .unwrap()
159 .content
160 .iter()
161 .any(|content| {
162 if let MessageContent::Text(text) = content {
163 text.contains("Ding")
164 } else {
165 false
166 }
167 }));
168 });
169}
170
171#[gpui::test]
172#[ignore = "can't run on CI yet"]
173async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
174 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
175
176 // Test a tool call that's likely to complete *before* streaming stops.
177 let mut events = thread.update(cx, |thread, cx| {
178 thread.add_tool(WordListTool);
179 thread.send(model.clone(), "Test the word_list tool.", cx)
180 });
181
182 let mut saw_partial_tool_use = false;
183 while let Some(event) = events.next().await {
184 if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
185 thread.update(cx, |thread, _cx| {
186 // Look for a tool use in the thread's last message
187 let last_content = thread.messages().last().unwrap().content.last().unwrap();
188 if let MessageContent::ToolUse(last_tool_use) = last_content {
189 assert_eq!(last_tool_use.name.as_ref(), "word_list");
190 if tool_call.status == acp::ToolCallStatus::Pending {
191 if !last_tool_use.is_input_complete
192 && last_tool_use.input.get("g").is_none()
193 {
194 saw_partial_tool_use = true;
195 }
196 } else {
197 last_tool_use
198 .input
199 .get("a")
200 .expect("'a' has streamed because input is now complete");
201 last_tool_use
202 .input
203 .get("g")
204 .expect("'g' has streamed because input is now complete");
205 }
206 } else {
207 panic!("last content should be a tool use");
208 }
209 });
210 }
211 }
212
213 assert!(
214 saw_partial_tool_use,
215 "should see at least one partially streamed tool use in the history"
216 );
217}
218
219#[gpui::test]
220async fn test_tool_authorization(cx: &mut TestAppContext) {
221 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
222 let fake_model = model.as_fake();
223
224 let mut events = thread.update(cx, |thread, cx| {
225 thread.add_tool(ToolRequiringPermission);
226 thread.send(model.clone(), "abc", cx)
227 });
228 cx.run_until_parked();
229 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
230 LanguageModelToolUse {
231 id: "tool_id_1".into(),
232 name: ToolRequiringPermission.name().into(),
233 raw_input: "{}".into(),
234 input: json!({}),
235 is_input_complete: true,
236 },
237 ));
238 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
239 LanguageModelToolUse {
240 id: "tool_id_2".into(),
241 name: ToolRequiringPermission.name().into(),
242 raw_input: "{}".into(),
243 input: json!({}),
244 is_input_complete: true,
245 },
246 ));
247 fake_model.end_last_completion_stream();
248 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
249 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
250
251 // Approve the first
252 tool_call_auth_1
253 .response
254 .send(tool_call_auth_1.options[1].id.clone())
255 .unwrap();
256 cx.run_until_parked();
257
258 // Reject the second
259 tool_call_auth_2
260 .response
261 .send(tool_call_auth_1.options[2].id.clone())
262 .unwrap();
263 cx.run_until_parked();
264
265 let completion = fake_model.pending_completions().pop().unwrap();
266 let message = completion.messages.last().unwrap();
267 assert_eq!(
268 message.content,
269 vec![
270 MessageContent::ToolResult(LanguageModelToolResult {
271 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
272 tool_name: ToolRequiringPermission.name().into(),
273 is_error: false,
274 content: "Allowed".into(),
275 output: Some("Allowed".into())
276 }),
277 MessageContent::ToolResult(LanguageModelToolResult {
278 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
279 tool_name: ToolRequiringPermission.name().into(),
280 is_error: true,
281 content: "Permission to run tool denied by user".into(),
282 output: None
283 })
284 ]
285 );
286}
287
288#[gpui::test]
289async fn test_tool_hallucination(cx: &mut TestAppContext) {
290 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
291 let fake_model = model.as_fake();
292
293 let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
294 cx.run_until_parked();
295 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
296 LanguageModelToolUse {
297 id: "tool_id_1".into(),
298 name: "nonexistent_tool".into(),
299 raw_input: "{}".into(),
300 input: json!({}),
301 is_input_complete: true,
302 },
303 ));
304 fake_model.end_last_completion_stream();
305
306 let tool_call = expect_tool_call(&mut events).await;
307 assert_eq!(tool_call.title, "nonexistent_tool");
308 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
309 let update = expect_tool_call_update_fields(&mut events).await;
310 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
311}
312
313async fn expect_tool_call(
314 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
315) -> acp::ToolCall {
316 let event = events
317 .next()
318 .await
319 .expect("no tool call authorization event received")
320 .unwrap();
321 match event {
322 AgentResponseEvent::ToolCall(tool_call) => return tool_call,
323 event => {
324 panic!("Unexpected event {event:?}");
325 }
326 }
327}
328
329async fn expect_tool_call_update_fields(
330 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
331) -> acp::ToolCallUpdate {
332 let event = events
333 .next()
334 .await
335 .expect("no tool call authorization event received")
336 .unwrap();
337 match event {
338 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
339 return update
340 }
341 event => {
342 panic!("Unexpected event {event:?}");
343 }
344 }
345}
346
347async fn next_tool_call_authorization(
348 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
349) -> ToolCallAuthorization {
350 loop {
351 let event = events
352 .next()
353 .await
354 .expect("no tool call authorization event received")
355 .unwrap();
356 if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
357 let permission_kinds = tool_call_authorization
358 .options
359 .iter()
360 .map(|o| o.kind)
361 .collect::<Vec<_>>();
362 assert_eq!(
363 permission_kinds,
364 vec![
365 acp::PermissionOptionKind::AllowAlways,
366 acp::PermissionOptionKind::AllowOnce,
367 acp::PermissionOptionKind::RejectOnce,
368 ]
369 );
370 return tool_call_authorization;
371 }
372 }
373}
374
375#[gpui::test]
376#[ignore = "can't run on CI yet"]
377async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
378 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
379
380 // Test concurrent tool calls with different delay times
381 let events = thread
382 .update(cx, |thread, cx| {
383 thread.add_tool(DelayTool);
384 thread.send(
385 model.clone(),
386 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
387 cx,
388 )
389 })
390 .collect()
391 .await;
392
393 let stop_reasons = stop_events(events);
394 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
395
396 thread.update(cx, |thread, _cx| {
397 let last_message = thread.messages().last().unwrap();
398 let text = last_message
399 .content
400 .iter()
401 .filter_map(|content| {
402 if let MessageContent::Text(text) = content {
403 Some(text.as_str())
404 } else {
405 None
406 }
407 })
408 .collect::<String>();
409
410 assert!(text.contains("Ding"));
411 });
412}
413
414#[gpui::test]
415#[ignore = "can't run on CI yet"]
416async fn test_cancellation(cx: &mut TestAppContext) {
417 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
418
419 let mut events = thread.update(cx, |thread, cx| {
420 thread.add_tool(InfiniteTool);
421 thread.add_tool(EchoTool);
422 thread.send(
423 model.clone(),
424 "Call the echo tool and then call the infinite tool, then explain their output",
425 cx,
426 )
427 });
428
429 // Wait until both tools are called.
430 let mut expected_tools = vec!["Echo", "Infinite Tool"];
431 let mut echo_id = None;
432 let mut echo_completed = false;
433 while let Some(event) = events.next().await {
434 match event.unwrap() {
435 AgentResponseEvent::ToolCall(tool_call) => {
436 assert_eq!(tool_call.title, expected_tools.remove(0));
437 if tool_call.title == "Echo" {
438 echo_id = Some(tool_call.id);
439 }
440 }
441 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
442 acp::ToolCallUpdate {
443 id,
444 fields:
445 acp::ToolCallUpdateFields {
446 status: Some(acp::ToolCallStatus::Completed),
447 ..
448 },
449 },
450 )) if Some(&id) == echo_id.as_ref() => {
451 echo_completed = true;
452 }
453 _ => {}
454 }
455
456 if expected_tools.is_empty() && echo_completed {
457 break;
458 }
459 }
460
461 // Cancel the current send and ensure that the event stream is closed, even
462 // if one of the tools is still running.
463 thread.update(cx, |thread, _cx| thread.cancel());
464 events.collect::<Vec<_>>().await;
465
466 // Ensure we can still send a new message after cancellation.
467 let events = thread
468 .update(cx, |thread, cx| {
469 thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
470 })
471 .collect::<Vec<_>>()
472 .await;
473 thread.update(cx, |thread, _cx| {
474 assert_eq!(
475 thread.messages().last().unwrap().content,
476 vec![MessageContent::Text("Hello".to_string())]
477 );
478 });
479 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
480}
481
482#[gpui::test]
483async fn test_refusal(cx: &mut TestAppContext) {
484 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
485 let fake_model = model.as_fake();
486
487 let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
488 cx.run_until_parked();
489 thread.read_with(cx, |thread, _| {
490 assert_eq!(
491 thread.to_markdown(),
492 indoc! {"
493 ## user
494 Hello
495 "}
496 );
497 });
498
499 fake_model.send_last_completion_stream_text_chunk("Hey!");
500 cx.run_until_parked();
501 thread.read_with(cx, |thread, _| {
502 assert_eq!(
503 thread.to_markdown(),
504 indoc! {"
505 ## user
506 Hello
507 ## assistant
508 Hey!
509 "}
510 );
511 });
512
513 // If the model refuses to continue, the thread should remove all the messages after the last user message.
514 fake_model
515 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
516 let events = events.collect::<Vec<_>>().await;
517 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
518 thread.read_with(cx, |thread, _| {
519 assert_eq!(thread.to_markdown(), "");
520 });
521}
522
523#[gpui::test]
524async fn test_agent_connection(cx: &mut TestAppContext) {
525 cx.update(settings::init);
526 let templates = Templates::new();
527
528 // Initialize language model system with test provider
529 cx.update(|cx| {
530 gpui_tokio::init(cx);
531 client::init_settings(cx);
532
533 let http_client = FakeHttpClient::with_404_response();
534 let clock = Arc::new(clock::FakeSystemClock::new());
535 let client = Client::new(clock, http_client, cx);
536 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
537 language_model::init(client.clone(), cx);
538 language_models::init(user_store.clone(), client.clone(), cx);
539 Project::init_settings(cx);
540 LanguageModelRegistry::test(cx);
541 });
542 cx.executor().forbid_parking();
543
544 // Create a project for new_thread
545 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
546 fake_fs.insert_tree(path!("/test"), json!({})).await;
547 let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
548 let cwd = Path::new("/test");
549
550 // Create agent and connection
551 let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
552 .await
553 .unwrap();
554 let connection = NativeAgentConnection(agent.clone());
555
556 // Test model_selector returns Some
557 let selector_opt = connection.model_selector();
558 assert!(
559 selector_opt.is_some(),
560 "agent2 should always support ModelSelector"
561 );
562 let selector = selector_opt.unwrap();
563
564 // Test list_models
565 let listed_models = cx
566 .update(|cx| {
567 let mut async_cx = cx.to_async();
568 selector.list_models(&mut async_cx)
569 })
570 .await
571 .expect("list_models should succeed");
572 assert!(!listed_models.is_empty(), "should have at least one model");
573 assert_eq!(listed_models[0].id().0, "fake");
574
575 // Create a thread using new_thread
576 let connection_rc = Rc::new(connection.clone());
577 let acp_thread = cx
578 .update(|cx| {
579 let mut async_cx = cx.to_async();
580 connection_rc.new_thread(project, cwd, &mut async_cx)
581 })
582 .await
583 .expect("new_thread should succeed");
584
585 // Get the session_id from the AcpThread
586 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
587
588 // Test selected_model returns the default
589 let model = cx
590 .update(|cx| {
591 let mut async_cx = cx.to_async();
592 selector.selected_model(&session_id, &mut async_cx)
593 })
594 .await
595 .expect("selected_model should succeed");
596 let model = model.as_fake();
597 assert_eq!(model.id().0, "fake", "should return default model");
598
599 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
600 cx.run_until_parked();
601 model.send_last_completion_stream_text_chunk("def");
602 cx.run_until_parked();
603 acp_thread.read_with(cx, |thread, cx| {
604 assert_eq!(
605 thread.to_markdown(cx),
606 indoc! {"
607 ## User
608
609 abc
610
611 ## Assistant
612
613 def
614
615 "}
616 )
617 });
618
619 // Test cancel
620 cx.update(|cx| connection.cancel(&session_id, cx));
621 request.await.expect("prompt should fail gracefully");
622
623 // Ensure that dropping the ACP thread causes the native thread to be
624 // dropped as well.
625 cx.update(|_| drop(acp_thread));
626 let result = cx
627 .update(|cx| {
628 connection.prompt(
629 acp::PromptRequest {
630 session_id: session_id.clone(),
631 prompt: vec!["ghi".into()],
632 },
633 cx,
634 )
635 })
636 .await;
637 assert_eq!(
638 result.as_ref().unwrap_err().to_string(),
639 "Session not found",
640 "unexpected result: {:?}",
641 result
642 );
643}
644
645#[gpui::test]
646async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
647 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
648 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
649 let fake_model = model.as_fake();
650
651 let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx));
652 cx.run_until_parked();
653
654 // Simulate streaming partial input.
655 let input = json!({});
656 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
657 LanguageModelToolUse {
658 id: "1".into(),
659 name: ThinkingTool.name().into(),
660 raw_input: input.to_string(),
661 input,
662 is_input_complete: false,
663 },
664 ));
665
666 // Input streaming completed
667 let input = json!({ "content": "Thinking hard!" });
668 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
669 LanguageModelToolUse {
670 id: "1".into(),
671 name: "thinking".into(),
672 raw_input: input.to_string(),
673 input,
674 is_input_complete: true,
675 },
676 ));
677 fake_model.end_last_completion_stream();
678 cx.run_until_parked();
679
680 let tool_call = expect_tool_call(&mut events).await;
681 assert_eq!(
682 tool_call,
683 acp::ToolCall {
684 id: acp::ToolCallId("1".into()),
685 title: "Thinking".into(),
686 kind: acp::ToolKind::Think,
687 status: acp::ToolCallStatus::Pending,
688 content: vec![],
689 locations: vec![],
690 raw_input: Some(json!({})),
691 raw_output: None,
692 }
693 );
694 let update = expect_tool_call_update_fields(&mut events).await;
695 assert_eq!(
696 update,
697 acp::ToolCallUpdate {
698 id: acp::ToolCallId("1".into()),
699 fields: acp::ToolCallUpdateFields {
700 title: Some("Thinking".into()),
701 kind: Some(acp::ToolKind::Think),
702 raw_input: Some(json!({ "content": "Thinking hard!" })),
703 ..Default::default()
704 },
705 }
706 );
707 let update = expect_tool_call_update_fields(&mut events).await;
708 assert_eq!(
709 update,
710 acp::ToolCallUpdate {
711 id: acp::ToolCallId("1".into()),
712 fields: acp::ToolCallUpdateFields {
713 status: Some(acp::ToolCallStatus::InProgress),
714 ..Default::default()
715 },
716 }
717 );
718 let update = expect_tool_call_update_fields(&mut events).await;
719 assert_eq!(
720 update,
721 acp::ToolCallUpdate {
722 id: acp::ToolCallId("1".into()),
723 fields: acp::ToolCallUpdateFields {
724 content: Some(vec!["Thinking hard!".into()]),
725 ..Default::default()
726 },
727 }
728 );
729 let update = expect_tool_call_update_fields(&mut events).await;
730 assert_eq!(
731 update,
732 acp::ToolCallUpdate {
733 id: acp::ToolCallId("1".into()),
734 fields: acp::ToolCallUpdateFields {
735 status: Some(acp::ToolCallStatus::Completed),
736 ..Default::default()
737 },
738 }
739 );
740}
741
742/// Filters out the stop events for asserting against in tests
743fn stop_events(
744 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
745) -> Vec<acp::StopReason> {
746 result_events
747 .into_iter()
748 .filter_map(|event| match event.unwrap() {
749 AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
750 _ => None,
751 })
752 .collect()
753}
754
755struct ThreadTest {
756 model: Arc<dyn LanguageModel>,
757 thread: Entity<Thread>,
758 project_context: Rc<RefCell<ProjectContext>>,
759}
760
761enum TestModel {
762 Sonnet4,
763 Sonnet4Thinking,
764 Fake,
765}
766
767impl TestModel {
768 fn id(&self) -> LanguageModelId {
769 match self {
770 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
771 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
772 TestModel::Fake => unreachable!(),
773 }
774 }
775}
776
777async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
778 cx.executor().allow_parking();
779 cx.update(|cx| {
780 settings::init(cx);
781 Project::init_settings(cx);
782 });
783 let templates = Templates::new();
784
785 let fs = FakeFs::new(cx.background_executor.clone());
786 fs.insert_tree(path!("/test"), json!({})).await;
787 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
788
789 let model = cx
790 .update(|cx| {
791 gpui_tokio::init(cx);
792 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
793 cx.set_http_client(Arc::new(http_client));
794
795 client::init_settings(cx);
796 let client = Client::production(cx);
797 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
798 language_model::init(client.clone(), cx);
799 language_models::init(user_store.clone(), client.clone(), cx);
800
801 if let TestModel::Fake = model {
802 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
803 } else {
804 let model_id = model.id();
805 let models = LanguageModelRegistry::read_global(cx);
806 let model = models
807 .available_models(cx)
808 .find(|model| model.id() == model_id)
809 .unwrap();
810
811 let provider = models.provider(&model.provider_id()).unwrap();
812 let authenticated = provider.authenticate(cx);
813
814 cx.spawn(async move |_cx| {
815 authenticated.await.unwrap();
816 model
817 })
818 }
819 })
820 .await;
821
822 let project_context = Rc::new(RefCell::new(ProjectContext::default()));
823 let action_log = cx.new(|_| ActionLog::new(project.clone()));
824 let thread = cx.new(|_| {
825 Thread::new(
826 project,
827 project_context.clone(),
828 action_log,
829 templates,
830 model.clone(),
831 )
832 });
833 ThreadTest {
834 model,
835 thread,
836 project_context,
837 }
838}
839
840#[cfg(test)]
841#[ctor::ctor]
842fn init_logger() {
843 if std::env::var("RUST_LOG").is_ok() {
844 env_logger::init();
845 }
846}