mod.rs

  1use super::*;
  2use acp_thread::AgentConnection;
  3use agent_client_protocol::{self as acp};
  4use anyhow::Result;
  5use assistant_tool::ActionLog;
  6use client::{Client, UserStore};
  7use fs::FakeFs;
  8use futures::channel::mpsc::UnboundedReceiver;
  9use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
 10use indoc::indoc;
 11use language_model::{
 12    fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
 13    LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
 14    LanguageModelToolUse, MessageContent, Role, StopReason,
 15};
 16use project::Project;
 17use prompt_store::ProjectContext;
 18use reqwest_client::ReqwestClient;
 19use schemars::JsonSchema;
 20use serde::{Deserialize, Serialize};
 21use serde_json::json;
 22use smol::stream::StreamExt;
 23use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
 24use util::path;
 25
 26mod test_tools;
 27use test_tools::*;
 28
 29#[gpui::test]
 30#[ignore = "can't run on CI yet"]
 31async fn test_echo(cx: &mut TestAppContext) {
 32    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
 33
 34    let events = thread
 35        .update(cx, |thread, cx| {
 36            thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
 37        })
 38        .collect()
 39        .await;
 40    thread.update(cx, |thread, _cx| {
 41        assert_eq!(
 42            thread.messages().last().unwrap().content,
 43            vec![MessageContent::Text("Hello".to_string())]
 44        );
 45    });
 46    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 47}
 48
 49#[gpui::test]
 50#[ignore = "can't run on CI yet"]
 51async fn test_thinking(cx: &mut TestAppContext) {
 52    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
 53
 54    let events = thread
 55        .update(cx, |thread, cx| {
 56            thread.send(
 57                model.clone(),
 58                indoc! {"
 59                    Testing:
 60
 61                    Generate a thinking step where you just think the word 'Think',
 62                    and have your final answer be 'Hello'
 63                "},
 64                cx,
 65            )
 66        })
 67        .collect()
 68        .await;
 69    thread.update(cx, |thread, _cx| {
 70        assert_eq!(
 71            thread.messages().last().unwrap().to_markdown(),
 72            indoc! {"
 73                ## assistant
 74                <think>Think</think>
 75                Hello
 76            "}
 77        )
 78    });
 79    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 80}
 81
 82#[gpui::test]
 83async fn test_system_prompt(cx: &mut TestAppContext) {
 84    let ThreadTest {
 85        model,
 86        thread,
 87        project_context,
 88        ..
 89    } = setup(cx, TestModel::Fake).await;
 90    let fake_model = model.as_fake();
 91
 92    project_context.borrow_mut().shell = "test-shell".into();
 93    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 94    thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
 95    cx.run_until_parked();
 96    let mut pending_completions = fake_model.pending_completions();
 97    assert_eq!(
 98        pending_completions.len(),
 99        1,
100        "unexpected pending completions: {:?}",
101        pending_completions
102    );
103
104    let pending_completion = pending_completions.pop().unwrap();
105    assert_eq!(pending_completion.messages[0].role, Role::System);
106
107    let system_message = &pending_completion.messages[0];
108    let system_prompt = system_message.content[0].to_str().unwrap();
109    assert!(
110        system_prompt.contains("test-shell"),
111        "unexpected system message: {:?}",
112        system_message
113    );
114    assert!(
115        system_prompt.contains("## Fixing Diagnostics"),
116        "unexpected system message: {:?}",
117        system_message
118    );
119}
120
121#[gpui::test]
122#[ignore = "can't run on CI yet"]
123async fn test_basic_tool_calls(cx: &mut TestAppContext) {
124    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
125
126    // Test a tool call that's likely to complete *before* streaming stops.
127    let events = thread
128        .update(cx, |thread, cx| {
129            thread.add_tool(EchoTool);
130            thread.send(
131                model.clone(),
132                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
133                cx,
134            )
135        })
136        .collect()
137        .await;
138    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
139
140    // Test a tool calls that's likely to complete *after* streaming stops.
141    let events = thread
142        .update(cx, |thread, cx| {
143            thread.remove_tool(&AgentTool::name(&EchoTool));
144            thread.add_tool(DelayTool);
145            thread.send(
146                model.clone(),
147                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
148                cx,
149            )
150        })
151        .collect()
152        .await;
153    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
154    thread.update(cx, |thread, _cx| {
155        assert!(thread
156            .messages()
157            .last()
158            .unwrap()
159            .content
160            .iter()
161            .any(|content| {
162                if let MessageContent::Text(text) = content {
163                    text.contains("Ding")
164                } else {
165                    false
166                }
167            }));
168    });
169}
170
171#[gpui::test]
172#[ignore = "can't run on CI yet"]
173async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
174    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
175
176    // Test a tool call that's likely to complete *before* streaming stops.
177    let mut events = thread.update(cx, |thread, cx| {
178        thread.add_tool(WordListTool);
179        thread.send(model.clone(), "Test the word_list tool.", cx)
180    });
181
182    let mut saw_partial_tool_use = false;
183    while let Some(event) = events.next().await {
184        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
185            thread.update(cx, |thread, _cx| {
186                // Look for a tool use in the thread's last message
187                let last_content = thread.messages().last().unwrap().content.last().unwrap();
188                if let MessageContent::ToolUse(last_tool_use) = last_content {
189                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
190                    if tool_call.status == acp::ToolCallStatus::Pending {
191                        if !last_tool_use.is_input_complete
192                            && last_tool_use.input.get("g").is_none()
193                        {
194                            saw_partial_tool_use = true;
195                        }
196                    } else {
197                        last_tool_use
198                            .input
199                            .get("a")
200                            .expect("'a' has streamed because input is now complete");
201                        last_tool_use
202                            .input
203                            .get("g")
204                            .expect("'g' has streamed because input is now complete");
205                    }
206                } else {
207                    panic!("last content should be a tool use");
208                }
209            });
210        }
211    }
212
213    assert!(
214        saw_partial_tool_use,
215        "should see at least one partially streamed tool use in the history"
216    );
217}
218
219#[gpui::test]
220async fn test_tool_authorization(cx: &mut TestAppContext) {
221    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
222    let fake_model = model.as_fake();
223
224    let mut events = thread.update(cx, |thread, cx| {
225        thread.add_tool(ToolRequiringPermission);
226        thread.send(model.clone(), "abc", cx)
227    });
228    cx.run_until_parked();
229    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
230        LanguageModelToolUse {
231            id: "tool_id_1".into(),
232            name: ToolRequiringPermission.name().into(),
233            raw_input: "{}".into(),
234            input: json!({}),
235            is_input_complete: true,
236        },
237    ));
238    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
239        LanguageModelToolUse {
240            id: "tool_id_2".into(),
241            name: ToolRequiringPermission.name().into(),
242            raw_input: "{}".into(),
243            input: json!({}),
244            is_input_complete: true,
245        },
246    ));
247    fake_model.end_last_completion_stream();
248    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
249    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
250
251    // Approve the first
252    tool_call_auth_1
253        .response
254        .send(tool_call_auth_1.options[1].id.clone())
255        .unwrap();
256    cx.run_until_parked();
257
258    // Reject the second
259    tool_call_auth_2
260        .response
261        .send(tool_call_auth_1.options[2].id.clone())
262        .unwrap();
263    cx.run_until_parked();
264
265    let completion = fake_model.pending_completions().pop().unwrap();
266    let message = completion.messages.last().unwrap();
267    assert_eq!(
268        message.content,
269        vec![
270            MessageContent::ToolResult(LanguageModelToolResult {
271                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
272                tool_name: ToolRequiringPermission.name().into(),
273                is_error: false,
274                content: "Allowed".into(),
275                output: Some("Allowed".into())
276            }),
277            MessageContent::ToolResult(LanguageModelToolResult {
278                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
279                tool_name: ToolRequiringPermission.name().into(),
280                is_error: true,
281                content: "Permission to run tool denied by user".into(),
282                output: None
283            })
284        ]
285    );
286}
287
288#[gpui::test]
289async fn test_tool_hallucination(cx: &mut TestAppContext) {
290    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
291    let fake_model = model.as_fake();
292
293    let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
294    cx.run_until_parked();
295    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
296        LanguageModelToolUse {
297            id: "tool_id_1".into(),
298            name: "nonexistent_tool".into(),
299            raw_input: "{}".into(),
300            input: json!({}),
301            is_input_complete: true,
302        },
303    ));
304    fake_model.end_last_completion_stream();
305
306    let tool_call = expect_tool_call(&mut events).await;
307    assert_eq!(tool_call.title, "nonexistent_tool");
308    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
309    let update = expect_tool_call_update(&mut events).await;
310    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
311}
312
313async fn expect_tool_call(
314    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
315) -> acp::ToolCall {
316    let event = events
317        .next()
318        .await
319        .expect("no tool call authorization event received")
320        .unwrap();
321    match event {
322        AgentResponseEvent::ToolCall(tool_call) => return tool_call,
323        event => {
324            panic!("Unexpected event {event:?}");
325        }
326    }
327}
328
329async fn expect_tool_call_update(
330    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
331) -> acp::ToolCallUpdate {
332    let event = events
333        .next()
334        .await
335        .expect("no tool call authorization event received")
336        .unwrap();
337    match event {
338        AgentResponseEvent::ToolCallUpdate(tool_call_update) => return tool_call_update,
339        event => {
340            panic!("Unexpected event {event:?}");
341        }
342    }
343}
344
345async fn next_tool_call_authorization(
346    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
347) -> ToolCallAuthorization {
348    loop {
349        let event = events
350            .next()
351            .await
352            .expect("no tool call authorization event received")
353            .unwrap();
354        if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
355            let permission_kinds = tool_call_authorization
356                .options
357                .iter()
358                .map(|o| o.kind)
359                .collect::<Vec<_>>();
360            assert_eq!(
361                permission_kinds,
362                vec![
363                    acp::PermissionOptionKind::AllowAlways,
364                    acp::PermissionOptionKind::AllowOnce,
365                    acp::PermissionOptionKind::RejectOnce,
366                ]
367            );
368            return tool_call_authorization;
369        }
370    }
371}
372
373#[gpui::test]
374#[ignore = "can't run on CI yet"]
375async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
376    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
377
378    // Test concurrent tool calls with different delay times
379    let events = thread
380        .update(cx, |thread, cx| {
381            thread.add_tool(DelayTool);
382            thread.send(
383                model.clone(),
384                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
385                cx,
386            )
387        })
388        .collect()
389        .await;
390
391    let stop_reasons = stop_events(events);
392    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
393
394    thread.update(cx, |thread, _cx| {
395        let last_message = thread.messages().last().unwrap();
396        let text = last_message
397            .content
398            .iter()
399            .filter_map(|content| {
400                if let MessageContent::Text(text) = content {
401                    Some(text.as_str())
402                } else {
403                    None
404                }
405            })
406            .collect::<String>();
407
408        assert!(text.contains("Ding"));
409    });
410}
411
412#[gpui::test]
413#[ignore = "can't run on CI yet"]
414async fn test_cancellation(cx: &mut TestAppContext) {
415    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
416
417    let mut events = thread.update(cx, |thread, cx| {
418        thread.add_tool(InfiniteTool);
419        thread.add_tool(EchoTool);
420        thread.send(
421            model.clone(),
422            "Call the echo tool and then call the infinite tool, then explain their output",
423            cx,
424        )
425    });
426
427    // Wait until both tools are called.
428    let mut expected_tool_calls = vec!["echo", "infinite"];
429    let mut echo_id = None;
430    let mut echo_completed = false;
431    while let Some(event) = events.next().await {
432        match event.unwrap() {
433            AgentResponseEvent::ToolCall(tool_call) => {
434                assert_eq!(tool_call.title, expected_tool_calls.remove(0));
435                if tool_call.title == "echo" {
436                    echo_id = Some(tool_call.id);
437                }
438            }
439            AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate {
440                id,
441                fields:
442                    acp::ToolCallUpdateFields {
443                        status: Some(acp::ToolCallStatus::Completed),
444                        ..
445                    },
446            }) if Some(&id) == echo_id.as_ref() => {
447                echo_completed = true;
448            }
449            _ => {}
450        }
451
452        if expected_tool_calls.is_empty() && echo_completed {
453            break;
454        }
455    }
456
457    // Cancel the current send and ensure that the event stream is closed, even
458    // if one of the tools is still running.
459    thread.update(cx, |thread, _cx| thread.cancel());
460    events.collect::<Vec<_>>().await;
461
462    // Ensure we can still send a new message after cancellation.
463    let events = thread
464        .update(cx, |thread, cx| {
465            thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
466        })
467        .collect::<Vec<_>>()
468        .await;
469    thread.update(cx, |thread, _cx| {
470        assert_eq!(
471            thread.messages().last().unwrap().content,
472            vec![MessageContent::Text("Hello".to_string())]
473        );
474    });
475    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
476}
477
478#[gpui::test]
479async fn test_refusal(cx: &mut TestAppContext) {
480    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
481    let fake_model = model.as_fake();
482
483    let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
484    cx.run_until_parked();
485    thread.read_with(cx, |thread, _| {
486        assert_eq!(
487            thread.to_markdown(),
488            indoc! {"
489                ## user
490                Hello
491            "}
492        );
493    });
494
495    fake_model.send_last_completion_stream_text_chunk("Hey!");
496    cx.run_until_parked();
497    thread.read_with(cx, |thread, _| {
498        assert_eq!(
499            thread.to_markdown(),
500            indoc! {"
501                ## user
502                Hello
503                ## assistant
504                Hey!
505            "}
506        );
507    });
508
509    // If the model refuses to continue, the thread should remove all the messages after the last user message.
510    fake_model
511        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
512    let events = events.collect::<Vec<_>>().await;
513    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
514    thread.read_with(cx, |thread, _| {
515        assert_eq!(thread.to_markdown(), "");
516    });
517}
518
519#[gpui::test]
520async fn test_agent_connection(cx: &mut TestAppContext) {
521    cx.update(settings::init);
522    let templates = Templates::new();
523
524    // Initialize language model system with test provider
525    cx.update(|cx| {
526        gpui_tokio::init(cx);
527        client::init_settings(cx);
528
529        let http_client = FakeHttpClient::with_404_response();
530        let clock = Arc::new(clock::FakeSystemClock::new());
531        let client = Client::new(clock, http_client, cx);
532        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
533        language_model::init(client.clone(), cx);
534        language_models::init(user_store.clone(), client.clone(), cx);
535        Project::init_settings(cx);
536        LanguageModelRegistry::test(cx);
537    });
538    cx.executor().forbid_parking();
539
540    // Create a project for new_thread
541    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
542    fake_fs.insert_tree(path!("/test"), json!({})).await;
543    let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
544    let cwd = Path::new("/test");
545
546    // Create agent and connection
547    let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
548        .await
549        .unwrap();
550    let connection = NativeAgentConnection(agent.clone());
551
552    // Test model_selector returns Some
553    let selector_opt = connection.model_selector();
554    assert!(
555        selector_opt.is_some(),
556        "agent2 should always support ModelSelector"
557    );
558    let selector = selector_opt.unwrap();
559
560    // Test list_models
561    let listed_models = cx
562        .update(|cx| {
563            let mut async_cx = cx.to_async();
564            selector.list_models(&mut async_cx)
565        })
566        .await
567        .expect("list_models should succeed");
568    assert!(!listed_models.is_empty(), "should have at least one model");
569    assert_eq!(listed_models[0].id().0, "fake");
570
571    // Create a thread using new_thread
572    let connection_rc = Rc::new(connection.clone());
573    let acp_thread = cx
574        .update(|cx| {
575            let mut async_cx = cx.to_async();
576            connection_rc.new_thread(project, cwd, &mut async_cx)
577        })
578        .await
579        .expect("new_thread should succeed");
580
581    // Get the session_id from the AcpThread
582    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
583
584    // Test selected_model returns the default
585    let model = cx
586        .update(|cx| {
587            let mut async_cx = cx.to_async();
588            selector.selected_model(&session_id, &mut async_cx)
589        })
590        .await
591        .expect("selected_model should succeed");
592    let model = model.as_fake();
593    assert_eq!(model.id().0, "fake", "should return default model");
594
595    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
596    cx.run_until_parked();
597    model.send_last_completion_stream_text_chunk("def");
598    cx.run_until_parked();
599    acp_thread.read_with(cx, |thread, cx| {
600        assert_eq!(
601            thread.to_markdown(cx),
602            indoc! {"
603                ## User
604
605                abc
606
607                ## Assistant
608
609                def
610
611            "}
612        )
613    });
614
615    // Test cancel
616    cx.update(|cx| connection.cancel(&session_id, cx));
617    request.await.expect("prompt should fail gracefully");
618
619    // Ensure that dropping the ACP thread causes the native thread to be
620    // dropped as well.
621    cx.update(|_| drop(acp_thread));
622    let result = cx
623        .update(|cx| {
624            connection.prompt(
625                acp::PromptRequest {
626                    session_id: session_id.clone(),
627                    prompt: vec!["ghi".into()],
628                },
629                cx,
630            )
631        })
632        .await;
633    assert_eq!(
634        result.as_ref().unwrap_err().to_string(),
635        "Session not found",
636        "unexpected result: {:?}",
637        result
638    );
639}
640
641#[gpui::test]
642async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
643    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
644    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
645    let fake_model = model.as_fake();
646
647    let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx));
648    cx.run_until_parked();
649
650    let input = json!({ "content": "Thinking hard!" });
651    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
652        LanguageModelToolUse {
653            id: "1".into(),
654            name: ThinkingTool.name().into(),
655            raw_input: input.to_string(),
656            input,
657            is_input_complete: true,
658        },
659    ));
660    fake_model.end_last_completion_stream();
661    cx.run_until_parked();
662
663    let tool_call = expect_tool_call(&mut events).await;
664    assert_eq!(
665        tool_call,
666        acp::ToolCall {
667            id: acp::ToolCallId("1".into()),
668            title: "Thinking".into(),
669            kind: acp::ToolKind::Think,
670            status: acp::ToolCallStatus::Pending,
671            content: vec![],
672            locations: vec![],
673            raw_input: Some(json!({ "content": "Thinking hard!" })),
674            raw_output: None,
675        }
676    );
677    let update = expect_tool_call_update(&mut events).await;
678    assert_eq!(
679        update,
680        acp::ToolCallUpdate {
681            id: acp::ToolCallId("1".into()),
682            fields: acp::ToolCallUpdateFields {
683                status: Some(acp::ToolCallStatus::InProgress,),
684                ..Default::default()
685            },
686        }
687    );
688    let update = expect_tool_call_update(&mut events).await;
689    assert_eq!(
690        update,
691        acp::ToolCallUpdate {
692            id: acp::ToolCallId("1".into()),
693            fields: acp::ToolCallUpdateFields {
694                content: Some(vec!["Thinking hard!".into()]),
695                ..Default::default()
696            },
697        }
698    );
699    let update = expect_tool_call_update(&mut events).await;
700    assert_eq!(
701        update,
702        acp::ToolCallUpdate {
703            id: acp::ToolCallId("1".into()),
704            fields: acp::ToolCallUpdateFields {
705                status: Some(acp::ToolCallStatus::Completed),
706                ..Default::default()
707            },
708        }
709    );
710}
711
712/// Filters out the stop events for asserting against in tests
713fn stop_events(
714    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
715) -> Vec<acp::StopReason> {
716    result_events
717        .into_iter()
718        .filter_map(|event| match event.unwrap() {
719            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
720            _ => None,
721        })
722        .collect()
723}
724
725struct ThreadTest {
726    model: Arc<dyn LanguageModel>,
727    thread: Entity<Thread>,
728    project_context: Rc<RefCell<ProjectContext>>,
729}
730
731enum TestModel {
732    Sonnet4,
733    Sonnet4Thinking,
734    Fake,
735}
736
737impl TestModel {
738    fn id(&self) -> LanguageModelId {
739        match self {
740            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
741            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
742            TestModel::Fake => unreachable!(),
743        }
744    }
745}
746
747async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
748    cx.executor().allow_parking();
749    cx.update(|cx| {
750        settings::init(cx);
751        Project::init_settings(cx);
752    });
753    let templates = Templates::new();
754
755    let fs = FakeFs::new(cx.background_executor.clone());
756    fs.insert_tree(path!("/test"), json!({})).await;
757    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
758
759    let model = cx
760        .update(|cx| {
761            gpui_tokio::init(cx);
762            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
763            cx.set_http_client(Arc::new(http_client));
764
765            client::init_settings(cx);
766            let client = Client::production(cx);
767            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
768            language_model::init(client.clone(), cx);
769            language_models::init(user_store.clone(), client.clone(), cx);
770
771            if let TestModel::Fake = model {
772                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
773            } else {
774                let model_id = model.id();
775                let models = LanguageModelRegistry::read_global(cx);
776                let model = models
777                    .available_models(cx)
778                    .find(|model| model.id() == model_id)
779                    .unwrap();
780
781                let provider = models.provider(&model.provider_id()).unwrap();
782                let authenticated = provider.authenticate(cx);
783
784                cx.spawn(async move |_cx| {
785                    authenticated.await.unwrap();
786                    model
787                })
788            }
789        })
790        .await;
791
792    let project_context = Rc::new(RefCell::new(ProjectContext::default()));
793    let action_log = cx.new(|_| ActionLog::new(project.clone()));
794    let thread = cx.new(|_| {
795        Thread::new(
796            project,
797            project_context.clone(),
798            action_log,
799            templates,
800            model.clone(),
801        )
802    });
803    ThreadTest {
804        model,
805        thread,
806        project_context,
807    }
808}
809
810#[cfg(test)]
811#[ctor::ctor]
812fn init_logger() {
813    if std::env::var("RUST_LOG").is_ok() {
814        env_logger::init();
815    }
816}