1use super::*;
2use acp_thread::AgentConnection;
3use agent_client_protocol::{self as acp};
4use anyhow::Result;
5use assistant_tool::ActionLog;
6use client::{Client, UserStore};
7use fs::FakeFs;
8use futures::channel::mpsc::UnboundedReceiver;
9use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
10use indoc::indoc;
11use language_model::{
12 fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
13 LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
14 LanguageModelToolUse, MessageContent, Role, StopReason,
15};
16use project::Project;
17use prompt_store::ProjectContext;
18use reqwest_client::ReqwestClient;
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use serde_json::json;
22use smol::stream::StreamExt;
23use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
24use util::path;
25
26mod test_tools;
27use test_tools::*;
28
29#[gpui::test]
30#[ignore = "can't run on CI yet"]
31async fn test_echo(cx: &mut TestAppContext) {
32 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
33
34 let events = thread
35 .update(cx, |thread, cx| {
36 thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
37 })
38 .collect()
39 .await;
40 thread.update(cx, |thread, _cx| {
41 assert_eq!(
42 thread.messages().last().unwrap().content,
43 vec![MessageContent::Text("Hello".to_string())]
44 );
45 });
46 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
47}
48
49#[gpui::test]
50#[ignore = "can't run on CI yet"]
51async fn test_thinking(cx: &mut TestAppContext) {
52 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
53
54 let events = thread
55 .update(cx, |thread, cx| {
56 thread.send(
57 model.clone(),
58 indoc! {"
59 Testing:
60
61 Generate a thinking step where you just think the word 'Think',
62 and have your final answer be 'Hello'
63 "},
64 cx,
65 )
66 })
67 .collect()
68 .await;
69 thread.update(cx, |thread, _cx| {
70 assert_eq!(
71 thread.messages().last().unwrap().to_markdown(),
72 indoc! {"
73 ## assistant
74 <think>Think</think>
75 Hello
76 "}
77 )
78 });
79 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
80}
81
82#[gpui::test]
83async fn test_system_prompt(cx: &mut TestAppContext) {
84 let ThreadTest {
85 model,
86 thread,
87 project_context,
88 ..
89 } = setup(cx, TestModel::Fake).await;
90 let fake_model = model.as_fake();
91
92 project_context.borrow_mut().shell = "test-shell".into();
93 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
94 thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
95 cx.run_until_parked();
96 let mut pending_completions = fake_model.pending_completions();
97 assert_eq!(
98 pending_completions.len(),
99 1,
100 "unexpected pending completions: {:?}",
101 pending_completions
102 );
103
104 let pending_completion = pending_completions.pop().unwrap();
105 assert_eq!(pending_completion.messages[0].role, Role::System);
106
107 let system_message = &pending_completion.messages[0];
108 let system_prompt = system_message.content[0].to_str().unwrap();
109 assert!(
110 system_prompt.contains("test-shell"),
111 "unexpected system message: {:?}",
112 system_message
113 );
114 assert!(
115 system_prompt.contains("## Fixing Diagnostics"),
116 "unexpected system message: {:?}",
117 system_message
118 );
119}
120
121#[gpui::test]
122#[ignore = "can't run on CI yet"]
123async fn test_basic_tool_calls(cx: &mut TestAppContext) {
124 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
125
126 // Test a tool call that's likely to complete *before* streaming stops.
127 let events = thread
128 .update(cx, |thread, cx| {
129 thread.add_tool(EchoTool);
130 thread.send(
131 model.clone(),
132 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
133 cx,
134 )
135 })
136 .collect()
137 .await;
138 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
139
140 // Test a tool calls that's likely to complete *after* streaming stops.
141 let events = thread
142 .update(cx, |thread, cx| {
143 thread.remove_tool(&AgentTool::name(&EchoTool));
144 thread.add_tool(DelayTool);
145 thread.send(
146 model.clone(),
147 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
148 cx,
149 )
150 })
151 .collect()
152 .await;
153 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
154 thread.update(cx, |thread, _cx| {
155 assert!(thread
156 .messages()
157 .last()
158 .unwrap()
159 .content
160 .iter()
161 .any(|content| {
162 if let MessageContent::Text(text) = content {
163 text.contains("Ding")
164 } else {
165 false
166 }
167 }));
168 });
169}
170
171#[gpui::test]
172#[ignore = "can't run on CI yet"]
173async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
174 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
175
176 // Test a tool call that's likely to complete *before* streaming stops.
177 let mut events = thread.update(cx, |thread, cx| {
178 thread.add_tool(WordListTool);
179 thread.send(model.clone(), "Test the word_list tool.", cx)
180 });
181
182 let mut saw_partial_tool_use = false;
183 while let Some(event) = events.next().await {
184 if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
185 thread.update(cx, |thread, _cx| {
186 // Look for a tool use in the thread's last message
187 let last_content = thread.messages().last().unwrap().content.last().unwrap();
188 if let MessageContent::ToolUse(last_tool_use) = last_content {
189 assert_eq!(last_tool_use.name.as_ref(), "word_list");
190 if tool_call.status == acp::ToolCallStatus::Pending {
191 if !last_tool_use.is_input_complete
192 && last_tool_use.input.get("g").is_none()
193 {
194 saw_partial_tool_use = true;
195 }
196 } else {
197 last_tool_use
198 .input
199 .get("a")
200 .expect("'a' has streamed because input is now complete");
201 last_tool_use
202 .input
203 .get("g")
204 .expect("'g' has streamed because input is now complete");
205 }
206 } else {
207 panic!("last content should be a tool use");
208 }
209 });
210 }
211 }
212
213 assert!(
214 saw_partial_tool_use,
215 "should see at least one partially streamed tool use in the history"
216 );
217}
218
219#[gpui::test]
220async fn test_tool_authorization(cx: &mut TestAppContext) {
221 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
222 let fake_model = model.as_fake();
223
224 let mut events = thread.update(cx, |thread, cx| {
225 thread.add_tool(ToolRequiringPermission);
226 thread.send(model.clone(), "abc", cx)
227 });
228 cx.run_until_parked();
229 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
230 LanguageModelToolUse {
231 id: "tool_id_1".into(),
232 name: ToolRequiringPermission.name().into(),
233 raw_input: "{}".into(),
234 input: json!({}),
235 is_input_complete: true,
236 },
237 ));
238 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
239 LanguageModelToolUse {
240 id: "tool_id_2".into(),
241 name: ToolRequiringPermission.name().into(),
242 raw_input: "{}".into(),
243 input: json!({}),
244 is_input_complete: true,
245 },
246 ));
247 fake_model.end_last_completion_stream();
248 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
249 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
250
251 // Approve the first
252 tool_call_auth_1
253 .response
254 .send(tool_call_auth_1.options[1].id.clone())
255 .unwrap();
256 cx.run_until_parked();
257
258 // Reject the second
259 tool_call_auth_2
260 .response
261 .send(tool_call_auth_1.options[2].id.clone())
262 .unwrap();
263 cx.run_until_parked();
264
265 let completion = fake_model.pending_completions().pop().unwrap();
266 let message = completion.messages.last().unwrap();
267 assert_eq!(
268 message.content,
269 vec![
270 MessageContent::ToolResult(LanguageModelToolResult {
271 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
272 tool_name: ToolRequiringPermission.name().into(),
273 is_error: false,
274 content: "Allowed".into(),
275 output: Some("Allowed".into())
276 }),
277 MessageContent::ToolResult(LanguageModelToolResult {
278 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
279 tool_name: ToolRequiringPermission.name().into(),
280 is_error: true,
281 content: "Permission to run tool denied by user".into(),
282 output: None
283 })
284 ]
285 );
286}
287
288#[gpui::test]
289async fn test_tool_hallucination(cx: &mut TestAppContext) {
290 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
291 let fake_model = model.as_fake();
292
293 let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
294 cx.run_until_parked();
295 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
296 LanguageModelToolUse {
297 id: "tool_id_1".into(),
298 name: "nonexistent_tool".into(),
299 raw_input: "{}".into(),
300 input: json!({}),
301 is_input_complete: true,
302 },
303 ));
304 fake_model.end_last_completion_stream();
305
306 let tool_call = expect_tool_call(&mut events).await;
307 assert_eq!(tool_call.title, "nonexistent_tool");
308 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
309 let update = expect_tool_call_update(&mut events).await;
310 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
311}
312
313async fn expect_tool_call(
314 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
315) -> acp::ToolCall {
316 let event = events
317 .next()
318 .await
319 .expect("no tool call authorization event received")
320 .unwrap();
321 match event {
322 AgentResponseEvent::ToolCall(tool_call) => return tool_call,
323 event => {
324 panic!("Unexpected event {event:?}");
325 }
326 }
327}
328
329async fn expect_tool_call_update(
330 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
331) -> acp::ToolCallUpdate {
332 let event = events
333 .next()
334 .await
335 .expect("no tool call authorization event received")
336 .unwrap();
337 match event {
338 AgentResponseEvent::ToolCallUpdate(tool_call_update) => return tool_call_update,
339 event => {
340 panic!("Unexpected event {event:?}");
341 }
342 }
343}
344
345async fn next_tool_call_authorization(
346 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
347) -> ToolCallAuthorization {
348 loop {
349 let event = events
350 .next()
351 .await
352 .expect("no tool call authorization event received")
353 .unwrap();
354 if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
355 let permission_kinds = tool_call_authorization
356 .options
357 .iter()
358 .map(|o| o.kind)
359 .collect::<Vec<_>>();
360 assert_eq!(
361 permission_kinds,
362 vec![
363 acp::PermissionOptionKind::AllowAlways,
364 acp::PermissionOptionKind::AllowOnce,
365 acp::PermissionOptionKind::RejectOnce,
366 ]
367 );
368 return tool_call_authorization;
369 }
370 }
371}
372
373#[gpui::test]
374#[ignore = "can't run on CI yet"]
375async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
376 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
377
378 // Test concurrent tool calls with different delay times
379 let events = thread
380 .update(cx, |thread, cx| {
381 thread.add_tool(DelayTool);
382 thread.send(
383 model.clone(),
384 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
385 cx,
386 )
387 })
388 .collect()
389 .await;
390
391 let stop_reasons = stop_events(events);
392 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
393
394 thread.update(cx, |thread, _cx| {
395 let last_message = thread.messages().last().unwrap();
396 let text = last_message
397 .content
398 .iter()
399 .filter_map(|content| {
400 if let MessageContent::Text(text) = content {
401 Some(text.as_str())
402 } else {
403 None
404 }
405 })
406 .collect::<String>();
407
408 assert!(text.contains("Ding"));
409 });
410}
411
412#[gpui::test]
413#[ignore = "can't run on CI yet"]
414async fn test_cancellation(cx: &mut TestAppContext) {
415 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
416
417 let mut events = thread.update(cx, |thread, cx| {
418 thread.add_tool(InfiniteTool);
419 thread.add_tool(EchoTool);
420 thread.send(
421 model.clone(),
422 "Call the echo tool and then call the infinite tool, then explain their output",
423 cx,
424 )
425 });
426
427 // Wait until both tools are called.
428 let mut expected_tool_calls = vec!["echo", "infinite"];
429 let mut echo_id = None;
430 let mut echo_completed = false;
431 while let Some(event) = events.next().await {
432 match event.unwrap() {
433 AgentResponseEvent::ToolCall(tool_call) => {
434 assert_eq!(tool_call.title, expected_tool_calls.remove(0));
435 if tool_call.title == "echo" {
436 echo_id = Some(tool_call.id);
437 }
438 }
439 AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate {
440 id,
441 fields:
442 acp::ToolCallUpdateFields {
443 status: Some(acp::ToolCallStatus::Completed),
444 ..
445 },
446 }) if Some(&id) == echo_id.as_ref() => {
447 echo_completed = true;
448 }
449 _ => {}
450 }
451
452 if expected_tool_calls.is_empty() && echo_completed {
453 break;
454 }
455 }
456
457 // Cancel the current send and ensure that the event stream is closed, even
458 // if one of the tools is still running.
459 thread.update(cx, |thread, _cx| thread.cancel());
460 events.collect::<Vec<_>>().await;
461
462 // Ensure we can still send a new message after cancellation.
463 let events = thread
464 .update(cx, |thread, cx| {
465 thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
466 })
467 .collect::<Vec<_>>()
468 .await;
469 thread.update(cx, |thread, _cx| {
470 assert_eq!(
471 thread.messages().last().unwrap().content,
472 vec![MessageContent::Text("Hello".to_string())]
473 );
474 });
475 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
476}
477
478#[gpui::test]
479async fn test_refusal(cx: &mut TestAppContext) {
480 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
481 let fake_model = model.as_fake();
482
483 let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
484 cx.run_until_parked();
485 thread.read_with(cx, |thread, _| {
486 assert_eq!(
487 thread.to_markdown(),
488 indoc! {"
489 ## user
490 Hello
491 "}
492 );
493 });
494
495 fake_model.send_last_completion_stream_text_chunk("Hey!");
496 cx.run_until_parked();
497 thread.read_with(cx, |thread, _| {
498 assert_eq!(
499 thread.to_markdown(),
500 indoc! {"
501 ## user
502 Hello
503 ## assistant
504 Hey!
505 "}
506 );
507 });
508
509 // If the model refuses to continue, the thread should remove all the messages after the last user message.
510 fake_model
511 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
512 let events = events.collect::<Vec<_>>().await;
513 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
514 thread.read_with(cx, |thread, _| {
515 assert_eq!(thread.to_markdown(), "");
516 });
517}
518
519#[gpui::test]
520async fn test_agent_connection(cx: &mut TestAppContext) {
521 cx.update(settings::init);
522 let templates = Templates::new();
523
524 // Initialize language model system with test provider
525 cx.update(|cx| {
526 gpui_tokio::init(cx);
527 client::init_settings(cx);
528
529 let http_client = FakeHttpClient::with_404_response();
530 let clock = Arc::new(clock::FakeSystemClock::new());
531 let client = Client::new(clock, http_client, cx);
532 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
533 language_model::init(client.clone(), cx);
534 language_models::init(user_store.clone(), client.clone(), cx);
535 Project::init_settings(cx);
536 LanguageModelRegistry::test(cx);
537 });
538 cx.executor().forbid_parking();
539
540 // Create a project for new_thread
541 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
542 fake_fs.insert_tree(path!("/test"), json!({})).await;
543 let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
544 let cwd = Path::new("/test");
545
546 // Create agent and connection
547 let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
548 .await
549 .unwrap();
550 let connection = NativeAgentConnection(agent.clone());
551
552 // Test model_selector returns Some
553 let selector_opt = connection.model_selector();
554 assert!(
555 selector_opt.is_some(),
556 "agent2 should always support ModelSelector"
557 );
558 let selector = selector_opt.unwrap();
559
560 // Test list_models
561 let listed_models = cx
562 .update(|cx| {
563 let mut async_cx = cx.to_async();
564 selector.list_models(&mut async_cx)
565 })
566 .await
567 .expect("list_models should succeed");
568 assert!(!listed_models.is_empty(), "should have at least one model");
569 assert_eq!(listed_models[0].id().0, "fake");
570
571 // Create a thread using new_thread
572 let connection_rc = Rc::new(connection.clone());
573 let acp_thread = cx
574 .update(|cx| {
575 let mut async_cx = cx.to_async();
576 connection_rc.new_thread(project, cwd, &mut async_cx)
577 })
578 .await
579 .expect("new_thread should succeed");
580
581 // Get the session_id from the AcpThread
582 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
583
584 // Test selected_model returns the default
585 let model = cx
586 .update(|cx| {
587 let mut async_cx = cx.to_async();
588 selector.selected_model(&session_id, &mut async_cx)
589 })
590 .await
591 .expect("selected_model should succeed");
592 let model = model.as_fake();
593 assert_eq!(model.id().0, "fake", "should return default model");
594
595 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
596 cx.run_until_parked();
597 model.send_last_completion_stream_text_chunk("def");
598 cx.run_until_parked();
599 acp_thread.read_with(cx, |thread, cx| {
600 assert_eq!(
601 thread.to_markdown(cx),
602 indoc! {"
603 ## User
604
605 abc
606
607 ## Assistant
608
609 def
610
611 "}
612 )
613 });
614
615 // Test cancel
616 cx.update(|cx| connection.cancel(&session_id, cx));
617 request.await.expect("prompt should fail gracefully");
618
619 // Ensure that dropping the ACP thread causes the native thread to be
620 // dropped as well.
621 cx.update(|_| drop(acp_thread));
622 let result = cx
623 .update(|cx| {
624 connection.prompt(
625 acp::PromptRequest {
626 session_id: session_id.clone(),
627 prompt: vec!["ghi".into()],
628 },
629 cx,
630 )
631 })
632 .await;
633 assert_eq!(
634 result.as_ref().unwrap_err().to_string(),
635 "Session not found",
636 "unexpected result: {:?}",
637 result
638 );
639}
640
641#[gpui::test]
642async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
643 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
644 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
645 let fake_model = model.as_fake();
646
647 let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx));
648 cx.run_until_parked();
649
650 // Simulate streaming partial input.
651 let input = json!({});
652 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
653 LanguageModelToolUse {
654 id: "1".into(),
655 name: ThinkingTool.name().into(),
656 raw_input: input.to_string(),
657 input,
658 is_input_complete: false,
659 },
660 ));
661
662 // Input streaming completed
663 let input = json!({ "content": "Thinking hard!" });
664 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
665 LanguageModelToolUse {
666 id: "1".into(),
667 name: ThinkingTool.name().into(),
668 raw_input: input.to_string(),
669 input,
670 is_input_complete: true,
671 },
672 ));
673 fake_model.end_last_completion_stream();
674 cx.run_until_parked();
675
676 let tool_call = expect_tool_call(&mut events).await;
677 assert_eq!(
678 tool_call,
679 acp::ToolCall {
680 id: acp::ToolCallId("1".into()),
681 title: "thinking".into(),
682 kind: acp::ToolKind::Think,
683 status: acp::ToolCallStatus::Pending,
684 content: vec![],
685 locations: vec![],
686 raw_input: Some(json!({})),
687 raw_output: None,
688 }
689 );
690 let update = expect_tool_call_update(&mut events).await;
691 assert_eq!(
692 update,
693 acp::ToolCallUpdate {
694 id: acp::ToolCallId("1".into()),
695 fields: acp::ToolCallUpdateFields {
696 title: Some("Thinking".into()),
697 kind: Some(acp::ToolKind::Think),
698 raw_input: Some(json!({ "content": "Thinking hard!" })),
699 ..Default::default()
700 },
701 }
702 );
703 let update = expect_tool_call_update(&mut events).await;
704 assert_eq!(
705 update,
706 acp::ToolCallUpdate {
707 id: acp::ToolCallId("1".into()),
708 fields: acp::ToolCallUpdateFields {
709 status: Some(acp::ToolCallStatus::InProgress),
710 ..Default::default()
711 },
712 }
713 );
714 let update = expect_tool_call_update(&mut events).await;
715 assert_eq!(
716 update,
717 acp::ToolCallUpdate {
718 id: acp::ToolCallId("1".into()),
719 fields: acp::ToolCallUpdateFields {
720 content: Some(vec!["Thinking hard!".into()]),
721 ..Default::default()
722 },
723 }
724 );
725 let update = expect_tool_call_update(&mut events).await;
726 assert_eq!(
727 update,
728 acp::ToolCallUpdate {
729 id: acp::ToolCallId("1".into()),
730 fields: acp::ToolCallUpdateFields {
731 status: Some(acp::ToolCallStatus::Completed),
732 ..Default::default()
733 },
734 }
735 );
736}
737
738/// Filters out the stop events for asserting against in tests
739fn stop_events(
740 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
741) -> Vec<acp::StopReason> {
742 result_events
743 .into_iter()
744 .filter_map(|event| match event.unwrap() {
745 AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
746 _ => None,
747 })
748 .collect()
749}
750
751struct ThreadTest {
752 model: Arc<dyn LanguageModel>,
753 thread: Entity<Thread>,
754 project_context: Rc<RefCell<ProjectContext>>,
755}
756
757enum TestModel {
758 Sonnet4,
759 Sonnet4Thinking,
760 Fake,
761}
762
763impl TestModel {
764 fn id(&self) -> LanguageModelId {
765 match self {
766 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
767 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
768 TestModel::Fake => unreachable!(),
769 }
770 }
771}
772
773async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
774 cx.executor().allow_parking();
775 cx.update(|cx| {
776 settings::init(cx);
777 Project::init_settings(cx);
778 });
779 let templates = Templates::new();
780
781 let fs = FakeFs::new(cx.background_executor.clone());
782 fs.insert_tree(path!("/test"), json!({})).await;
783 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
784
785 let model = cx
786 .update(|cx| {
787 gpui_tokio::init(cx);
788 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
789 cx.set_http_client(Arc::new(http_client));
790
791 client::init_settings(cx);
792 let client = Client::production(cx);
793 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
794 language_model::init(client.clone(), cx);
795 language_models::init(user_store.clone(), client.clone(), cx);
796
797 if let TestModel::Fake = model {
798 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
799 } else {
800 let model_id = model.id();
801 let models = LanguageModelRegistry::read_global(cx);
802 let model = models
803 .available_models(cx)
804 .find(|model| model.id() == model_id)
805 .unwrap();
806
807 let provider = models.provider(&model.provider_id()).unwrap();
808 let authenticated = provider.authenticate(cx);
809
810 cx.spawn(async move |_cx| {
811 authenticated.await.unwrap();
812 model
813 })
814 }
815 })
816 .await;
817
818 let project_context = Rc::new(RefCell::new(ProjectContext::default()));
819 let action_log = cx.new(|_| ActionLog::new(project.clone()));
820 let thread = cx.new(|_| {
821 Thread::new(
822 project,
823 project_context.clone(),
824 action_log,
825 templates,
826 model.clone(),
827 )
828 });
829 ThreadTest {
830 model,
831 thread,
832 project_context,
833 }
834}
835
836#[cfg(test)]
837#[ctor::ctor]
838fn init_logger() {
839 if std::env::var("RUST_LOG").is_ok() {
840 env_logger::init();
841 }
842}