mod.rs

  1use super::*;
  2use acp_thread::AgentConnection;
  3use agent_client_protocol::{self as acp};
  4use anyhow::Result;
  5use assistant_tool::ActionLog;
  6use client::{Client, UserStore};
  7use fs::FakeFs;
  8use futures::channel::mpsc::UnboundedReceiver;
  9use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
 10use indoc::indoc;
 11use language_model::{
 12    fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
 13    LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
 14    LanguageModelToolUse, MessageContent, Role, StopReason,
 15};
 16use project::Project;
 17use prompt_store::ProjectContext;
 18use reqwest_client::ReqwestClient;
 19use schemars::JsonSchema;
 20use serde::{Deserialize, Serialize};
 21use serde_json::json;
 22use smol::stream::StreamExt;
 23use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
 24use util::path;
 25
 26mod test_tools;
 27use test_tools::*;
 28
 29#[gpui::test]
 30#[ignore = "can't run on CI yet"]
 31async fn test_echo(cx: &mut TestAppContext) {
 32    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
 33
 34    let events = thread
 35        .update(cx, |thread, cx| {
 36            thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
 37        })
 38        .collect()
 39        .await;
 40    thread.update(cx, |thread, _cx| {
 41        assert_eq!(
 42            thread.messages().last().unwrap().content,
 43            vec![MessageContent::Text("Hello".to_string())]
 44        );
 45    });
 46    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 47}
 48
 49#[gpui::test]
 50#[ignore = "can't run on CI yet"]
 51async fn test_thinking(cx: &mut TestAppContext) {
 52    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
 53
 54    let events = thread
 55        .update(cx, |thread, cx| {
 56            thread.send(
 57                model.clone(),
 58                indoc! {"
 59                    Testing:
 60
 61                    Generate a thinking step where you just think the word 'Think',
 62                    and have your final answer be 'Hello'
 63                "},
 64                cx,
 65            )
 66        })
 67        .collect()
 68        .await;
 69    thread.update(cx, |thread, _cx| {
 70        assert_eq!(
 71            thread.messages().last().unwrap().to_markdown(),
 72            indoc! {"
 73                ## assistant
 74                <think>Think</think>
 75                Hello
 76            "}
 77        )
 78    });
 79    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 80}
 81
 82#[gpui::test]
 83async fn test_system_prompt(cx: &mut TestAppContext) {
 84    let ThreadTest {
 85        model,
 86        thread,
 87        project_context,
 88        ..
 89    } = setup(cx, TestModel::Fake).await;
 90    let fake_model = model.as_fake();
 91
 92    project_context.borrow_mut().shell = "test-shell".into();
 93    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 94    thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
 95    cx.run_until_parked();
 96    let mut pending_completions = fake_model.pending_completions();
 97    assert_eq!(
 98        pending_completions.len(),
 99        1,
100        "unexpected pending completions: {:?}",
101        pending_completions
102    );
103
104    let pending_completion = pending_completions.pop().unwrap();
105    assert_eq!(pending_completion.messages[0].role, Role::System);
106
107    let system_message = &pending_completion.messages[0];
108    let system_prompt = system_message.content[0].to_str().unwrap();
109    assert!(
110        system_prompt.contains("test-shell"),
111        "unexpected system message: {:?}",
112        system_message
113    );
114    assert!(
115        system_prompt.contains("## Fixing Diagnostics"),
116        "unexpected system message: {:?}",
117        system_message
118    );
119}
120
121#[gpui::test]
122#[ignore = "can't run on CI yet"]
123async fn test_basic_tool_calls(cx: &mut TestAppContext) {
124    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
125
126    // Test a tool call that's likely to complete *before* streaming stops.
127    let events = thread
128        .update(cx, |thread, cx| {
129            thread.add_tool(EchoTool);
130            thread.send(
131                model.clone(),
132                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
133                cx,
134            )
135        })
136        .collect()
137        .await;
138    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
139
140    // Test a tool calls that's likely to complete *after* streaming stops.
141    let events = thread
142        .update(cx, |thread, cx| {
143            thread.remove_tool(&AgentTool::name(&EchoTool));
144            thread.add_tool(DelayTool);
145            thread.send(
146                model.clone(),
147                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
148                cx,
149            )
150        })
151        .collect()
152        .await;
153    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
154    thread.update(cx, |thread, _cx| {
155        assert!(thread
156            .messages()
157            .last()
158            .unwrap()
159            .content
160            .iter()
161            .any(|content| {
162                if let MessageContent::Text(text) = content {
163                    text.contains("Ding")
164                } else {
165                    false
166                }
167            }));
168    });
169}
170
171#[gpui::test]
172#[ignore = "can't run on CI yet"]
173async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
174    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
175
176    // Test a tool call that's likely to complete *before* streaming stops.
177    let mut events = thread.update(cx, |thread, cx| {
178        thread.add_tool(WordListTool);
179        thread.send(model.clone(), "Test the word_list tool.", cx)
180    });
181
182    let mut saw_partial_tool_use = false;
183    while let Some(event) = events.next().await {
184        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
185            thread.update(cx, |thread, _cx| {
186                // Look for a tool use in the thread's last message
187                let last_content = thread.messages().last().unwrap().content.last().unwrap();
188                if let MessageContent::ToolUse(last_tool_use) = last_content {
189                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
190                    if tool_call.status == acp::ToolCallStatus::Pending {
191                        if !last_tool_use.is_input_complete
192                            && last_tool_use.input.get("g").is_none()
193                        {
194                            saw_partial_tool_use = true;
195                        }
196                    } else {
197                        last_tool_use
198                            .input
199                            .get("a")
200                            .expect("'a' has streamed because input is now complete");
201                        last_tool_use
202                            .input
203                            .get("g")
204                            .expect("'g' has streamed because input is now complete");
205                    }
206                } else {
207                    panic!("last content should be a tool use");
208                }
209            });
210        }
211    }
212
213    assert!(
214        saw_partial_tool_use,
215        "should see at least one partially streamed tool use in the history"
216    );
217}
218
219#[gpui::test]
220async fn test_tool_authorization(cx: &mut TestAppContext) {
221    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
222    let fake_model = model.as_fake();
223
224    let mut events = thread.update(cx, |thread, cx| {
225        thread.add_tool(ToolRequiringPermission);
226        thread.send(model.clone(), "abc", cx)
227    });
228    cx.run_until_parked();
229    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
230        LanguageModelToolUse {
231            id: "tool_id_1".into(),
232            name: ToolRequiringPermission.name().into(),
233            raw_input: "{}".into(),
234            input: json!({}),
235            is_input_complete: true,
236        },
237    ));
238    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
239        LanguageModelToolUse {
240            id: "tool_id_2".into(),
241            name: ToolRequiringPermission.name().into(),
242            raw_input: "{}".into(),
243            input: json!({}),
244            is_input_complete: true,
245        },
246    ));
247    fake_model.end_last_completion_stream();
248    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
249    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
250
251    // Approve the first
252    tool_call_auth_1
253        .response
254        .send(tool_call_auth_1.options[1].id.clone())
255        .unwrap();
256    cx.run_until_parked();
257
258    // Reject the second
259    tool_call_auth_2
260        .response
261        .send(tool_call_auth_1.options[2].id.clone())
262        .unwrap();
263    cx.run_until_parked();
264
265    let completion = fake_model.pending_completions().pop().unwrap();
266    let message = completion.messages.last().unwrap();
267    assert_eq!(
268        message.content,
269        vec![
270            MessageContent::ToolResult(LanguageModelToolResult {
271                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
272                tool_name: ToolRequiringPermission.name().into(),
273                is_error: false,
274                content: "Allowed".into(),
275                output: Some("Allowed".into())
276            }),
277            MessageContent::ToolResult(LanguageModelToolResult {
278                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
279                tool_name: ToolRequiringPermission.name().into(),
280                is_error: true,
281                content: "Permission to run tool denied by user".into(),
282                output: None
283            })
284        ]
285    );
286}
287
288#[gpui::test]
289async fn test_tool_hallucination(cx: &mut TestAppContext) {
290    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
291    let fake_model = model.as_fake();
292
293    let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
294    cx.run_until_parked();
295    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
296        LanguageModelToolUse {
297            id: "tool_id_1".into(),
298            name: "nonexistent_tool".into(),
299            raw_input: "{}".into(),
300            input: json!({}),
301            is_input_complete: true,
302        },
303    ));
304    fake_model.end_last_completion_stream();
305
306    let tool_call = expect_tool_call(&mut events).await;
307    assert_eq!(tool_call.title, "nonexistent_tool");
308    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
309    let update = expect_tool_call_update(&mut events).await;
310    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
311}
312
313async fn expect_tool_call(
314    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
315) -> acp::ToolCall {
316    let event = events
317        .next()
318        .await
319        .expect("no tool call authorization event received")
320        .unwrap();
321    match event {
322        AgentResponseEvent::ToolCall(tool_call) => return tool_call,
323        event => {
324            panic!("Unexpected event {event:?}");
325        }
326    }
327}
328
329async fn expect_tool_call_update(
330    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
331) -> acp::ToolCallUpdate {
332    let event = events
333        .next()
334        .await
335        .expect("no tool call authorization event received")
336        .unwrap();
337    match event {
338        AgentResponseEvent::ToolCallUpdate(tool_call_update) => return tool_call_update,
339        event => {
340            panic!("Unexpected event {event:?}");
341        }
342    }
343}
344
345async fn next_tool_call_authorization(
346    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
347) -> ToolCallAuthorization {
348    loop {
349        let event = events
350            .next()
351            .await
352            .expect("no tool call authorization event received")
353            .unwrap();
354        if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
355            let permission_kinds = tool_call_authorization
356                .options
357                .iter()
358                .map(|o| o.kind)
359                .collect::<Vec<_>>();
360            assert_eq!(
361                permission_kinds,
362                vec![
363                    acp::PermissionOptionKind::AllowAlways,
364                    acp::PermissionOptionKind::AllowOnce,
365                    acp::PermissionOptionKind::RejectOnce,
366                ]
367            );
368            return tool_call_authorization;
369        }
370    }
371}
372
373#[gpui::test]
374#[ignore = "can't run on CI yet"]
375async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
376    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
377
378    // Test concurrent tool calls with different delay times
379    let events = thread
380        .update(cx, |thread, cx| {
381            thread.add_tool(DelayTool);
382            thread.send(
383                model.clone(),
384                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
385                cx,
386            )
387        })
388        .collect()
389        .await;
390
391    let stop_reasons = stop_events(events);
392    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
393
394    thread.update(cx, |thread, _cx| {
395        let last_message = thread.messages().last().unwrap();
396        let text = last_message
397            .content
398            .iter()
399            .filter_map(|content| {
400                if let MessageContent::Text(text) = content {
401                    Some(text.as_str())
402                } else {
403                    None
404                }
405            })
406            .collect::<String>();
407
408        assert!(text.contains("Ding"));
409    });
410}
411
412#[gpui::test]
413#[ignore = "can't run on CI yet"]
414async fn test_cancellation(cx: &mut TestAppContext) {
415    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
416
417    let mut events = thread.update(cx, |thread, cx| {
418        thread.add_tool(InfiniteTool);
419        thread.add_tool(EchoTool);
420        thread.send(
421            model.clone(),
422            "Call the echo tool and then call the infinite tool, then explain their output",
423            cx,
424        )
425    });
426
427    // Wait until both tools are called.
428    let mut expected_tool_calls = vec!["echo", "infinite"];
429    let mut echo_id = None;
430    let mut echo_completed = false;
431    while let Some(event) = events.next().await {
432        match event.unwrap() {
433            AgentResponseEvent::ToolCall(tool_call) => {
434                assert_eq!(tool_call.title, expected_tool_calls.remove(0));
435                if tool_call.title == "echo" {
436                    echo_id = Some(tool_call.id);
437                }
438            }
439            AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate {
440                id,
441                fields:
442                    acp::ToolCallUpdateFields {
443                        status: Some(acp::ToolCallStatus::Completed),
444                        ..
445                    },
446            }) if Some(&id) == echo_id.as_ref() => {
447                echo_completed = true;
448            }
449            _ => {}
450        }
451
452        if expected_tool_calls.is_empty() && echo_completed {
453            break;
454        }
455    }
456
457    // Cancel the current send and ensure that the event stream is closed, even
458    // if one of the tools is still running.
459    thread.update(cx, |thread, _cx| thread.cancel());
460    events.collect::<Vec<_>>().await;
461
462    // Ensure we can still send a new message after cancellation.
463    let events = thread
464        .update(cx, |thread, cx| {
465            thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
466        })
467        .collect::<Vec<_>>()
468        .await;
469    thread.update(cx, |thread, _cx| {
470        assert_eq!(
471            thread.messages().last().unwrap().content,
472            vec![MessageContent::Text("Hello".to_string())]
473        );
474    });
475    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
476}
477
478#[gpui::test]
479async fn test_refusal(cx: &mut TestAppContext) {
480    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
481    let fake_model = model.as_fake();
482
483    let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
484    cx.run_until_parked();
485    thread.read_with(cx, |thread, _| {
486        assert_eq!(
487            thread.to_markdown(),
488            indoc! {"
489                ## user
490                Hello
491            "}
492        );
493    });
494
495    fake_model.send_last_completion_stream_text_chunk("Hey!");
496    cx.run_until_parked();
497    thread.read_with(cx, |thread, _| {
498        assert_eq!(
499            thread.to_markdown(),
500            indoc! {"
501                ## user
502                Hello
503                ## assistant
504                Hey!
505            "}
506        );
507    });
508
509    // If the model refuses to continue, the thread should remove all the messages after the last user message.
510    fake_model
511        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
512    let events = events.collect::<Vec<_>>().await;
513    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
514    thread.read_with(cx, |thread, _| {
515        assert_eq!(thread.to_markdown(), "");
516    });
517}
518
519#[gpui::test]
520async fn test_agent_connection(cx: &mut TestAppContext) {
521    cx.update(settings::init);
522    let templates = Templates::new();
523
524    // Initialize language model system with test provider
525    cx.update(|cx| {
526        gpui_tokio::init(cx);
527        client::init_settings(cx);
528
529        let http_client = FakeHttpClient::with_404_response();
530        let clock = Arc::new(clock::FakeSystemClock::new());
531        let client = Client::new(clock, http_client, cx);
532        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
533        language_model::init(client.clone(), cx);
534        language_models::init(user_store.clone(), client.clone(), cx);
535        Project::init_settings(cx);
536        LanguageModelRegistry::test(cx);
537    });
538    cx.executor().forbid_parking();
539
540    // Create a project for new_thread
541    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
542    fake_fs.insert_tree(path!("/test"), json!({})).await;
543    let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
544    let cwd = Path::new("/test");
545
546    // Create agent and connection
547    let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
548        .await
549        .unwrap();
550    let connection = NativeAgentConnection(agent.clone());
551
552    // Test model_selector returns Some
553    let selector_opt = connection.model_selector();
554    assert!(
555        selector_opt.is_some(),
556        "agent2 should always support ModelSelector"
557    );
558    let selector = selector_opt.unwrap();
559
560    // Test list_models
561    let listed_models = cx
562        .update(|cx| {
563            let mut async_cx = cx.to_async();
564            selector.list_models(&mut async_cx)
565        })
566        .await
567        .expect("list_models should succeed");
568    assert!(!listed_models.is_empty(), "should have at least one model");
569    assert_eq!(listed_models[0].id().0, "fake");
570
571    // Create a thread using new_thread
572    let connection_rc = Rc::new(connection.clone());
573    let acp_thread = cx
574        .update(|cx| {
575            let mut async_cx = cx.to_async();
576            connection_rc.new_thread(project, cwd, &mut async_cx)
577        })
578        .await
579        .expect("new_thread should succeed");
580
581    // Get the session_id from the AcpThread
582    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
583
584    // Test selected_model returns the default
585    let model = cx
586        .update(|cx| {
587            let mut async_cx = cx.to_async();
588            selector.selected_model(&session_id, &mut async_cx)
589        })
590        .await
591        .expect("selected_model should succeed");
592    let model = model.as_fake();
593    assert_eq!(model.id().0, "fake", "should return default model");
594
595    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
596    cx.run_until_parked();
597    model.send_last_completion_stream_text_chunk("def");
598    cx.run_until_parked();
599    acp_thread.read_with(cx, |thread, cx| {
600        assert_eq!(
601            thread.to_markdown(cx),
602            indoc! {"
603                ## User
604
605                abc
606
607                ## Assistant
608
609                def
610
611            "}
612        )
613    });
614
615    // Test cancel
616    cx.update(|cx| connection.cancel(&session_id, cx));
617    request.await.expect("prompt should fail gracefully");
618
619    // Ensure that dropping the ACP thread causes the native thread to be
620    // dropped as well.
621    cx.update(|_| drop(acp_thread));
622    let result = cx
623        .update(|cx| {
624            connection.prompt(
625                acp::PromptRequest {
626                    session_id: session_id.clone(),
627                    prompt: vec!["ghi".into()],
628                },
629                cx,
630            )
631        })
632        .await;
633    assert_eq!(
634        result.as_ref().unwrap_err().to_string(),
635        "Session not found",
636        "unexpected result: {:?}",
637        result
638    );
639}
640
641#[gpui::test]
642async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
643    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
644    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
645    let fake_model = model.as_fake();
646
647    let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx));
648    cx.run_until_parked();
649
650    // Simulate streaming partial input.
651    let input = json!({});
652    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
653        LanguageModelToolUse {
654            id: "1".into(),
655            name: ThinkingTool.name().into(),
656            raw_input: input.to_string(),
657            input,
658            is_input_complete: false,
659        },
660    ));
661
662    // Input streaming completed
663    let input = json!({ "content": "Thinking hard!" });
664    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
665        LanguageModelToolUse {
666            id: "1".into(),
667            name: ThinkingTool.name().into(),
668            raw_input: input.to_string(),
669            input,
670            is_input_complete: true,
671        },
672    ));
673    fake_model.end_last_completion_stream();
674    cx.run_until_parked();
675
676    let tool_call = expect_tool_call(&mut events).await;
677    assert_eq!(
678        tool_call,
679        acp::ToolCall {
680            id: acp::ToolCallId("1".into()),
681            title: "thinking".into(),
682            kind: acp::ToolKind::Think,
683            status: acp::ToolCallStatus::Pending,
684            content: vec![],
685            locations: vec![],
686            raw_input: Some(json!({})),
687            raw_output: None,
688        }
689    );
690    let update = expect_tool_call_update(&mut events).await;
691    assert_eq!(
692        update,
693        acp::ToolCallUpdate {
694            id: acp::ToolCallId("1".into()),
695            fields: acp::ToolCallUpdateFields {
696                title: Some("Thinking".into()),
697                kind: Some(acp::ToolKind::Think),
698                raw_input: Some(json!({ "content": "Thinking hard!" })),
699                ..Default::default()
700            },
701        }
702    );
703    let update = expect_tool_call_update(&mut events).await;
704    assert_eq!(
705        update,
706        acp::ToolCallUpdate {
707            id: acp::ToolCallId("1".into()),
708            fields: acp::ToolCallUpdateFields {
709                status: Some(acp::ToolCallStatus::InProgress),
710                ..Default::default()
711            },
712        }
713    );
714    let update = expect_tool_call_update(&mut events).await;
715    assert_eq!(
716        update,
717        acp::ToolCallUpdate {
718            id: acp::ToolCallId("1".into()),
719            fields: acp::ToolCallUpdateFields {
720                content: Some(vec!["Thinking hard!".into()]),
721                ..Default::default()
722            },
723        }
724    );
725    let update = expect_tool_call_update(&mut events).await;
726    assert_eq!(
727        update,
728        acp::ToolCallUpdate {
729            id: acp::ToolCallId("1".into()),
730            fields: acp::ToolCallUpdateFields {
731                status: Some(acp::ToolCallStatus::Completed),
732                ..Default::default()
733            },
734        }
735    );
736}
737
738/// Filters out the stop events for asserting against in tests
739fn stop_events(
740    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
741) -> Vec<acp::StopReason> {
742    result_events
743        .into_iter()
744        .filter_map(|event| match event.unwrap() {
745            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
746            _ => None,
747        })
748        .collect()
749}
750
751struct ThreadTest {
752    model: Arc<dyn LanguageModel>,
753    thread: Entity<Thread>,
754    project_context: Rc<RefCell<ProjectContext>>,
755}
756
757enum TestModel {
758    Sonnet4,
759    Sonnet4Thinking,
760    Fake,
761}
762
763impl TestModel {
764    fn id(&self) -> LanguageModelId {
765        match self {
766            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
767            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
768            TestModel::Fake => unreachable!(),
769        }
770    }
771}
772
773async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
774    cx.executor().allow_parking();
775    cx.update(|cx| {
776        settings::init(cx);
777        Project::init_settings(cx);
778    });
779    let templates = Templates::new();
780
781    let fs = FakeFs::new(cx.background_executor.clone());
782    fs.insert_tree(path!("/test"), json!({})).await;
783    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
784
785    let model = cx
786        .update(|cx| {
787            gpui_tokio::init(cx);
788            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
789            cx.set_http_client(Arc::new(http_client));
790
791            client::init_settings(cx);
792            let client = Client::production(cx);
793            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
794            language_model::init(client.clone(), cx);
795            language_models::init(user_store.clone(), client.clone(), cx);
796
797            if let TestModel::Fake = model {
798                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
799            } else {
800                let model_id = model.id();
801                let models = LanguageModelRegistry::read_global(cx);
802                let model = models
803                    .available_models(cx)
804                    .find(|model| model.id() == model_id)
805                    .unwrap();
806
807                let provider = models.provider(&model.provider_id()).unwrap();
808                let authenticated = provider.authenticate(cx);
809
810                cx.spawn(async move |_cx| {
811                    authenticated.await.unwrap();
812                    model
813                })
814            }
815        })
816        .await;
817
818    let project_context = Rc::new(RefCell::new(ProjectContext::default()));
819    let action_log = cx.new(|_| ActionLog::new(project.clone()));
820    let thread = cx.new(|_| {
821        Thread::new(
822            project,
823            project_context.clone(),
824            action_log,
825            templates,
826            model.clone(),
827        )
828    });
829    ThreadTest {
830        model,
831        thread,
832        project_context,
833    }
834}
835
836#[cfg(test)]
837#[ctor::ctor]
838fn init_logger() {
839    if std::env::var("RUST_LOG").is_ok() {
840        env_logger::init();
841    }
842}