mod.rs

  1use super::*;
  2use crate::templates::Templates;
  3use acp_thread::AgentConnection;
  4use agent_client_protocol::{self as acp};
  5use anyhow::Result;
  6use assistant_tool::ActionLog;
  7use client::{Client, UserStore};
  8use fs::FakeFs;
  9use futures::channel::mpsc::UnboundedReceiver;
 10use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
 11use indoc::indoc;
 12use language_model::{
 13    fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
 14    LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
 15    LanguageModelToolUse, MessageContent, Role, StopReason,
 16};
 17use project::Project;
 18use prompt_store::ProjectContext;
 19use reqwest_client::ReqwestClient;
 20use schemars::JsonSchema;
 21use serde::{Deserialize, Serialize};
 22use serde_json::json;
 23use smol::stream::StreamExt;
 24use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
 25use util::path;
 26
 27mod test_tools;
 28use test_tools::*;
 29
 30#[gpui::test]
 31#[ignore = "can't run on CI yet"]
 32async fn test_echo(cx: &mut TestAppContext) {
 33    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
 34
 35    let events = thread
 36        .update(cx, |thread, cx| {
 37            thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
 38        })
 39        .collect()
 40        .await;
 41    thread.update(cx, |thread, _cx| {
 42        assert_eq!(
 43            thread.messages().last().unwrap().content,
 44            vec![MessageContent::Text("Hello".to_string())]
 45        );
 46    });
 47    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 48}
 49
 50#[gpui::test]
 51#[ignore = "can't run on CI yet"]
 52async fn test_thinking(cx: &mut TestAppContext) {
 53    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
 54
 55    let events = thread
 56        .update(cx, |thread, cx| {
 57            thread.send(
 58                model.clone(),
 59                indoc! {"
 60                    Testing:
 61
 62                    Generate a thinking step where you just think the word 'Think',
 63                    and have your final answer be 'Hello'
 64                "},
 65                cx,
 66            )
 67        })
 68        .collect()
 69        .await;
 70    thread.update(cx, |thread, _cx| {
 71        assert_eq!(
 72            thread.messages().last().unwrap().to_markdown(),
 73            indoc! {"
 74                ## assistant
 75                <think>Think</think>
 76                Hello
 77            "}
 78        )
 79    });
 80    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 81}
 82
 83#[gpui::test]
 84async fn test_system_prompt(cx: &mut TestAppContext) {
 85    let ThreadTest {
 86        model,
 87        thread,
 88        project_context,
 89        ..
 90    } = setup(cx, TestModel::Fake).await;
 91    let fake_model = model.as_fake();
 92
 93    project_context.borrow_mut().shell = "test-shell".into();
 94    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 95    thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
 96    cx.run_until_parked();
 97    let mut pending_completions = fake_model.pending_completions();
 98    assert_eq!(
 99        pending_completions.len(),
100        1,
101        "unexpected pending completions: {:?}",
102        pending_completions
103    );
104
105    let pending_completion = pending_completions.pop().unwrap();
106    assert_eq!(pending_completion.messages[0].role, Role::System);
107
108    let system_message = &pending_completion.messages[0];
109    let system_prompt = system_message.content[0].to_str().unwrap();
110    assert!(
111        system_prompt.contains("test-shell"),
112        "unexpected system message: {:?}",
113        system_message
114    );
115    assert!(
116        system_prompt.contains("## Fixing Diagnostics"),
117        "unexpected system message: {:?}",
118        system_message
119    );
120}
121
122#[gpui::test]
123#[ignore = "can't run on CI yet"]
124async fn test_basic_tool_calls(cx: &mut TestAppContext) {
125    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
126
127    // Test a tool call that's likely to complete *before* streaming stops.
128    let events = thread
129        .update(cx, |thread, cx| {
130            thread.add_tool(EchoTool);
131            thread.send(
132                model.clone(),
133                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
134                cx,
135            )
136        })
137        .collect()
138        .await;
139    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
140
141    // Test a tool calls that's likely to complete *after* streaming stops.
142    let events = thread
143        .update(cx, |thread, cx| {
144            thread.remove_tool(&AgentTool::name(&EchoTool));
145            thread.add_tool(DelayTool);
146            thread.send(
147                model.clone(),
148                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
149                cx,
150            )
151        })
152        .collect()
153        .await;
154    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
155    thread.update(cx, |thread, _cx| {
156        assert!(thread
157            .messages()
158            .last()
159            .unwrap()
160            .content
161            .iter()
162            .any(|content| {
163                if let MessageContent::Text(text) = content {
164                    text.contains("Ding")
165                } else {
166                    false
167                }
168            }));
169    });
170}
171
172#[gpui::test]
173#[ignore = "can't run on CI yet"]
174async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
175    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
176
177    // Test a tool call that's likely to complete *before* streaming stops.
178    let mut events = thread.update(cx, |thread, cx| {
179        thread.add_tool(WordListTool);
180        thread.send(model.clone(), "Test the word_list tool.", cx)
181    });
182
183    let mut saw_partial_tool_use = false;
184    while let Some(event) = events.next().await {
185        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
186            thread.update(cx, |thread, _cx| {
187                // Look for a tool use in the thread's last message
188                let last_content = thread.messages().last().unwrap().content.last().unwrap();
189                if let MessageContent::ToolUse(last_tool_use) = last_content {
190                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
191                    if tool_call.status == acp::ToolCallStatus::Pending {
192                        if !last_tool_use.is_input_complete
193                            && last_tool_use.input.get("g").is_none()
194                        {
195                            saw_partial_tool_use = true;
196                        }
197                    } else {
198                        last_tool_use
199                            .input
200                            .get("a")
201                            .expect("'a' has streamed because input is now complete");
202                        last_tool_use
203                            .input
204                            .get("g")
205                            .expect("'g' has streamed because input is now complete");
206                    }
207                } else {
208                    panic!("last content should be a tool use");
209                }
210            });
211        }
212    }
213
214    assert!(
215        saw_partial_tool_use,
216        "should see at least one partially streamed tool use in the history"
217    );
218}
219
220#[gpui::test]
221async fn test_tool_authorization(cx: &mut TestAppContext) {
222    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
223    let fake_model = model.as_fake();
224
225    let mut events = thread.update(cx, |thread, cx| {
226        thread.add_tool(ToolRequiringPermission);
227        thread.send(model.clone(), "abc", cx)
228    });
229    cx.run_until_parked();
230    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
231        LanguageModelToolUse {
232            id: "tool_id_1".into(),
233            name: ToolRequiringPermission.name().into(),
234            raw_input: "{}".into(),
235            input: json!({}),
236            is_input_complete: true,
237        },
238    ));
239    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
240        LanguageModelToolUse {
241            id: "tool_id_2".into(),
242            name: ToolRequiringPermission.name().into(),
243            raw_input: "{}".into(),
244            input: json!({}),
245            is_input_complete: true,
246        },
247    ));
248    fake_model.end_last_completion_stream();
249    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
250    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
251
252    // Approve the first
253    tool_call_auth_1
254        .response
255        .send(tool_call_auth_1.options[1].id.clone())
256        .unwrap();
257    cx.run_until_parked();
258
259    // Reject the second
260    tool_call_auth_2
261        .response
262        .send(tool_call_auth_1.options[2].id.clone())
263        .unwrap();
264    cx.run_until_parked();
265
266    let completion = fake_model.pending_completions().pop().unwrap();
267    let message = completion.messages.last().unwrap();
268    assert_eq!(
269        message.content,
270        vec![
271            MessageContent::ToolResult(LanguageModelToolResult {
272                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
273                tool_name: ToolRequiringPermission.name().into(),
274                is_error: false,
275                content: "Allowed".into(),
276                output: None
277            }),
278            MessageContent::ToolResult(LanguageModelToolResult {
279                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
280                tool_name: ToolRequiringPermission.name().into(),
281                is_error: true,
282                content: "Permission to run tool denied by user".into(),
283                output: None
284            })
285        ]
286    );
287}
288
289#[gpui::test]
290async fn test_tool_hallucination(cx: &mut TestAppContext) {
291    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
292    let fake_model = model.as_fake();
293
294    let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
295    cx.run_until_parked();
296    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
297        LanguageModelToolUse {
298            id: "tool_id_1".into(),
299            name: "nonexistent_tool".into(),
300            raw_input: "{}".into(),
301            input: json!({}),
302            is_input_complete: true,
303        },
304    ));
305    fake_model.end_last_completion_stream();
306
307    let tool_call = expect_tool_call(&mut events).await;
308    assert_eq!(tool_call.title, "nonexistent_tool");
309    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
310    let update = expect_tool_call_update(&mut events).await;
311    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
312}
313
314async fn expect_tool_call(
315    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
316) -> acp::ToolCall {
317    let event = events
318        .next()
319        .await
320        .expect("no tool call authorization event received")
321        .unwrap();
322    match event {
323        AgentResponseEvent::ToolCall(tool_call) => return tool_call,
324        event => {
325            panic!("Unexpected event {event:?}");
326        }
327    }
328}
329
330async fn expect_tool_call_update(
331    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
332) -> acp::ToolCallUpdate {
333    let event = events
334        .next()
335        .await
336        .expect("no tool call authorization event received")
337        .unwrap();
338    match event {
339        AgentResponseEvent::ToolCallUpdate(tool_call_update) => return tool_call_update,
340        event => {
341            panic!("Unexpected event {event:?}");
342        }
343    }
344}
345
346async fn next_tool_call_authorization(
347    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
348) -> ToolCallAuthorization {
349    loop {
350        let event = events
351            .next()
352            .await
353            .expect("no tool call authorization event received")
354            .unwrap();
355        if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
356            let permission_kinds = tool_call_authorization
357                .options
358                .iter()
359                .map(|o| o.kind)
360                .collect::<Vec<_>>();
361            assert_eq!(
362                permission_kinds,
363                vec![
364                    acp::PermissionOptionKind::AllowAlways,
365                    acp::PermissionOptionKind::AllowOnce,
366                    acp::PermissionOptionKind::RejectOnce,
367                ]
368            );
369            return tool_call_authorization;
370        }
371    }
372}
373
374#[gpui::test]
375#[ignore = "can't run on CI yet"]
376async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
377    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
378
379    // Test concurrent tool calls with different delay times
380    let events = thread
381        .update(cx, |thread, cx| {
382            thread.add_tool(DelayTool);
383            thread.send(
384                model.clone(),
385                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
386                cx,
387            )
388        })
389        .collect()
390        .await;
391
392    let stop_reasons = stop_events(events);
393    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
394
395    thread.update(cx, |thread, _cx| {
396        let last_message = thread.messages().last().unwrap();
397        let text = last_message
398            .content
399            .iter()
400            .filter_map(|content| {
401                if let MessageContent::Text(text) = content {
402                    Some(text.as_str())
403                } else {
404                    None
405                }
406            })
407            .collect::<String>();
408
409        assert!(text.contains("Ding"));
410    });
411}
412
413#[gpui::test]
414#[ignore = "can't run on CI yet"]
415async fn test_cancellation(cx: &mut TestAppContext) {
416    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
417
418    let mut events = thread.update(cx, |thread, cx| {
419        thread.add_tool(InfiniteTool);
420        thread.add_tool(EchoTool);
421        thread.send(
422            model.clone(),
423            "Call the echo tool and then call the infinite tool, then explain their output",
424            cx,
425        )
426    });
427
428    // Wait until both tools are called.
429    let mut expected_tool_calls = vec!["echo", "infinite"];
430    let mut echo_id = None;
431    let mut echo_completed = false;
432    while let Some(event) = events.next().await {
433        match event.unwrap() {
434            AgentResponseEvent::ToolCall(tool_call) => {
435                assert_eq!(tool_call.title, expected_tool_calls.remove(0));
436                if tool_call.title == "echo" {
437                    echo_id = Some(tool_call.id);
438                }
439            }
440            AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate {
441                id,
442                fields:
443                    acp::ToolCallUpdateFields {
444                        status: Some(acp::ToolCallStatus::Completed),
445                        ..
446                    },
447            }) if Some(&id) == echo_id.as_ref() => {
448                echo_completed = true;
449            }
450            _ => {}
451        }
452
453        if expected_tool_calls.is_empty() && echo_completed {
454            break;
455        }
456    }
457
458    // Cancel the current send and ensure that the event stream is closed, even
459    // if one of the tools is still running.
460    thread.update(cx, |thread, _cx| thread.cancel());
461    events.collect::<Vec<_>>().await;
462
463    // Ensure we can still send a new message after cancellation.
464    let events = thread
465        .update(cx, |thread, cx| {
466            thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
467        })
468        .collect::<Vec<_>>()
469        .await;
470    thread.update(cx, |thread, _cx| {
471        assert_eq!(
472            thread.messages().last().unwrap().content,
473            vec![MessageContent::Text("Hello".to_string())]
474        );
475    });
476    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
477}
478
479#[gpui::test]
480async fn test_refusal(cx: &mut TestAppContext) {
481    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
482    let fake_model = model.as_fake();
483
484    let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
485    cx.run_until_parked();
486    thread.read_with(cx, |thread, _| {
487        assert_eq!(
488            thread.to_markdown(),
489            indoc! {"
490                ## user
491                Hello
492            "}
493        );
494    });
495
496    fake_model.send_last_completion_stream_text_chunk("Hey!");
497    cx.run_until_parked();
498    thread.read_with(cx, |thread, _| {
499        assert_eq!(
500            thread.to_markdown(),
501            indoc! {"
502                ## user
503                Hello
504                ## assistant
505                Hey!
506            "}
507        );
508    });
509
510    // If the model refuses to continue, the thread should remove all the messages after the last user message.
511    fake_model
512        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
513    let events = events.collect::<Vec<_>>().await;
514    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
515    thread.read_with(cx, |thread, _| {
516        assert_eq!(thread.to_markdown(), "");
517    });
518}
519
520#[gpui::test]
521async fn test_agent_connection(cx: &mut TestAppContext) {
522    cx.update(settings::init);
523    let templates = Templates::new();
524
525    // Initialize language model system with test provider
526    cx.update(|cx| {
527        gpui_tokio::init(cx);
528        client::init_settings(cx);
529
530        let http_client = FakeHttpClient::with_404_response();
531        let clock = Arc::new(clock::FakeSystemClock::new());
532        let client = Client::new(clock, http_client, cx);
533        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
534        language_model::init(client.clone(), cx);
535        language_models::init(user_store.clone(), client.clone(), cx);
536        Project::init_settings(cx);
537        LanguageModelRegistry::test(cx);
538    });
539    cx.executor().forbid_parking();
540
541    // Create a project for new_thread
542    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
543    fake_fs.insert_tree(path!("/test"), json!({})).await;
544    let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
545    let cwd = Path::new("/test");
546
547    // Create agent and connection
548    let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
549        .await
550        .unwrap();
551    let connection = NativeAgentConnection(agent.clone());
552
553    // Test model_selector returns Some
554    let selector_opt = connection.model_selector();
555    assert!(
556        selector_opt.is_some(),
557        "agent2 should always support ModelSelector"
558    );
559    let selector = selector_opt.unwrap();
560
561    // Test list_models
562    let listed_models = cx
563        .update(|cx| {
564            let mut async_cx = cx.to_async();
565            selector.list_models(&mut async_cx)
566        })
567        .await
568        .expect("list_models should succeed");
569    assert!(!listed_models.is_empty(), "should have at least one model");
570    assert_eq!(listed_models[0].id().0, "fake");
571
572    // Create a thread using new_thread
573    let connection_rc = Rc::new(connection.clone());
574    let acp_thread = cx
575        .update(|cx| {
576            let mut async_cx = cx.to_async();
577            connection_rc.new_thread(project, cwd, &mut async_cx)
578        })
579        .await
580        .expect("new_thread should succeed");
581
582    // Get the session_id from the AcpThread
583    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
584
585    // Test selected_model returns the default
586    let model = cx
587        .update(|cx| {
588            let mut async_cx = cx.to_async();
589            selector.selected_model(&session_id, &mut async_cx)
590        })
591        .await
592        .expect("selected_model should succeed");
593    let model = model.as_fake();
594    assert_eq!(model.id().0, "fake", "should return default model");
595
596    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
597    cx.run_until_parked();
598    model.send_last_completion_stream_text_chunk("def");
599    cx.run_until_parked();
600    acp_thread.read_with(cx, |thread, cx| {
601        assert_eq!(
602            thread.to_markdown(cx),
603            indoc! {"
604                ## User
605
606                abc
607
608                ## Assistant
609
610                def
611
612            "}
613        )
614    });
615
616    // Test cancel
617    cx.update(|cx| connection.cancel(&session_id, cx));
618    request.await.expect("prompt should fail gracefully");
619
620    // Ensure that dropping the ACP thread causes the native thread to be
621    // dropped as well.
622    cx.update(|_| drop(acp_thread));
623    let result = cx
624        .update(|cx| {
625            connection.prompt(
626                acp::PromptRequest {
627                    session_id: session_id.clone(),
628                    prompt: vec!["ghi".into()],
629                },
630                cx,
631            )
632        })
633        .await;
634    assert_eq!(
635        result.as_ref().unwrap_err().to_string(),
636        "Session not found",
637        "unexpected result: {:?}",
638        result
639    );
640}
641
642#[gpui::test]
643async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
644    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
645    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
646    let fake_model = model.as_fake();
647
648    let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx));
649    cx.run_until_parked();
650
651    let input = json!({ "content": "Thinking hard!" });
652    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
653        LanguageModelToolUse {
654            id: "1".into(),
655            name: ThinkingTool.name().into(),
656            raw_input: input.to_string(),
657            input,
658            is_input_complete: true,
659        },
660    ));
661    fake_model.end_last_completion_stream();
662    cx.run_until_parked();
663
664    let tool_call = expect_tool_call(&mut events).await;
665    assert_eq!(
666        tool_call,
667        acp::ToolCall {
668            id: acp::ToolCallId("1".into()),
669            title: "Thinking".into(),
670            kind: acp::ToolKind::Think,
671            status: acp::ToolCallStatus::Pending,
672            content: vec![],
673            locations: vec![],
674            raw_input: Some(json!({ "content": "Thinking hard!" })),
675            raw_output: None,
676        }
677    );
678    let update = expect_tool_call_update(&mut events).await;
679    assert_eq!(
680        update,
681        acp::ToolCallUpdate {
682            id: acp::ToolCallId("1".into()),
683            fields: acp::ToolCallUpdateFields {
684                status: Some(acp::ToolCallStatus::InProgress,),
685                ..Default::default()
686            },
687        }
688    );
689    let update = expect_tool_call_update(&mut events).await;
690    assert_eq!(
691        update,
692        acp::ToolCallUpdate {
693            id: acp::ToolCallId("1".into()),
694            fields: acp::ToolCallUpdateFields {
695                content: Some(vec!["Thinking hard!".into()]),
696                ..Default::default()
697            },
698        }
699    );
700    let update = expect_tool_call_update(&mut events).await;
701    assert_eq!(
702        update,
703        acp::ToolCallUpdate {
704            id: acp::ToolCallId("1".into()),
705            fields: acp::ToolCallUpdateFields {
706                status: Some(acp::ToolCallStatus::Completed),
707                ..Default::default()
708            },
709        }
710    );
711}
712
713/// Filters out the stop events for asserting against in tests
714fn stop_events(
715    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
716) -> Vec<acp::StopReason> {
717    result_events
718        .into_iter()
719        .filter_map(|event| match event.unwrap() {
720            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
721            _ => None,
722        })
723        .collect()
724}
725
726struct ThreadTest {
727    model: Arc<dyn LanguageModel>,
728    thread: Entity<Thread>,
729    project_context: Rc<RefCell<ProjectContext>>,
730}
731
732enum TestModel {
733    Sonnet4,
734    Sonnet4Thinking,
735    Fake,
736}
737
738impl TestModel {
739    fn id(&self) -> LanguageModelId {
740        match self {
741            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
742            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
743            TestModel::Fake => unreachable!(),
744        }
745    }
746}
747
748async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
749    cx.executor().allow_parking();
750    cx.update(|cx| {
751        settings::init(cx);
752        Project::init_settings(cx);
753    });
754    let templates = Templates::new();
755
756    let fs = FakeFs::new(cx.background_executor.clone());
757    fs.insert_tree(path!("/test"), json!({})).await;
758    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
759
760    let model = cx
761        .update(|cx| {
762            gpui_tokio::init(cx);
763            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
764            cx.set_http_client(Arc::new(http_client));
765
766            client::init_settings(cx);
767            let client = Client::production(cx);
768            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
769            language_model::init(client.clone(), cx);
770            language_models::init(user_store.clone(), client.clone(), cx);
771
772            if let TestModel::Fake = model {
773                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
774            } else {
775                let model_id = model.id();
776                let models = LanguageModelRegistry::read_global(cx);
777                let model = models
778                    .available_models(cx)
779                    .find(|model| model.id() == model_id)
780                    .unwrap();
781
782                let provider = models.provider(&model.provider_id()).unwrap();
783                let authenticated = provider.authenticate(cx);
784
785                cx.spawn(async move |_cx| {
786                    authenticated.await.unwrap();
787                    model
788                })
789            }
790        })
791        .await;
792
793    let project_context = Rc::new(RefCell::new(ProjectContext::default()));
794    let action_log = cx.new(|_| ActionLog::new(project.clone()));
795    let thread = cx.new(|_| {
796        Thread::new(
797            project,
798            project_context.clone(),
799            action_log,
800            templates,
801            model.clone(),
802        )
803    });
804    ThreadTest {
805        model,
806        thread,
807        project_context,
808    }
809}
810
811#[cfg(test)]
812#[ctor::ctor]
813fn init_logger() {
814    if std::env::var("RUST_LOG").is_ok() {
815        env_logger::init();
816    }
817}