1use super::*;
2use acp_thread::AgentConnection;
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp};
5use anyhow::Result;
6use client::{Client, UserStore};
7use fs::FakeFs;
8use futures::channel::mpsc::UnboundedReceiver;
9use gpui::{AppContext, Entity, Task, TestAppContext, http_client::FakeHttpClient};
10use indoc::indoc;
11use language_model::{
12 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
13 LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
14 StopReason, fake_provider::FakeLanguageModel,
15};
16use project::Project;
17use prompt_store::ProjectContext;
18use reqwest_client::ReqwestClient;
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use serde_json::json;
22use smol::stream::StreamExt;
23use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
24use util::path;
25
26mod test_tools;
27use test_tools::*;
28
29#[gpui::test]
30#[ignore = "can't run on CI yet"]
31async fn test_echo(cx: &mut TestAppContext) {
32 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
33
34 let events = thread
35 .update(cx, |thread, cx| {
36 thread.send("Testing: Reply with 'Hello'", cx)
37 })
38 .collect()
39 .await;
40 thread.update(cx, |thread, _cx| {
41 assert_eq!(
42 thread.messages().last().unwrap().content,
43 vec![MessageContent::Text("Hello".to_string())]
44 );
45 });
46 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
47}
48
49#[gpui::test]
50#[ignore = "can't run on CI yet"]
51async fn test_thinking(cx: &mut TestAppContext) {
52 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
53
54 let events = thread
55 .update(cx, |thread, cx| {
56 thread.send(
57 indoc! {"
58 Testing:
59
60 Generate a thinking step where you just think the word 'Think',
61 and have your final answer be 'Hello'
62 "},
63 cx,
64 )
65 })
66 .collect()
67 .await;
68 thread.update(cx, |thread, _cx| {
69 assert_eq!(
70 thread.messages().last().unwrap().to_markdown(),
71 indoc! {"
72 ## assistant
73 <think>Think</think>
74 Hello
75 "}
76 )
77 });
78 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
79}
80
81#[gpui::test]
82async fn test_system_prompt(cx: &mut TestAppContext) {
83 let ThreadTest {
84 model,
85 thread,
86 project_context,
87 ..
88 } = setup(cx, TestModel::Fake).await;
89 let fake_model = model.as_fake();
90
91 project_context.borrow_mut().shell = "test-shell".into();
92 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
93 thread.update(cx, |thread, cx| thread.send("abc", cx));
94 cx.run_until_parked();
95 let mut pending_completions = fake_model.pending_completions();
96 assert_eq!(
97 pending_completions.len(),
98 1,
99 "unexpected pending completions: {:?}",
100 pending_completions
101 );
102
103 let pending_completion = pending_completions.pop().unwrap();
104 assert_eq!(pending_completion.messages[0].role, Role::System);
105
106 let system_message = &pending_completion.messages[0];
107 let system_prompt = system_message.content[0].to_str().unwrap();
108 assert!(
109 system_prompt.contains("test-shell"),
110 "unexpected system message: {:?}",
111 system_message
112 );
113 assert!(
114 system_prompt.contains("## Fixing Diagnostics"),
115 "unexpected system message: {:?}",
116 system_message
117 );
118}
119
120#[gpui::test]
121#[ignore = "can't run on CI yet"]
122async fn test_basic_tool_calls(cx: &mut TestAppContext) {
123 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
124
125 // Test a tool call that's likely to complete *before* streaming stops.
126 let events = thread
127 .update(cx, |thread, cx| {
128 thread.add_tool(EchoTool);
129 thread.send(
130 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
131 cx,
132 )
133 })
134 .collect()
135 .await;
136 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
137
138 // Test a tool calls that's likely to complete *after* streaming stops.
139 let events = thread
140 .update(cx, |thread, cx| {
141 thread.remove_tool(&AgentTool::name(&EchoTool));
142 thread.add_tool(DelayTool);
143 thread.send(
144 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
145 cx,
146 )
147 })
148 .collect()
149 .await;
150 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
151 thread.update(cx, |thread, _cx| {
152 assert!(
153 thread
154 .messages()
155 .last()
156 .unwrap()
157 .content
158 .iter()
159 .any(|content| {
160 if let MessageContent::Text(text) = content {
161 text.contains("Ding")
162 } else {
163 false
164 }
165 })
166 );
167 });
168}
169
170#[gpui::test]
171#[ignore = "can't run on CI yet"]
172async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
173 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
174
175 // Test a tool call that's likely to complete *before* streaming stops.
176 let mut events = thread.update(cx, |thread, cx| {
177 thread.add_tool(WordListTool);
178 thread.send("Test the word_list tool.", cx)
179 });
180
181 let mut saw_partial_tool_use = false;
182 while let Some(event) = events.next().await {
183 if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
184 thread.update(cx, |thread, _cx| {
185 // Look for a tool use in the thread's last message
186 let last_content = thread.messages().last().unwrap().content.last().unwrap();
187 if let MessageContent::ToolUse(last_tool_use) = last_content {
188 assert_eq!(last_tool_use.name.as_ref(), "word_list");
189 if tool_call.status == acp::ToolCallStatus::Pending {
190 if !last_tool_use.is_input_complete
191 && last_tool_use.input.get("g").is_none()
192 {
193 saw_partial_tool_use = true;
194 }
195 } else {
196 last_tool_use
197 .input
198 .get("a")
199 .expect("'a' has streamed because input is now complete");
200 last_tool_use
201 .input
202 .get("g")
203 .expect("'g' has streamed because input is now complete");
204 }
205 } else {
206 panic!("last content should be a tool use");
207 }
208 });
209 }
210 }
211
212 assert!(
213 saw_partial_tool_use,
214 "should see at least one partially streamed tool use in the history"
215 );
216}
217
218#[gpui::test]
219async fn test_tool_authorization(cx: &mut TestAppContext) {
220 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
221 let fake_model = model.as_fake();
222
223 let mut events = thread.update(cx, |thread, cx| {
224 thread.add_tool(ToolRequiringPermission);
225 thread.send("abc", cx)
226 });
227 cx.run_until_parked();
228 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
229 LanguageModelToolUse {
230 id: "tool_id_1".into(),
231 name: ToolRequiringPermission.name().into(),
232 raw_input: "{}".into(),
233 input: json!({}),
234 is_input_complete: true,
235 },
236 ));
237 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
238 LanguageModelToolUse {
239 id: "tool_id_2".into(),
240 name: ToolRequiringPermission.name().into(),
241 raw_input: "{}".into(),
242 input: json!({}),
243 is_input_complete: true,
244 },
245 ));
246 fake_model.end_last_completion_stream();
247 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
248 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
249
250 // Approve the first
251 tool_call_auth_1
252 .response
253 .send(tool_call_auth_1.options[1].id.clone())
254 .unwrap();
255 cx.run_until_parked();
256
257 // Reject the second
258 tool_call_auth_2
259 .response
260 .send(tool_call_auth_1.options[2].id.clone())
261 .unwrap();
262 cx.run_until_parked();
263
264 let completion = fake_model.pending_completions().pop().unwrap();
265 let message = completion.messages.last().unwrap();
266 assert_eq!(
267 message.content,
268 vec![
269 MessageContent::ToolResult(LanguageModelToolResult {
270 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
271 tool_name: ToolRequiringPermission.name().into(),
272 is_error: false,
273 content: "Allowed".into(),
274 output: Some("Allowed".into())
275 }),
276 MessageContent::ToolResult(LanguageModelToolResult {
277 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
278 tool_name: ToolRequiringPermission.name().into(),
279 is_error: true,
280 content: "Permission to run tool denied by user".into(),
281 output: None
282 })
283 ]
284 );
285}
286
287#[gpui::test]
288async fn test_tool_hallucination(cx: &mut TestAppContext) {
289 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
290 let fake_model = model.as_fake();
291
292 let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
293 cx.run_until_parked();
294 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
295 LanguageModelToolUse {
296 id: "tool_id_1".into(),
297 name: "nonexistent_tool".into(),
298 raw_input: "{}".into(),
299 input: json!({}),
300 is_input_complete: true,
301 },
302 ));
303 fake_model.end_last_completion_stream();
304
305 let tool_call = expect_tool_call(&mut events).await;
306 assert_eq!(tool_call.title, "nonexistent_tool");
307 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
308 let update = expect_tool_call_update_fields(&mut events).await;
309 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
310}
311
312async fn expect_tool_call(
313 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
314) -> acp::ToolCall {
315 let event = events
316 .next()
317 .await
318 .expect("no tool call authorization event received")
319 .unwrap();
320 match event {
321 AgentResponseEvent::ToolCall(tool_call) => return tool_call,
322 event => {
323 panic!("Unexpected event {event:?}");
324 }
325 }
326}
327
328async fn expect_tool_call_update_fields(
329 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
330) -> acp::ToolCallUpdate {
331 let event = events
332 .next()
333 .await
334 .expect("no tool call authorization event received")
335 .unwrap();
336 match event {
337 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
338 return update;
339 }
340 event => {
341 panic!("Unexpected event {event:?}");
342 }
343 }
344}
345
346async fn next_tool_call_authorization(
347 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
348) -> ToolCallAuthorization {
349 loop {
350 let event = events
351 .next()
352 .await
353 .expect("no tool call authorization event received")
354 .unwrap();
355 if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
356 let permission_kinds = tool_call_authorization
357 .options
358 .iter()
359 .map(|o| o.kind)
360 .collect::<Vec<_>>();
361 assert_eq!(
362 permission_kinds,
363 vec![
364 acp::PermissionOptionKind::AllowAlways,
365 acp::PermissionOptionKind::AllowOnce,
366 acp::PermissionOptionKind::RejectOnce,
367 ]
368 );
369 return tool_call_authorization;
370 }
371 }
372}
373
374#[gpui::test]
375#[ignore = "can't run on CI yet"]
376async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
377 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
378
379 // Test concurrent tool calls with different delay times
380 let events = thread
381 .update(cx, |thread, cx| {
382 thread.add_tool(DelayTool);
383 thread.send(
384 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
385 cx,
386 )
387 })
388 .collect()
389 .await;
390
391 let stop_reasons = stop_events(events);
392 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
393
394 thread.update(cx, |thread, _cx| {
395 let last_message = thread.messages().last().unwrap();
396 let text = last_message
397 .content
398 .iter()
399 .filter_map(|content| {
400 if let MessageContent::Text(text) = content {
401 Some(text.as_str())
402 } else {
403 None
404 }
405 })
406 .collect::<String>();
407
408 assert!(text.contains("Ding"));
409 });
410}
411
412#[gpui::test]
413#[ignore = "can't run on CI yet"]
414async fn test_cancellation(cx: &mut TestAppContext) {
415 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
416
417 let mut events = thread.update(cx, |thread, cx| {
418 thread.add_tool(InfiniteTool);
419 thread.add_tool(EchoTool);
420 thread.send(
421 "Call the echo tool and then call the infinite tool, then explain their output",
422 cx,
423 )
424 });
425
426 // Wait until both tools are called.
427 let mut expected_tools = vec!["Echo", "Infinite Tool"];
428 let mut echo_id = None;
429 let mut echo_completed = false;
430 while let Some(event) = events.next().await {
431 match event.unwrap() {
432 AgentResponseEvent::ToolCall(tool_call) => {
433 assert_eq!(tool_call.title, expected_tools.remove(0));
434 if tool_call.title == "Echo" {
435 echo_id = Some(tool_call.id);
436 }
437 }
438 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
439 acp::ToolCallUpdate {
440 id,
441 fields:
442 acp::ToolCallUpdateFields {
443 status: Some(acp::ToolCallStatus::Completed),
444 ..
445 },
446 },
447 )) if Some(&id) == echo_id.as_ref() => {
448 echo_completed = true;
449 }
450 _ => {}
451 }
452
453 if expected_tools.is_empty() && echo_completed {
454 break;
455 }
456 }
457
458 // Cancel the current send and ensure that the event stream is closed, even
459 // if one of the tools is still running.
460 thread.update(cx, |thread, _cx| thread.cancel());
461 events.collect::<Vec<_>>().await;
462
463 // Ensure we can still send a new message after cancellation.
464 let events = thread
465 .update(cx, |thread, cx| {
466 thread.send("Testing: reply with 'Hello' then stop.", cx)
467 })
468 .collect::<Vec<_>>()
469 .await;
470 thread.update(cx, |thread, _cx| {
471 assert_eq!(
472 thread.messages().last().unwrap().content,
473 vec![MessageContent::Text("Hello".to_string())]
474 );
475 });
476 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
477}
478
479#[gpui::test]
480async fn test_refusal(cx: &mut TestAppContext) {
481 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
482 let fake_model = model.as_fake();
483
484 let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
485 cx.run_until_parked();
486 thread.read_with(cx, |thread, _| {
487 assert_eq!(
488 thread.to_markdown(),
489 indoc! {"
490 ## user
491 Hello
492 "}
493 );
494 });
495
496 fake_model.send_last_completion_stream_text_chunk("Hey!");
497 cx.run_until_parked();
498 thread.read_with(cx, |thread, _| {
499 assert_eq!(
500 thread.to_markdown(),
501 indoc! {"
502 ## user
503 Hello
504 ## assistant
505 Hey!
506 "}
507 );
508 });
509
510 // If the model refuses to continue, the thread should remove all the messages after the last user message.
511 fake_model
512 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
513 let events = events.collect::<Vec<_>>().await;
514 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
515 thread.read_with(cx, |thread, _| {
516 assert_eq!(thread.to_markdown(), "");
517 });
518}
519
520#[gpui::test]
521async fn test_agent_connection(cx: &mut TestAppContext) {
522 cx.update(settings::init);
523 let templates = Templates::new();
524
525 // Initialize language model system with test provider
526 cx.update(|cx| {
527 gpui_tokio::init(cx);
528 client::init_settings(cx);
529
530 let http_client = FakeHttpClient::with_404_response();
531 let clock = Arc::new(clock::FakeSystemClock::new());
532 let client = Client::new(clock, http_client, cx);
533 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
534 language_model::init(client.clone(), cx);
535 language_models::init(user_store.clone(), client.clone(), cx);
536 Project::init_settings(cx);
537 LanguageModelRegistry::test(cx);
538 });
539 cx.executor().forbid_parking();
540
541 // Create a project for new_thread
542 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
543 fake_fs.insert_tree(path!("/test"), json!({})).await;
544 let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
545 let cwd = Path::new("/test");
546
547 // Create agent and connection
548 let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
549 .await
550 .unwrap();
551 let connection = NativeAgentConnection(agent.clone());
552
553 // Test model_selector returns Some
554 let selector_opt = connection.model_selector();
555 assert!(
556 selector_opt.is_some(),
557 "agent2 should always support ModelSelector"
558 );
559 let selector = selector_opt.unwrap();
560
561 // Test list_models
562 let listed_models = cx
563 .update(|cx| {
564 let mut async_cx = cx.to_async();
565 selector.list_models(&mut async_cx)
566 })
567 .await
568 .expect("list_models should succeed");
569 assert!(!listed_models.is_empty(), "should have at least one model");
570 assert_eq!(listed_models[0].id().0, "fake");
571
572 // Create a thread using new_thread
573 let connection_rc = Rc::new(connection.clone());
574 let acp_thread = cx
575 .update(|cx| {
576 let mut async_cx = cx.to_async();
577 connection_rc.new_thread(project, cwd, &mut async_cx)
578 })
579 .await
580 .expect("new_thread should succeed");
581
582 // Get the session_id from the AcpThread
583 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
584
585 // Test selected_model returns the default
586 let model = cx
587 .update(|cx| {
588 let mut async_cx = cx.to_async();
589 selector.selected_model(&session_id, &mut async_cx)
590 })
591 .await
592 .expect("selected_model should succeed");
593 let model = model.as_fake();
594 assert_eq!(model.id().0, "fake", "should return default model");
595
596 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
597 cx.run_until_parked();
598 model.send_last_completion_stream_text_chunk("def");
599 cx.run_until_parked();
600 acp_thread.read_with(cx, |thread, cx| {
601 assert_eq!(
602 thread.to_markdown(cx),
603 indoc! {"
604 ## User
605
606 abc
607
608 ## Assistant
609
610 def
611
612 "}
613 )
614 });
615
616 // Test cancel
617 cx.update(|cx| connection.cancel(&session_id, cx));
618 request.await.expect("prompt should fail gracefully");
619
620 // Ensure that dropping the ACP thread causes the native thread to be
621 // dropped as well.
622 cx.update(|_| drop(acp_thread));
623 let result = cx
624 .update(|cx| {
625 connection.prompt(
626 acp::PromptRequest {
627 session_id: session_id.clone(),
628 prompt: vec!["ghi".into()],
629 },
630 cx,
631 )
632 })
633 .await;
634 assert_eq!(
635 result.as_ref().unwrap_err().to_string(),
636 "Session not found",
637 "unexpected result: {:?}",
638 result
639 );
640}
641
642#[gpui::test]
643async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
644 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
645 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
646 let fake_model = model.as_fake();
647
648 let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
649 cx.run_until_parked();
650
651 // Simulate streaming partial input.
652 let input = json!({});
653 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
654 LanguageModelToolUse {
655 id: "1".into(),
656 name: ThinkingTool.name().into(),
657 raw_input: input.to_string(),
658 input,
659 is_input_complete: false,
660 },
661 ));
662
663 // Input streaming completed
664 let input = json!({ "content": "Thinking hard!" });
665 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
666 LanguageModelToolUse {
667 id: "1".into(),
668 name: "thinking".into(),
669 raw_input: input.to_string(),
670 input,
671 is_input_complete: true,
672 },
673 ));
674 fake_model.end_last_completion_stream();
675 cx.run_until_parked();
676
677 let tool_call = expect_tool_call(&mut events).await;
678 assert_eq!(
679 tool_call,
680 acp::ToolCall {
681 id: acp::ToolCallId("1".into()),
682 title: "Thinking".into(),
683 kind: acp::ToolKind::Think,
684 status: acp::ToolCallStatus::Pending,
685 content: vec![],
686 locations: vec![],
687 raw_input: Some(json!({})),
688 raw_output: None,
689 }
690 );
691 let update = expect_tool_call_update_fields(&mut events).await;
692 assert_eq!(
693 update,
694 acp::ToolCallUpdate {
695 id: acp::ToolCallId("1".into()),
696 fields: acp::ToolCallUpdateFields {
697 title: Some("Thinking".into()),
698 kind: Some(acp::ToolKind::Think),
699 raw_input: Some(json!({ "content": "Thinking hard!" })),
700 ..Default::default()
701 },
702 }
703 );
704 let update = expect_tool_call_update_fields(&mut events).await;
705 assert_eq!(
706 update,
707 acp::ToolCallUpdate {
708 id: acp::ToolCallId("1".into()),
709 fields: acp::ToolCallUpdateFields {
710 status: Some(acp::ToolCallStatus::InProgress),
711 ..Default::default()
712 },
713 }
714 );
715 let update = expect_tool_call_update_fields(&mut events).await;
716 assert_eq!(
717 update,
718 acp::ToolCallUpdate {
719 id: acp::ToolCallId("1".into()),
720 fields: acp::ToolCallUpdateFields {
721 content: Some(vec!["Thinking hard!".into()]),
722 ..Default::default()
723 },
724 }
725 );
726 let update = expect_tool_call_update_fields(&mut events).await;
727 assert_eq!(
728 update,
729 acp::ToolCallUpdate {
730 id: acp::ToolCallId("1".into()),
731 fields: acp::ToolCallUpdateFields {
732 status: Some(acp::ToolCallStatus::Completed),
733 ..Default::default()
734 },
735 }
736 );
737}
738
739/// Filters out the stop events for asserting against in tests
740fn stop_events(
741 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
742) -> Vec<acp::StopReason> {
743 result_events
744 .into_iter()
745 .filter_map(|event| match event.unwrap() {
746 AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
747 _ => None,
748 })
749 .collect()
750}
751
752struct ThreadTest {
753 model: Arc<dyn LanguageModel>,
754 thread: Entity<Thread>,
755 project_context: Rc<RefCell<ProjectContext>>,
756}
757
758enum TestModel {
759 Sonnet4,
760 Sonnet4Thinking,
761 Fake,
762}
763
764impl TestModel {
765 fn id(&self) -> LanguageModelId {
766 match self {
767 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
768 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
769 TestModel::Fake => unreachable!(),
770 }
771 }
772}
773
774async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
775 cx.executor().allow_parking();
776 cx.update(|cx| {
777 settings::init(cx);
778 Project::init_settings(cx);
779 });
780 let templates = Templates::new();
781
782 let fs = FakeFs::new(cx.background_executor.clone());
783 fs.insert_tree(path!("/test"), json!({})).await;
784 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
785
786 let model = cx
787 .update(|cx| {
788 gpui_tokio::init(cx);
789 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
790 cx.set_http_client(Arc::new(http_client));
791
792 client::init_settings(cx);
793 let client = Client::production(cx);
794 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
795 language_model::init(client.clone(), cx);
796 language_models::init(user_store.clone(), client.clone(), cx);
797
798 if let TestModel::Fake = model {
799 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
800 } else {
801 let model_id = model.id();
802 let models = LanguageModelRegistry::read_global(cx);
803 let model = models
804 .available_models(cx)
805 .find(|model| model.id() == model_id)
806 .unwrap();
807
808 let provider = models.provider(&model.provider_id()).unwrap();
809 let authenticated = provider.authenticate(cx);
810
811 cx.spawn(async move |_cx| {
812 authenticated.await.unwrap();
813 model
814 })
815 }
816 })
817 .await;
818
819 let project_context = Rc::new(RefCell::new(ProjectContext::default()));
820 let action_log = cx.new(|_| ActionLog::new(project.clone()));
821 let thread = cx.new(|_| {
822 Thread::new(
823 project,
824 project_context.clone(),
825 action_log,
826 templates,
827 model.clone(),
828 )
829 });
830 ThreadTest {
831 model,
832 thread,
833 project_context,
834 }
835}
836
837#[cfg(test)]
838#[ctor::ctor]
839fn init_logger() {
840 if std::env::var("RUST_LOG").is_ok() {
841 env_logger::init();
842 }
843}