1use super::*;
2use acp_thread::AgentConnection;
3use agent_client_protocol::{self as acp};
4use anyhow::Result;
5use assistant_tool::ActionLog;
6use client::{Client, UserStore};
7use fs::FakeFs;
8use futures::channel::mpsc::UnboundedReceiver;
9use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
10use indoc::indoc;
11use language_model::{
12 fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
13 LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
14 LanguageModelToolUse, MessageContent, Role, StopReason,
15};
16use project::Project;
17use prompt_store::ProjectContext;
18use reqwest_client::ReqwestClient;
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use serde_json::json;
22use smol::stream::StreamExt;
23use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
24use util::path;
25
26mod test_tools;
27use test_tools::*;
28
29#[gpui::test]
30#[ignore = "can't run on CI yet"]
31async fn test_echo(cx: &mut TestAppContext) {
32 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
33
34 let events = thread
35 .update(cx, |thread, cx| {
36 thread.send("Testing: Reply with 'Hello'", cx)
37 })
38 .collect()
39 .await;
40 thread.update(cx, |thread, _cx| {
41 assert_eq!(
42 thread.messages().last().unwrap().content,
43 vec![MessageContent::Text("Hello".to_string())]
44 );
45 });
46 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
47}
48
49#[gpui::test]
50#[ignore = "can't run on CI yet"]
51async fn test_thinking(cx: &mut TestAppContext) {
52 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
53
54 let events = thread
55 .update(cx, |thread, cx| {
56 thread.send(
57 indoc! {"
58 Testing:
59
60 Generate a thinking step where you just think the word 'Think',
61 and have your final answer be 'Hello'
62 "},
63 cx,
64 )
65 })
66 .collect()
67 .await;
68 thread.update(cx, |thread, _cx| {
69 assert_eq!(
70 thread.messages().last().unwrap().to_markdown(),
71 indoc! {"
72 ## assistant
73 <think>Think</think>
74 Hello
75 "}
76 )
77 });
78 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
79}
80
81#[gpui::test]
82async fn test_system_prompt(cx: &mut TestAppContext) {
83 let ThreadTest {
84 model,
85 thread,
86 project_context,
87 ..
88 } = setup(cx, TestModel::Fake).await;
89 let fake_model = model.as_fake();
90
91 project_context.borrow_mut().shell = "test-shell".into();
92 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
93 thread.update(cx, |thread, cx| thread.send("abc", cx));
94 cx.run_until_parked();
95 let mut pending_completions = fake_model.pending_completions();
96 assert_eq!(
97 pending_completions.len(),
98 1,
99 "unexpected pending completions: {:?}",
100 pending_completions
101 );
102
103 let pending_completion = pending_completions.pop().unwrap();
104 assert_eq!(pending_completion.messages[0].role, Role::System);
105
106 let system_message = &pending_completion.messages[0];
107 let system_prompt = system_message.content[0].to_str().unwrap();
108 assert!(
109 system_prompt.contains("test-shell"),
110 "unexpected system message: {:?}",
111 system_message
112 );
113 assert!(
114 system_prompt.contains("## Fixing Diagnostics"),
115 "unexpected system message: {:?}",
116 system_message
117 );
118}
119
120#[gpui::test]
121#[ignore = "can't run on CI yet"]
122async fn test_basic_tool_calls(cx: &mut TestAppContext) {
123 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
124
125 // Test a tool call that's likely to complete *before* streaming stops.
126 let events = thread
127 .update(cx, |thread, cx| {
128 thread.add_tool(EchoTool);
129 thread.send(
130 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
131 cx,
132 )
133 })
134 .collect()
135 .await;
136 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
137
138 // Test a tool calls that's likely to complete *after* streaming stops.
139 let events = thread
140 .update(cx, |thread, cx| {
141 thread.remove_tool(&AgentTool::name(&EchoTool));
142 thread.add_tool(DelayTool);
143 thread.send(
144 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
145 cx,
146 )
147 })
148 .collect()
149 .await;
150 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
151 thread.update(cx, |thread, _cx| {
152 assert!(thread
153 .messages()
154 .last()
155 .unwrap()
156 .content
157 .iter()
158 .any(|content| {
159 if let MessageContent::Text(text) = content {
160 text.contains("Ding")
161 } else {
162 false
163 }
164 }));
165 });
166}
167
168#[gpui::test]
169#[ignore = "can't run on CI yet"]
170async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
171 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
172
173 // Test a tool call that's likely to complete *before* streaming stops.
174 let mut events = thread.update(cx, |thread, cx| {
175 thread.add_tool(WordListTool);
176 thread.send("Test the word_list tool.", cx)
177 });
178
179 let mut saw_partial_tool_use = false;
180 while let Some(event) = events.next().await {
181 if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
182 thread.update(cx, |thread, _cx| {
183 // Look for a tool use in the thread's last message
184 let last_content = thread.messages().last().unwrap().content.last().unwrap();
185 if let MessageContent::ToolUse(last_tool_use) = last_content {
186 assert_eq!(last_tool_use.name.as_ref(), "word_list");
187 if tool_call.status == acp::ToolCallStatus::Pending {
188 if !last_tool_use.is_input_complete
189 && last_tool_use.input.get("g").is_none()
190 {
191 saw_partial_tool_use = true;
192 }
193 } else {
194 last_tool_use
195 .input
196 .get("a")
197 .expect("'a' has streamed because input is now complete");
198 last_tool_use
199 .input
200 .get("g")
201 .expect("'g' has streamed because input is now complete");
202 }
203 } else {
204 panic!("last content should be a tool use");
205 }
206 });
207 }
208 }
209
210 assert!(
211 saw_partial_tool_use,
212 "should see at least one partially streamed tool use in the history"
213 );
214}
215
216#[gpui::test]
217async fn test_tool_authorization(cx: &mut TestAppContext) {
218 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
219 let fake_model = model.as_fake();
220
221 let mut events = thread.update(cx, |thread, cx| {
222 thread.add_tool(ToolRequiringPermission);
223 thread.send("abc", cx)
224 });
225 cx.run_until_parked();
226 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
227 LanguageModelToolUse {
228 id: "tool_id_1".into(),
229 name: ToolRequiringPermission.name().into(),
230 raw_input: "{}".into(),
231 input: json!({}),
232 is_input_complete: true,
233 },
234 ));
235 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
236 LanguageModelToolUse {
237 id: "tool_id_2".into(),
238 name: ToolRequiringPermission.name().into(),
239 raw_input: "{}".into(),
240 input: json!({}),
241 is_input_complete: true,
242 },
243 ));
244 fake_model.end_last_completion_stream();
245 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
246 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
247
248 // Approve the first
249 tool_call_auth_1
250 .response
251 .send(tool_call_auth_1.options[1].id.clone())
252 .unwrap();
253 cx.run_until_parked();
254
255 // Reject the second
256 tool_call_auth_2
257 .response
258 .send(tool_call_auth_1.options[2].id.clone())
259 .unwrap();
260 cx.run_until_parked();
261
262 let completion = fake_model.pending_completions().pop().unwrap();
263 let message = completion.messages.last().unwrap();
264 assert_eq!(
265 message.content,
266 vec![
267 MessageContent::ToolResult(LanguageModelToolResult {
268 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
269 tool_name: ToolRequiringPermission.name().into(),
270 is_error: false,
271 content: "Allowed".into(),
272 output: Some("Allowed".into())
273 }),
274 MessageContent::ToolResult(LanguageModelToolResult {
275 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
276 tool_name: ToolRequiringPermission.name().into(),
277 is_error: true,
278 content: "Permission to run tool denied by user".into(),
279 output: None
280 })
281 ]
282 );
283}
284
285#[gpui::test]
286async fn test_tool_hallucination(cx: &mut TestAppContext) {
287 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
288 let fake_model = model.as_fake();
289
290 let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
291 cx.run_until_parked();
292 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
293 LanguageModelToolUse {
294 id: "tool_id_1".into(),
295 name: "nonexistent_tool".into(),
296 raw_input: "{}".into(),
297 input: json!({}),
298 is_input_complete: true,
299 },
300 ));
301 fake_model.end_last_completion_stream();
302
303 let tool_call = expect_tool_call(&mut events).await;
304 assert_eq!(tool_call.title, "nonexistent_tool");
305 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
306 let update = expect_tool_call_update_fields(&mut events).await;
307 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
308}
309
310async fn expect_tool_call(
311 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
312) -> acp::ToolCall {
313 let event = events
314 .next()
315 .await
316 .expect("no tool call authorization event received")
317 .unwrap();
318 match event {
319 AgentResponseEvent::ToolCall(tool_call) => return tool_call,
320 event => {
321 panic!("Unexpected event {event:?}");
322 }
323 }
324}
325
326async fn expect_tool_call_update_fields(
327 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
328) -> acp::ToolCallUpdate {
329 let event = events
330 .next()
331 .await
332 .expect("no tool call authorization event received")
333 .unwrap();
334 match event {
335 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
336 return update
337 }
338 event => {
339 panic!("Unexpected event {event:?}");
340 }
341 }
342}
343
344async fn next_tool_call_authorization(
345 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
346) -> ToolCallAuthorization {
347 loop {
348 let event = events
349 .next()
350 .await
351 .expect("no tool call authorization event received")
352 .unwrap();
353 if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
354 let permission_kinds = tool_call_authorization
355 .options
356 .iter()
357 .map(|o| o.kind)
358 .collect::<Vec<_>>();
359 assert_eq!(
360 permission_kinds,
361 vec![
362 acp::PermissionOptionKind::AllowAlways,
363 acp::PermissionOptionKind::AllowOnce,
364 acp::PermissionOptionKind::RejectOnce,
365 ]
366 );
367 return tool_call_authorization;
368 }
369 }
370}
371
372#[gpui::test]
373#[ignore = "can't run on CI yet"]
374async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
375 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
376
377 // Test concurrent tool calls with different delay times
378 let events = thread
379 .update(cx, |thread, cx| {
380 thread.add_tool(DelayTool);
381 thread.send(
382 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
383 cx,
384 )
385 })
386 .collect()
387 .await;
388
389 let stop_reasons = stop_events(events);
390 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
391
392 thread.update(cx, |thread, _cx| {
393 let last_message = thread.messages().last().unwrap();
394 let text = last_message
395 .content
396 .iter()
397 .filter_map(|content| {
398 if let MessageContent::Text(text) = content {
399 Some(text.as_str())
400 } else {
401 None
402 }
403 })
404 .collect::<String>();
405
406 assert!(text.contains("Ding"));
407 });
408}
409
410#[gpui::test]
411#[ignore = "can't run on CI yet"]
412async fn test_cancellation(cx: &mut TestAppContext) {
413 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
414
415 let mut events = thread.update(cx, |thread, cx| {
416 thread.add_tool(InfiniteTool);
417 thread.add_tool(EchoTool);
418 thread.send(
419 "Call the echo tool and then call the infinite tool, then explain their output",
420 cx,
421 )
422 });
423
424 // Wait until both tools are called.
425 let mut expected_tools = vec!["Echo", "Infinite Tool"];
426 let mut echo_id = None;
427 let mut echo_completed = false;
428 while let Some(event) = events.next().await {
429 match event.unwrap() {
430 AgentResponseEvent::ToolCall(tool_call) => {
431 assert_eq!(tool_call.title, expected_tools.remove(0));
432 if tool_call.title == "Echo" {
433 echo_id = Some(tool_call.id);
434 }
435 }
436 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
437 acp::ToolCallUpdate {
438 id,
439 fields:
440 acp::ToolCallUpdateFields {
441 status: Some(acp::ToolCallStatus::Completed),
442 ..
443 },
444 },
445 )) if Some(&id) == echo_id.as_ref() => {
446 echo_completed = true;
447 }
448 _ => {}
449 }
450
451 if expected_tools.is_empty() && echo_completed {
452 break;
453 }
454 }
455
456 // Cancel the current send and ensure that the event stream is closed, even
457 // if one of the tools is still running.
458 thread.update(cx, |thread, _cx| thread.cancel());
459 events.collect::<Vec<_>>().await;
460
461 // Ensure we can still send a new message after cancellation.
462 let events = thread
463 .update(cx, |thread, cx| {
464 thread.send("Testing: reply with 'Hello' then stop.", cx)
465 })
466 .collect::<Vec<_>>()
467 .await;
468 thread.update(cx, |thread, _cx| {
469 assert_eq!(
470 thread.messages().last().unwrap().content,
471 vec![MessageContent::Text("Hello".to_string())]
472 );
473 });
474 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
475}
476
477#[gpui::test]
478async fn test_refusal(cx: &mut TestAppContext) {
479 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
480 let fake_model = model.as_fake();
481
482 let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
483 cx.run_until_parked();
484 thread.read_with(cx, |thread, _| {
485 assert_eq!(
486 thread.to_markdown(),
487 indoc! {"
488 ## user
489 Hello
490 "}
491 );
492 });
493
494 fake_model.send_last_completion_stream_text_chunk("Hey!");
495 cx.run_until_parked();
496 thread.read_with(cx, |thread, _| {
497 assert_eq!(
498 thread.to_markdown(),
499 indoc! {"
500 ## user
501 Hello
502 ## assistant
503 Hey!
504 "}
505 );
506 });
507
508 // If the model refuses to continue, the thread should remove all the messages after the last user message.
509 fake_model
510 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
511 let events = events.collect::<Vec<_>>().await;
512 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
513 thread.read_with(cx, |thread, _| {
514 assert_eq!(thread.to_markdown(), "");
515 });
516}
517
518#[gpui::test]
519async fn test_agent_connection(cx: &mut TestAppContext) {
520 cx.update(settings::init);
521 let templates = Templates::new();
522
523 // Initialize language model system with test provider
524 cx.update(|cx| {
525 gpui_tokio::init(cx);
526 client::init_settings(cx);
527
528 let http_client = FakeHttpClient::with_404_response();
529 let clock = Arc::new(clock::FakeSystemClock::new());
530 let client = Client::new(clock, http_client, cx);
531 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
532 language_model::init(client.clone(), cx);
533 language_models::init(user_store.clone(), client.clone(), cx);
534 Project::init_settings(cx);
535 LanguageModelRegistry::test(cx);
536 });
537 cx.executor().forbid_parking();
538
539 // Create a project for new_thread
540 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
541 fake_fs.insert_tree(path!("/test"), json!({})).await;
542 let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
543 let cwd = Path::new("/test");
544
545 // Create agent and connection
546 let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
547 .await
548 .unwrap();
549 let connection = NativeAgentConnection(agent.clone());
550
551 // Test model_selector returns Some
552 let selector_opt = connection.model_selector();
553 assert!(
554 selector_opt.is_some(),
555 "agent2 should always support ModelSelector"
556 );
557 let selector = selector_opt.unwrap();
558
559 // Test list_models
560 let listed_models = cx
561 .update(|cx| {
562 let mut async_cx = cx.to_async();
563 selector.list_models(&mut async_cx)
564 })
565 .await
566 .expect("list_models should succeed");
567 assert!(!listed_models.is_empty(), "should have at least one model");
568 assert_eq!(listed_models[0].id().0, "fake");
569
570 // Create a thread using new_thread
571 let connection_rc = Rc::new(connection.clone());
572 let acp_thread = cx
573 .update(|cx| {
574 let mut async_cx = cx.to_async();
575 connection_rc.new_thread(project, cwd, &mut async_cx)
576 })
577 .await
578 .expect("new_thread should succeed");
579
580 // Get the session_id from the AcpThread
581 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
582
583 // Test selected_model returns the default
584 let model = cx
585 .update(|cx| {
586 let mut async_cx = cx.to_async();
587 selector.selected_model(&session_id, &mut async_cx)
588 })
589 .await
590 .expect("selected_model should succeed");
591 let model = model.as_fake();
592 assert_eq!(model.id().0, "fake", "should return default model");
593
594 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
595 cx.run_until_parked();
596 model.send_last_completion_stream_text_chunk("def");
597 cx.run_until_parked();
598 acp_thread.read_with(cx, |thread, cx| {
599 assert_eq!(
600 thread.to_markdown(cx),
601 indoc! {"
602 ## User
603
604 abc
605
606 ## Assistant
607
608 def
609
610 "}
611 )
612 });
613
614 // Test cancel
615 cx.update(|cx| connection.cancel(&session_id, cx));
616 request.await.expect("prompt should fail gracefully");
617
618 // Ensure that dropping the ACP thread causes the native thread to be
619 // dropped as well.
620 cx.update(|_| drop(acp_thread));
621 let result = cx
622 .update(|cx| {
623 connection.prompt(
624 acp::PromptRequest {
625 session_id: session_id.clone(),
626 prompt: vec!["ghi".into()],
627 },
628 cx,
629 )
630 })
631 .await;
632 assert_eq!(
633 result.as_ref().unwrap_err().to_string(),
634 "Session not found",
635 "unexpected result: {:?}",
636 result
637 );
638}
639
640#[gpui::test]
641async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
642 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
643 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
644 let fake_model = model.as_fake();
645
646 let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
647 cx.run_until_parked();
648
649 // Simulate streaming partial input.
650 let input = json!({});
651 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
652 LanguageModelToolUse {
653 id: "1".into(),
654 name: ThinkingTool.name().into(),
655 raw_input: input.to_string(),
656 input,
657 is_input_complete: false,
658 },
659 ));
660
661 // Input streaming completed
662 let input = json!({ "content": "Thinking hard!" });
663 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
664 LanguageModelToolUse {
665 id: "1".into(),
666 name: "thinking".into(),
667 raw_input: input.to_string(),
668 input,
669 is_input_complete: true,
670 },
671 ));
672 fake_model.end_last_completion_stream();
673 cx.run_until_parked();
674
675 let tool_call = expect_tool_call(&mut events).await;
676 assert_eq!(
677 tool_call,
678 acp::ToolCall {
679 id: acp::ToolCallId("1".into()),
680 title: "Thinking".into(),
681 kind: acp::ToolKind::Think,
682 status: acp::ToolCallStatus::Pending,
683 content: vec![],
684 locations: vec![],
685 raw_input: Some(json!({})),
686 raw_output: None,
687 }
688 );
689 let update = expect_tool_call_update_fields(&mut events).await;
690 assert_eq!(
691 update,
692 acp::ToolCallUpdate {
693 id: acp::ToolCallId("1".into()),
694 fields: acp::ToolCallUpdateFields {
695 title: Some("Thinking".into()),
696 kind: Some(acp::ToolKind::Think),
697 raw_input: Some(json!({ "content": "Thinking hard!" })),
698 ..Default::default()
699 },
700 }
701 );
702 let update = expect_tool_call_update_fields(&mut events).await;
703 assert_eq!(
704 update,
705 acp::ToolCallUpdate {
706 id: acp::ToolCallId("1".into()),
707 fields: acp::ToolCallUpdateFields {
708 status: Some(acp::ToolCallStatus::InProgress),
709 ..Default::default()
710 },
711 }
712 );
713 let update = expect_tool_call_update_fields(&mut events).await;
714 assert_eq!(
715 update,
716 acp::ToolCallUpdate {
717 id: acp::ToolCallId("1".into()),
718 fields: acp::ToolCallUpdateFields {
719 content: Some(vec!["Thinking hard!".into()]),
720 ..Default::default()
721 },
722 }
723 );
724 let update = expect_tool_call_update_fields(&mut events).await;
725 assert_eq!(
726 update,
727 acp::ToolCallUpdate {
728 id: acp::ToolCallId("1".into()),
729 fields: acp::ToolCallUpdateFields {
730 status: Some(acp::ToolCallStatus::Completed),
731 ..Default::default()
732 },
733 }
734 );
735}
736
737/// Filters out the stop events for asserting against in tests
738fn stop_events(
739 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
740) -> Vec<acp::StopReason> {
741 result_events
742 .into_iter()
743 .filter_map(|event| match event.unwrap() {
744 AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
745 _ => None,
746 })
747 .collect()
748}
749
750struct ThreadTest {
751 model: Arc<dyn LanguageModel>,
752 thread: Entity<Thread>,
753 project_context: Rc<RefCell<ProjectContext>>,
754}
755
756enum TestModel {
757 Sonnet4,
758 Sonnet4Thinking,
759 Fake,
760}
761
762impl TestModel {
763 fn id(&self) -> LanguageModelId {
764 match self {
765 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
766 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
767 TestModel::Fake => unreachable!(),
768 }
769 }
770}
771
772async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
773 cx.executor().allow_parking();
774 cx.update(|cx| {
775 settings::init(cx);
776 Project::init_settings(cx);
777 });
778 let templates = Templates::new();
779
780 let fs = FakeFs::new(cx.background_executor.clone());
781 fs.insert_tree(path!("/test"), json!({})).await;
782 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
783
784 let model = cx
785 .update(|cx| {
786 gpui_tokio::init(cx);
787 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
788 cx.set_http_client(Arc::new(http_client));
789
790 client::init_settings(cx);
791 let client = Client::production(cx);
792 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
793 language_model::init(client.clone(), cx);
794 language_models::init(user_store.clone(), client.clone(), cx);
795
796 if let TestModel::Fake = model {
797 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
798 } else {
799 let model_id = model.id();
800 let models = LanguageModelRegistry::read_global(cx);
801 let model = models
802 .available_models(cx)
803 .find(|model| model.id() == model_id)
804 .unwrap();
805
806 let provider = models.provider(&model.provider_id()).unwrap();
807 let authenticated = provider.authenticate(cx);
808
809 cx.spawn(async move |_cx| {
810 authenticated.await.unwrap();
811 model
812 })
813 }
814 })
815 .await;
816
817 let project_context = Rc::new(RefCell::new(ProjectContext::default()));
818 let action_log = cx.new(|_| ActionLog::new(project.clone()));
819 let thread = cx.new(|_| {
820 Thread::new(
821 project,
822 project_context.clone(),
823 action_log,
824 templates,
825 model.clone(),
826 )
827 });
828 ThreadTest {
829 model,
830 thread,
831 project_context,
832 }
833}
834
835#[cfg(test)]
836#[ctor::ctor]
837fn init_logger() {
838 if std::env::var("RUST_LOG").is_ok() {
839 env_logger::init();
840 }
841}