mod.rs

  1use super::*;
  2use acp_thread::AgentConnection;
  3use action_log::ActionLog;
  4use agent_client_protocol::{self as acp};
  5use anyhow::Result;
  6use client::{Client, UserStore};
  7use fs::FakeFs;
  8use futures::channel::mpsc::UnboundedReceiver;
  9use gpui::{AppContext, Entity, Task, TestAppContext, http_client::FakeHttpClient};
 10use indoc::indoc;
 11use language_model::{
 12    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
 13    LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
 14    StopReason, fake_provider::FakeLanguageModel,
 15};
 16use project::Project;
 17use prompt_store::ProjectContext;
 18use reqwest_client::ReqwestClient;
 19use schemars::JsonSchema;
 20use serde::{Deserialize, Serialize};
 21use serde_json::json;
 22use smol::stream::StreamExt;
 23use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
 24use util::path;
 25
 26mod test_tools;
 27use test_tools::*;
 28
 29#[gpui::test]
 30#[ignore = "can't run on CI yet"]
 31async fn test_echo(cx: &mut TestAppContext) {
 32    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 33
 34    let events = thread
 35        .update(cx, |thread, cx| {
 36            thread.send("Testing: Reply with 'Hello'", cx)
 37        })
 38        .collect()
 39        .await;
 40    thread.update(cx, |thread, _cx| {
 41        assert_eq!(
 42            thread.messages().last().unwrap().content,
 43            vec![MessageContent::Text("Hello".to_string())]
 44        );
 45    });
 46    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 47}
 48
 49#[gpui::test]
 50#[ignore = "can't run on CI yet"]
 51async fn test_thinking(cx: &mut TestAppContext) {
 52    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
 53
 54    let events = thread
 55        .update(cx, |thread, cx| {
 56            thread.send(
 57                indoc! {"
 58                    Testing:
 59
 60                    Generate a thinking step where you just think the word 'Think',
 61                    and have your final answer be 'Hello'
 62                "},
 63                cx,
 64            )
 65        })
 66        .collect()
 67        .await;
 68    thread.update(cx, |thread, _cx| {
 69        assert_eq!(
 70            thread.messages().last().unwrap().to_markdown(),
 71            indoc! {"
 72                ## assistant
 73                <think>Think</think>
 74                Hello
 75            "}
 76        )
 77    });
 78    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 79}
 80
 81#[gpui::test]
 82async fn test_system_prompt(cx: &mut TestAppContext) {
 83    let ThreadTest {
 84        model,
 85        thread,
 86        project_context,
 87        ..
 88    } = setup(cx, TestModel::Fake).await;
 89    let fake_model = model.as_fake();
 90
 91    project_context.borrow_mut().shell = "test-shell".into();
 92    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 93    thread.update(cx, |thread, cx| thread.send("abc", cx));
 94    cx.run_until_parked();
 95    let mut pending_completions = fake_model.pending_completions();
 96    assert_eq!(
 97        pending_completions.len(),
 98        1,
 99        "unexpected pending completions: {:?}",
100        pending_completions
101    );
102
103    let pending_completion = pending_completions.pop().unwrap();
104    assert_eq!(pending_completion.messages[0].role, Role::System);
105
106    let system_message = &pending_completion.messages[0];
107    let system_prompt = system_message.content[0].to_str().unwrap();
108    assert!(
109        system_prompt.contains("test-shell"),
110        "unexpected system message: {:?}",
111        system_message
112    );
113    assert!(
114        system_prompt.contains("## Fixing Diagnostics"),
115        "unexpected system message: {:?}",
116        system_message
117    );
118}
119
120#[gpui::test]
121#[ignore = "can't run on CI yet"]
122async fn test_basic_tool_calls(cx: &mut TestAppContext) {
123    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
124
125    // Test a tool call that's likely to complete *before* streaming stops.
126    let events = thread
127        .update(cx, |thread, cx| {
128            thread.add_tool(EchoTool);
129            thread.send(
130                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
131                cx,
132            )
133        })
134        .collect()
135        .await;
136    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
137
138    // Test a tool calls that's likely to complete *after* streaming stops.
139    let events = thread
140        .update(cx, |thread, cx| {
141            thread.remove_tool(&AgentTool::name(&EchoTool));
142            thread.add_tool(DelayTool);
143            thread.send(
144                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
145                cx,
146            )
147        })
148        .collect()
149        .await;
150    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
151    thread.update(cx, |thread, _cx| {
152        assert!(
153            thread
154                .messages()
155                .last()
156                .unwrap()
157                .content
158                .iter()
159                .any(|content| {
160                    if let MessageContent::Text(text) = content {
161                        text.contains("Ding")
162                    } else {
163                        false
164                    }
165                })
166        );
167    });
168}
169
170#[gpui::test]
171#[ignore = "can't run on CI yet"]
172async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
173    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
174
175    // Test a tool call that's likely to complete *before* streaming stops.
176    let mut events = thread.update(cx, |thread, cx| {
177        thread.add_tool(WordListTool);
178        thread.send("Test the word_list tool.", cx)
179    });
180
181    let mut saw_partial_tool_use = false;
182    while let Some(event) = events.next().await {
183        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
184            thread.update(cx, |thread, _cx| {
185                // Look for a tool use in the thread's last message
186                let last_content = thread.messages().last().unwrap().content.last().unwrap();
187                if let MessageContent::ToolUse(last_tool_use) = last_content {
188                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
189                    if tool_call.status == acp::ToolCallStatus::Pending {
190                        if !last_tool_use.is_input_complete
191                            && last_tool_use.input.get("g").is_none()
192                        {
193                            saw_partial_tool_use = true;
194                        }
195                    } else {
196                        last_tool_use
197                            .input
198                            .get("a")
199                            .expect("'a' has streamed because input is now complete");
200                        last_tool_use
201                            .input
202                            .get("g")
203                            .expect("'g' has streamed because input is now complete");
204                    }
205                } else {
206                    panic!("last content should be a tool use");
207                }
208            });
209        }
210    }
211
212    assert!(
213        saw_partial_tool_use,
214        "should see at least one partially streamed tool use in the history"
215    );
216}
217
218#[gpui::test]
219async fn test_tool_authorization(cx: &mut TestAppContext) {
220    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
221    let fake_model = model.as_fake();
222
223    let mut events = thread.update(cx, |thread, cx| {
224        thread.add_tool(ToolRequiringPermission);
225        thread.send("abc", cx)
226    });
227    cx.run_until_parked();
228    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
229        LanguageModelToolUse {
230            id: "tool_id_1".into(),
231            name: ToolRequiringPermission.name().into(),
232            raw_input: "{}".into(),
233            input: json!({}),
234            is_input_complete: true,
235        },
236    ));
237    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
238        LanguageModelToolUse {
239            id: "tool_id_2".into(),
240            name: ToolRequiringPermission.name().into(),
241            raw_input: "{}".into(),
242            input: json!({}),
243            is_input_complete: true,
244        },
245    ));
246    fake_model.end_last_completion_stream();
247    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
248    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
249
250    // Approve the first
251    tool_call_auth_1
252        .response
253        .send(tool_call_auth_1.options[1].id.clone())
254        .unwrap();
255    cx.run_until_parked();
256
257    // Reject the second
258    tool_call_auth_2
259        .response
260        .send(tool_call_auth_1.options[2].id.clone())
261        .unwrap();
262    cx.run_until_parked();
263
264    let completion = fake_model.pending_completions().pop().unwrap();
265    let message = completion.messages.last().unwrap();
266    assert_eq!(
267        message.content,
268        vec![
269            MessageContent::ToolResult(LanguageModelToolResult {
270                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
271                tool_name: ToolRequiringPermission.name().into(),
272                is_error: false,
273                content: "Allowed".into(),
274                output: Some("Allowed".into())
275            }),
276            MessageContent::ToolResult(LanguageModelToolResult {
277                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
278                tool_name: ToolRequiringPermission.name().into(),
279                is_error: true,
280                content: "Permission to run tool denied by user".into(),
281                output: None
282            })
283        ]
284    );
285}
286
287#[gpui::test]
288async fn test_tool_hallucination(cx: &mut TestAppContext) {
289    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
290    let fake_model = model.as_fake();
291
292    let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
293    cx.run_until_parked();
294    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
295        LanguageModelToolUse {
296            id: "tool_id_1".into(),
297            name: "nonexistent_tool".into(),
298            raw_input: "{}".into(),
299            input: json!({}),
300            is_input_complete: true,
301        },
302    ));
303    fake_model.end_last_completion_stream();
304
305    let tool_call = expect_tool_call(&mut events).await;
306    assert_eq!(tool_call.title, "nonexistent_tool");
307    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
308    let update = expect_tool_call_update_fields(&mut events).await;
309    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
310}
311
312async fn expect_tool_call(
313    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
314) -> acp::ToolCall {
315    let event = events
316        .next()
317        .await
318        .expect("no tool call authorization event received")
319        .unwrap();
320    match event {
321        AgentResponseEvent::ToolCall(tool_call) => return tool_call,
322        event => {
323            panic!("Unexpected event {event:?}");
324        }
325    }
326}
327
328async fn expect_tool_call_update_fields(
329    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
330) -> acp::ToolCallUpdate {
331    let event = events
332        .next()
333        .await
334        .expect("no tool call authorization event received")
335        .unwrap();
336    match event {
337        AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
338            return update;
339        }
340        event => {
341            panic!("Unexpected event {event:?}");
342        }
343    }
344}
345
346async fn next_tool_call_authorization(
347    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
348) -> ToolCallAuthorization {
349    loop {
350        let event = events
351            .next()
352            .await
353            .expect("no tool call authorization event received")
354            .unwrap();
355        if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
356            let permission_kinds = tool_call_authorization
357                .options
358                .iter()
359                .map(|o| o.kind)
360                .collect::<Vec<_>>();
361            assert_eq!(
362                permission_kinds,
363                vec![
364                    acp::PermissionOptionKind::AllowAlways,
365                    acp::PermissionOptionKind::AllowOnce,
366                    acp::PermissionOptionKind::RejectOnce,
367                ]
368            );
369            return tool_call_authorization;
370        }
371    }
372}
373
374#[gpui::test]
375#[ignore = "can't run on CI yet"]
376async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
377    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
378
379    // Test concurrent tool calls with different delay times
380    let events = thread
381        .update(cx, |thread, cx| {
382            thread.add_tool(DelayTool);
383            thread.send(
384                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
385                cx,
386            )
387        })
388        .collect()
389        .await;
390
391    let stop_reasons = stop_events(events);
392    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
393
394    thread.update(cx, |thread, _cx| {
395        let last_message = thread.messages().last().unwrap();
396        let text = last_message
397            .content
398            .iter()
399            .filter_map(|content| {
400                if let MessageContent::Text(text) = content {
401                    Some(text.as_str())
402                } else {
403                    None
404                }
405            })
406            .collect::<String>();
407
408        assert!(text.contains("Ding"));
409    });
410}
411
412#[gpui::test]
413#[ignore = "can't run on CI yet"]
414async fn test_cancellation(cx: &mut TestAppContext) {
415    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
416
417    let mut events = thread.update(cx, |thread, cx| {
418        thread.add_tool(InfiniteTool);
419        thread.add_tool(EchoTool);
420        thread.send(
421            "Call the echo tool and then call the infinite tool, then explain their output",
422            cx,
423        )
424    });
425
426    // Wait until both tools are called.
427    let mut expected_tools = vec!["Echo", "Infinite Tool"];
428    let mut echo_id = None;
429    let mut echo_completed = false;
430    while let Some(event) = events.next().await {
431        match event.unwrap() {
432            AgentResponseEvent::ToolCall(tool_call) => {
433                assert_eq!(tool_call.title, expected_tools.remove(0));
434                if tool_call.title == "Echo" {
435                    echo_id = Some(tool_call.id);
436                }
437            }
438            AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
439                acp::ToolCallUpdate {
440                    id,
441                    fields:
442                        acp::ToolCallUpdateFields {
443                            status: Some(acp::ToolCallStatus::Completed),
444                            ..
445                        },
446                },
447            )) if Some(&id) == echo_id.as_ref() => {
448                echo_completed = true;
449            }
450            _ => {}
451        }
452
453        if expected_tools.is_empty() && echo_completed {
454            break;
455        }
456    }
457
458    // Cancel the current send and ensure that the event stream is closed, even
459    // if one of the tools is still running.
460    thread.update(cx, |thread, _cx| thread.cancel());
461    events.collect::<Vec<_>>().await;
462
463    // Ensure we can still send a new message after cancellation.
464    let events = thread
465        .update(cx, |thread, cx| {
466            thread.send("Testing: reply with 'Hello' then stop.", cx)
467        })
468        .collect::<Vec<_>>()
469        .await;
470    thread.update(cx, |thread, _cx| {
471        assert_eq!(
472            thread.messages().last().unwrap().content,
473            vec![MessageContent::Text("Hello".to_string())]
474        );
475    });
476    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
477}
478
479#[gpui::test]
480async fn test_refusal(cx: &mut TestAppContext) {
481    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
482    let fake_model = model.as_fake();
483
484    let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
485    cx.run_until_parked();
486    thread.read_with(cx, |thread, _| {
487        assert_eq!(
488            thread.to_markdown(),
489            indoc! {"
490                ## user
491                Hello
492            "}
493        );
494    });
495
496    fake_model.send_last_completion_stream_text_chunk("Hey!");
497    cx.run_until_parked();
498    thread.read_with(cx, |thread, _| {
499        assert_eq!(
500            thread.to_markdown(),
501            indoc! {"
502                ## user
503                Hello
504                ## assistant
505                Hey!
506            "}
507        );
508    });
509
510    // If the model refuses to continue, the thread should remove all the messages after the last user message.
511    fake_model
512        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
513    let events = events.collect::<Vec<_>>().await;
514    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
515    thread.read_with(cx, |thread, _| {
516        assert_eq!(thread.to_markdown(), "");
517    });
518}
519
520#[gpui::test]
521async fn test_agent_connection(cx: &mut TestAppContext) {
522    cx.update(settings::init);
523    let templates = Templates::new();
524
525    // Initialize language model system with test provider
526    cx.update(|cx| {
527        gpui_tokio::init(cx);
528        client::init_settings(cx);
529
530        let http_client = FakeHttpClient::with_404_response();
531        let clock = Arc::new(clock::FakeSystemClock::new());
532        let client = Client::new(clock, http_client, cx);
533        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
534        language_model::init(client.clone(), cx);
535        language_models::init(user_store.clone(), client.clone(), cx);
536        Project::init_settings(cx);
537        LanguageModelRegistry::test(cx);
538    });
539    cx.executor().forbid_parking();
540
541    // Create a project for new_thread
542    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
543    fake_fs.insert_tree(path!("/test"), json!({})).await;
544    let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
545    let cwd = Path::new("/test");
546
547    // Create agent and connection
548    let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
549        .await
550        .unwrap();
551    let connection = NativeAgentConnection(agent.clone());
552
553    // Test model_selector returns Some
554    let selector_opt = connection.model_selector();
555    assert!(
556        selector_opt.is_some(),
557        "agent2 should always support ModelSelector"
558    );
559    let selector = selector_opt.unwrap();
560
561    // Test list_models
562    let listed_models = cx
563        .update(|cx| {
564            let mut async_cx = cx.to_async();
565            selector.list_models(&mut async_cx)
566        })
567        .await
568        .expect("list_models should succeed");
569    assert!(!listed_models.is_empty(), "should have at least one model");
570    assert_eq!(listed_models[0].id().0, "fake");
571
572    // Create a thread using new_thread
573    let connection_rc = Rc::new(connection.clone());
574    let acp_thread = cx
575        .update(|cx| {
576            let mut async_cx = cx.to_async();
577            connection_rc.new_thread(project, cwd, &mut async_cx)
578        })
579        .await
580        .expect("new_thread should succeed");
581
582    // Get the session_id from the AcpThread
583    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
584
585    // Test selected_model returns the default
586    let model = cx
587        .update(|cx| {
588            let mut async_cx = cx.to_async();
589            selector.selected_model(&session_id, &mut async_cx)
590        })
591        .await
592        .expect("selected_model should succeed");
593    let model = model.as_fake();
594    assert_eq!(model.id().0, "fake", "should return default model");
595
596    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
597    cx.run_until_parked();
598    model.send_last_completion_stream_text_chunk("def");
599    cx.run_until_parked();
600    acp_thread.read_with(cx, |thread, cx| {
601        assert_eq!(
602            thread.to_markdown(cx),
603            indoc! {"
604                ## User
605
606                abc
607
608                ## Assistant
609
610                def
611
612            "}
613        )
614    });
615
616    // Test cancel
617    cx.update(|cx| connection.cancel(&session_id, cx));
618    request.await.expect("prompt should fail gracefully");
619
620    // Ensure that dropping the ACP thread causes the native thread to be
621    // dropped as well.
622    cx.update(|_| drop(acp_thread));
623    let result = cx
624        .update(|cx| {
625            connection.prompt(
626                acp::PromptRequest {
627                    session_id: session_id.clone(),
628                    prompt: vec!["ghi".into()],
629                },
630                cx,
631            )
632        })
633        .await;
634    assert_eq!(
635        result.as_ref().unwrap_err().to_string(),
636        "Session not found",
637        "unexpected result: {:?}",
638        result
639    );
640}
641
642#[gpui::test]
643async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
644    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
645    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
646    let fake_model = model.as_fake();
647
648    let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
649    cx.run_until_parked();
650
651    // Simulate streaming partial input.
652    let input = json!({});
653    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
654        LanguageModelToolUse {
655            id: "1".into(),
656            name: ThinkingTool.name().into(),
657            raw_input: input.to_string(),
658            input,
659            is_input_complete: false,
660        },
661    ));
662
663    // Input streaming completed
664    let input = json!({ "content": "Thinking hard!" });
665    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
666        LanguageModelToolUse {
667            id: "1".into(),
668            name: "thinking".into(),
669            raw_input: input.to_string(),
670            input,
671            is_input_complete: true,
672        },
673    ));
674    fake_model.end_last_completion_stream();
675    cx.run_until_parked();
676
677    let tool_call = expect_tool_call(&mut events).await;
678    assert_eq!(
679        tool_call,
680        acp::ToolCall {
681            id: acp::ToolCallId("1".into()),
682            title: "Thinking".into(),
683            kind: acp::ToolKind::Think,
684            status: acp::ToolCallStatus::Pending,
685            content: vec![],
686            locations: vec![],
687            raw_input: Some(json!({})),
688            raw_output: None,
689        }
690    );
691    let update = expect_tool_call_update_fields(&mut events).await;
692    assert_eq!(
693        update,
694        acp::ToolCallUpdate {
695            id: acp::ToolCallId("1".into()),
696            fields: acp::ToolCallUpdateFields {
697                title: Some("Thinking".into()),
698                kind: Some(acp::ToolKind::Think),
699                raw_input: Some(json!({ "content": "Thinking hard!" })),
700                ..Default::default()
701            },
702        }
703    );
704    let update = expect_tool_call_update_fields(&mut events).await;
705    assert_eq!(
706        update,
707        acp::ToolCallUpdate {
708            id: acp::ToolCallId("1".into()),
709            fields: acp::ToolCallUpdateFields {
710                status: Some(acp::ToolCallStatus::InProgress),
711                ..Default::default()
712            },
713        }
714    );
715    let update = expect_tool_call_update_fields(&mut events).await;
716    assert_eq!(
717        update,
718        acp::ToolCallUpdate {
719            id: acp::ToolCallId("1".into()),
720            fields: acp::ToolCallUpdateFields {
721                content: Some(vec!["Thinking hard!".into()]),
722                ..Default::default()
723            },
724        }
725    );
726    let update = expect_tool_call_update_fields(&mut events).await;
727    assert_eq!(
728        update,
729        acp::ToolCallUpdate {
730            id: acp::ToolCallId("1".into()),
731            fields: acp::ToolCallUpdateFields {
732                status: Some(acp::ToolCallStatus::Completed),
733                ..Default::default()
734            },
735        }
736    );
737}
738
739/// Filters out the stop events for asserting against in tests
740fn stop_events(
741    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
742) -> Vec<acp::StopReason> {
743    result_events
744        .into_iter()
745        .filter_map(|event| match event.unwrap() {
746            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
747            _ => None,
748        })
749        .collect()
750}
751
752struct ThreadTest {
753    model: Arc<dyn LanguageModel>,
754    thread: Entity<Thread>,
755    project_context: Rc<RefCell<ProjectContext>>,
756}
757
758enum TestModel {
759    Sonnet4,
760    Sonnet4Thinking,
761    Fake,
762}
763
764impl TestModel {
765    fn id(&self) -> LanguageModelId {
766        match self {
767            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
768            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
769            TestModel::Fake => unreachable!(),
770        }
771    }
772}
773
774async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
775    cx.executor().allow_parking();
776    cx.update(|cx| {
777        settings::init(cx);
778        Project::init_settings(cx);
779    });
780    let templates = Templates::new();
781
782    let fs = FakeFs::new(cx.background_executor.clone());
783    fs.insert_tree(path!("/test"), json!({})).await;
784    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
785
786    let model = cx
787        .update(|cx| {
788            gpui_tokio::init(cx);
789            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
790            cx.set_http_client(Arc::new(http_client));
791
792            client::init_settings(cx);
793            let client = Client::production(cx);
794            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
795            language_model::init(client.clone(), cx);
796            language_models::init(user_store.clone(), client.clone(), cx);
797
798            if let TestModel::Fake = model {
799                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
800            } else {
801                let model_id = model.id();
802                let models = LanguageModelRegistry::read_global(cx);
803                let model = models
804                    .available_models(cx)
805                    .find(|model| model.id() == model_id)
806                    .unwrap();
807
808                let provider = models.provider(&model.provider_id()).unwrap();
809                let authenticated = provider.authenticate(cx);
810
811                cx.spawn(async move |_cx| {
812                    authenticated.await.unwrap();
813                    model
814                })
815            }
816        })
817        .await;
818
819    let project_context = Rc::new(RefCell::new(ProjectContext::default()));
820    let action_log = cx.new(|_| ActionLog::new(project.clone()));
821    let thread = cx.new(|_| {
822        Thread::new(
823            project,
824            project_context.clone(),
825            action_log,
826            templates,
827            model.clone(),
828        )
829    });
830    ThreadTest {
831        model,
832        thread,
833        project_context,
834    }
835}
836
837#[cfg(test)]
838#[ctor::ctor]
839fn init_logger() {
840    if std::env::var("RUST_LOG").is_ok() {
841        env_logger::init();
842    }
843}