mod.rs

  1use super::*;
  2use acp_thread::AgentConnection;
  3use agent_client_protocol::{self as acp};
  4use anyhow::Result;
  5use assistant_tool::ActionLog;
  6use client::{Client, UserStore};
  7use fs::FakeFs;
  8use futures::channel::mpsc::UnboundedReceiver;
  9use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
 10use indoc::indoc;
 11use language_model::{
 12    fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
 13    LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
 14    LanguageModelToolUse, MessageContent, Role, StopReason,
 15};
 16use project::Project;
 17use prompt_store::ProjectContext;
 18use reqwest_client::ReqwestClient;
 19use schemars::JsonSchema;
 20use serde::{Deserialize, Serialize};
 21use serde_json::json;
 22use smol::stream::StreamExt;
 23use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
 24use util::path;
 25
 26mod test_tools;
 27use test_tools::*;
 28
 29#[gpui::test]
 30#[ignore = "can't run on CI yet"]
 31async fn test_echo(cx: &mut TestAppContext) {
 32    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
 33
 34    let events = thread
 35        .update(cx, |thread, cx| {
 36            thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
 37        })
 38        .collect()
 39        .await;
 40    thread.update(cx, |thread, _cx| {
 41        assert_eq!(
 42            thread.messages().last().unwrap().content,
 43            vec![MessageContent::Text("Hello".to_string())]
 44        );
 45    });
 46    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 47}
 48
 49#[gpui::test]
 50#[ignore = "can't run on CI yet"]
 51async fn test_thinking(cx: &mut TestAppContext) {
 52    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
 53
 54    let events = thread
 55        .update(cx, |thread, cx| {
 56            thread.send(
 57                model.clone(),
 58                indoc! {"
 59                    Testing:
 60
 61                    Generate a thinking step where you just think the word 'Think',
 62                    and have your final answer be 'Hello'
 63                "},
 64                cx,
 65            )
 66        })
 67        .collect()
 68        .await;
 69    thread.update(cx, |thread, _cx| {
 70        assert_eq!(
 71            thread.messages().last().unwrap().to_markdown(),
 72            indoc! {"
 73                ## assistant
 74                <think>Think</think>
 75                Hello
 76            "}
 77        )
 78    });
 79    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 80}
 81
 82#[gpui::test]
 83async fn test_system_prompt(cx: &mut TestAppContext) {
 84    let ThreadTest {
 85        model,
 86        thread,
 87        project_context,
 88        ..
 89    } = setup(cx, TestModel::Fake).await;
 90    let fake_model = model.as_fake();
 91
 92    project_context.borrow_mut().shell = "test-shell".into();
 93    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 94    thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
 95    cx.run_until_parked();
 96    let mut pending_completions = fake_model.pending_completions();
 97    assert_eq!(
 98        pending_completions.len(),
 99        1,
100        "unexpected pending completions: {:?}",
101        pending_completions
102    );
103
104    let pending_completion = pending_completions.pop().unwrap();
105    assert_eq!(pending_completion.messages[0].role, Role::System);
106
107    let system_message = &pending_completion.messages[0];
108    let system_prompt = system_message.content[0].to_str().unwrap();
109    assert!(
110        system_prompt.contains("test-shell"),
111        "unexpected system message: {:?}",
112        system_message
113    );
114    assert!(
115        system_prompt.contains("## Fixing Diagnostics"),
116        "unexpected system message: {:?}",
117        system_message
118    );
119}
120
121#[gpui::test]
122#[ignore = "can't run on CI yet"]
123async fn test_basic_tool_calls(cx: &mut TestAppContext) {
124    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
125
126    // Test a tool call that's likely to complete *before* streaming stops.
127    let events = thread
128        .update(cx, |thread, cx| {
129            thread.add_tool(EchoTool);
130            thread.send(
131                model.clone(),
132                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
133                cx,
134            )
135        })
136        .collect()
137        .await;
138    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
139
140    // Test a tool calls that's likely to complete *after* streaming stops.
141    let events = thread
142        .update(cx, |thread, cx| {
143            thread.remove_tool(&AgentTool::name(&EchoTool));
144            thread.add_tool(DelayTool);
145            thread.send(
146                model.clone(),
147                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
148                cx,
149            )
150        })
151        .collect()
152        .await;
153    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
154    thread.update(cx, |thread, _cx| {
155        assert!(thread
156            .messages()
157            .last()
158            .unwrap()
159            .content
160            .iter()
161            .any(|content| {
162                if let MessageContent::Text(text) = content {
163                    text.contains("Ding")
164                } else {
165                    false
166                }
167            }));
168    });
169}
170
171#[gpui::test]
172#[ignore = "can't run on CI yet"]
173async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
174    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
175
176    // Test a tool call that's likely to complete *before* streaming stops.
177    let mut events = thread.update(cx, |thread, cx| {
178        thread.add_tool(WordListTool);
179        thread.send(model.clone(), "Test the word_list tool.", cx)
180    });
181
182    let mut saw_partial_tool_use = false;
183    while let Some(event) = events.next().await {
184        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
185            thread.update(cx, |thread, _cx| {
186                // Look for a tool use in the thread's last message
187                let last_content = thread.messages().last().unwrap().content.last().unwrap();
188                if let MessageContent::ToolUse(last_tool_use) = last_content {
189                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
190                    if tool_call.status == acp::ToolCallStatus::Pending {
191                        if !last_tool_use.is_input_complete
192                            && last_tool_use.input.get("g").is_none()
193                        {
194                            saw_partial_tool_use = true;
195                        }
196                    } else {
197                        last_tool_use
198                            .input
199                            .get("a")
200                            .expect("'a' has streamed because input is now complete");
201                        last_tool_use
202                            .input
203                            .get("g")
204                            .expect("'g' has streamed because input is now complete");
205                    }
206                } else {
207                    panic!("last content should be a tool use");
208                }
209            });
210        }
211    }
212
213    assert!(
214        saw_partial_tool_use,
215        "should see at least one partially streamed tool use in the history"
216    );
217}
218
219#[gpui::test]
220async fn test_tool_authorization(cx: &mut TestAppContext) {
221    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
222    let fake_model = model.as_fake();
223
224    let mut events = thread.update(cx, |thread, cx| {
225        thread.add_tool(ToolRequiringPermission);
226        thread.send(model.clone(), "abc", cx)
227    });
228    cx.run_until_parked();
229    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
230        LanguageModelToolUse {
231            id: "tool_id_1".into(),
232            name: ToolRequiringPermission.name().into(),
233            raw_input: "{}".into(),
234            input: json!({}),
235            is_input_complete: true,
236        },
237    ));
238    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
239        LanguageModelToolUse {
240            id: "tool_id_2".into(),
241            name: ToolRequiringPermission.name().into(),
242            raw_input: "{}".into(),
243            input: json!({}),
244            is_input_complete: true,
245        },
246    ));
247    fake_model.end_last_completion_stream();
248    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
249    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
250
251    // Approve the first
252    tool_call_auth_1
253        .response
254        .send(tool_call_auth_1.options[1].id.clone())
255        .unwrap();
256    cx.run_until_parked();
257
258    // Reject the second
259    tool_call_auth_2
260        .response
261        .send(tool_call_auth_1.options[2].id.clone())
262        .unwrap();
263    cx.run_until_parked();
264
265    let completion = fake_model.pending_completions().pop().unwrap();
266    let message = completion.messages.last().unwrap();
267    assert_eq!(
268        message.content,
269        vec![
270            MessageContent::ToolResult(LanguageModelToolResult {
271                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
272                tool_name: ToolRequiringPermission.name().into(),
273                is_error: false,
274                content: "Allowed".into(),
275                output: Some("Allowed".into())
276            }),
277            MessageContent::ToolResult(LanguageModelToolResult {
278                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
279                tool_name: ToolRequiringPermission.name().into(),
280                is_error: true,
281                content: "Permission to run tool denied by user".into(),
282                output: None
283            })
284        ]
285    );
286}
287
288#[gpui::test]
289async fn test_tool_hallucination(cx: &mut TestAppContext) {
290    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
291    let fake_model = model.as_fake();
292
293    let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
294    cx.run_until_parked();
295    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
296        LanguageModelToolUse {
297            id: "tool_id_1".into(),
298            name: "nonexistent_tool".into(),
299            raw_input: "{}".into(),
300            input: json!({}),
301            is_input_complete: true,
302        },
303    ));
304    fake_model.end_last_completion_stream();
305
306    let tool_call = expect_tool_call(&mut events).await;
307    assert_eq!(tool_call.title, "nonexistent_tool");
308    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
309    let update = expect_tool_call_update_fields(&mut events).await;
310    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
311}
312
313async fn expect_tool_call(
314    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
315) -> acp::ToolCall {
316    let event = events
317        .next()
318        .await
319        .expect("no tool call authorization event received")
320        .unwrap();
321    match event {
322        AgentResponseEvent::ToolCall(tool_call) => return tool_call,
323        event => {
324            panic!("Unexpected event {event:?}");
325        }
326    }
327}
328
329async fn expect_tool_call_update_fields(
330    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
331) -> acp::ToolCallUpdate {
332    let event = events
333        .next()
334        .await
335        .expect("no tool call authorization event received")
336        .unwrap();
337    match event {
338        AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
339            return update
340        }
341        event => {
342            panic!("Unexpected event {event:?}");
343        }
344    }
345}
346
347async fn next_tool_call_authorization(
348    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
349) -> ToolCallAuthorization {
350    loop {
351        let event = events
352            .next()
353            .await
354            .expect("no tool call authorization event received")
355            .unwrap();
356        if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
357            let permission_kinds = tool_call_authorization
358                .options
359                .iter()
360                .map(|o| o.kind)
361                .collect::<Vec<_>>();
362            assert_eq!(
363                permission_kinds,
364                vec![
365                    acp::PermissionOptionKind::AllowAlways,
366                    acp::PermissionOptionKind::AllowOnce,
367                    acp::PermissionOptionKind::RejectOnce,
368                ]
369            );
370            return tool_call_authorization;
371        }
372    }
373}
374
375#[gpui::test]
376#[ignore = "can't run on CI yet"]
377async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
378    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
379
380    // Test concurrent tool calls with different delay times
381    let events = thread
382        .update(cx, |thread, cx| {
383            thread.add_tool(DelayTool);
384            thread.send(
385                model.clone(),
386                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
387                cx,
388            )
389        })
390        .collect()
391        .await;
392
393    let stop_reasons = stop_events(events);
394    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
395
396    thread.update(cx, |thread, _cx| {
397        let last_message = thread.messages().last().unwrap();
398        let text = last_message
399            .content
400            .iter()
401            .filter_map(|content| {
402                if let MessageContent::Text(text) = content {
403                    Some(text.as_str())
404                } else {
405                    None
406                }
407            })
408            .collect::<String>();
409
410        assert!(text.contains("Ding"));
411    });
412}
413
414#[gpui::test]
415#[ignore = "can't run on CI yet"]
416async fn test_cancellation(cx: &mut TestAppContext) {
417    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
418
419    let mut events = thread.update(cx, |thread, cx| {
420        thread.add_tool(InfiniteTool);
421        thread.add_tool(EchoTool);
422        thread.send(
423            model.clone(),
424            "Call the echo tool and then call the infinite tool, then explain their output",
425            cx,
426        )
427    });
428
429    // Wait until both tools are called.
430    let mut expected_tools = vec!["Echo", "Infinite Tool"];
431    let mut echo_id = None;
432    let mut echo_completed = false;
433    while let Some(event) = events.next().await {
434        match event.unwrap() {
435            AgentResponseEvent::ToolCall(tool_call) => {
436                assert_eq!(tool_call.title, expected_tools.remove(0));
437                if tool_call.title == "Echo" {
438                    echo_id = Some(tool_call.id);
439                }
440            }
441            AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
442                acp::ToolCallUpdate {
443                    id,
444                    fields:
445                        acp::ToolCallUpdateFields {
446                            status: Some(acp::ToolCallStatus::Completed),
447                            ..
448                        },
449                },
450            )) if Some(&id) == echo_id.as_ref() => {
451                echo_completed = true;
452            }
453            _ => {}
454        }
455
456        if expected_tools.is_empty() && echo_completed {
457            break;
458        }
459    }
460
461    // Cancel the current send and ensure that the event stream is closed, even
462    // if one of the tools is still running.
463    thread.update(cx, |thread, _cx| thread.cancel());
464    events.collect::<Vec<_>>().await;
465
466    // Ensure we can still send a new message after cancellation.
467    let events = thread
468        .update(cx, |thread, cx| {
469            thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
470        })
471        .collect::<Vec<_>>()
472        .await;
473    thread.update(cx, |thread, _cx| {
474        assert_eq!(
475            thread.messages().last().unwrap().content,
476            vec![MessageContent::Text("Hello".to_string())]
477        );
478    });
479    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
480}
481
482#[gpui::test]
483async fn test_refusal(cx: &mut TestAppContext) {
484    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
485    let fake_model = model.as_fake();
486
487    let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
488    cx.run_until_parked();
489    thread.read_with(cx, |thread, _| {
490        assert_eq!(
491            thread.to_markdown(),
492            indoc! {"
493                ## user
494                Hello
495            "}
496        );
497    });
498
499    fake_model.send_last_completion_stream_text_chunk("Hey!");
500    cx.run_until_parked();
501    thread.read_with(cx, |thread, _| {
502        assert_eq!(
503            thread.to_markdown(),
504            indoc! {"
505                ## user
506                Hello
507                ## assistant
508                Hey!
509            "}
510        );
511    });
512
513    // If the model refuses to continue, the thread should remove all the messages after the last user message.
514    fake_model
515        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
516    let events = events.collect::<Vec<_>>().await;
517    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
518    thread.read_with(cx, |thread, _| {
519        assert_eq!(thread.to_markdown(), "");
520    });
521}
522
523#[gpui::test]
524async fn test_agent_connection(cx: &mut TestAppContext) {
525    cx.update(settings::init);
526    let templates = Templates::new();
527
528    // Initialize language model system with test provider
529    cx.update(|cx| {
530        gpui_tokio::init(cx);
531        client::init_settings(cx);
532
533        let http_client = FakeHttpClient::with_404_response();
534        let clock = Arc::new(clock::FakeSystemClock::new());
535        let client = Client::new(clock, http_client, cx);
536        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
537        language_model::init(client.clone(), cx);
538        language_models::init(user_store.clone(), client.clone(), cx);
539        Project::init_settings(cx);
540        LanguageModelRegistry::test(cx);
541    });
542    cx.executor().forbid_parking();
543
544    // Create a project for new_thread
545    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
546    fake_fs.insert_tree(path!("/test"), json!({})).await;
547    let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
548    let cwd = Path::new("/test");
549
550    // Create agent and connection
551    let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
552        .await
553        .unwrap();
554    let connection = NativeAgentConnection(agent.clone());
555
556    // Test model_selector returns Some
557    let selector_opt = connection.model_selector();
558    assert!(
559        selector_opt.is_some(),
560        "agent2 should always support ModelSelector"
561    );
562    let selector = selector_opt.unwrap();
563
564    // Test list_models
565    let listed_models = cx
566        .update(|cx| {
567            let mut async_cx = cx.to_async();
568            selector.list_models(&mut async_cx)
569        })
570        .await
571        .expect("list_models should succeed");
572    assert!(!listed_models.is_empty(), "should have at least one model");
573    assert_eq!(listed_models[0].id().0, "fake");
574
575    // Create a thread using new_thread
576    let connection_rc = Rc::new(connection.clone());
577    let acp_thread = cx
578        .update(|cx| {
579            let mut async_cx = cx.to_async();
580            connection_rc.new_thread(project, cwd, &mut async_cx)
581        })
582        .await
583        .expect("new_thread should succeed");
584
585    // Get the session_id from the AcpThread
586    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
587
588    // Test selected_model returns the default
589    let model = cx
590        .update(|cx| {
591            let mut async_cx = cx.to_async();
592            selector.selected_model(&session_id, &mut async_cx)
593        })
594        .await
595        .expect("selected_model should succeed");
596    let model = model.as_fake();
597    assert_eq!(model.id().0, "fake", "should return default model");
598
599    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
600    cx.run_until_parked();
601    model.send_last_completion_stream_text_chunk("def");
602    cx.run_until_parked();
603    acp_thread.read_with(cx, |thread, cx| {
604        assert_eq!(
605            thread.to_markdown(cx),
606            indoc! {"
607                ## User
608
609                abc
610
611                ## Assistant
612
613                def
614
615            "}
616        )
617    });
618
619    // Test cancel
620    cx.update(|cx| connection.cancel(&session_id, cx));
621    request.await.expect("prompt should fail gracefully");
622
623    // Ensure that dropping the ACP thread causes the native thread to be
624    // dropped as well.
625    cx.update(|_| drop(acp_thread));
626    let result = cx
627        .update(|cx| {
628            connection.prompt(
629                acp::PromptRequest {
630                    session_id: session_id.clone(),
631                    prompt: vec!["ghi".into()],
632                },
633                cx,
634            )
635        })
636        .await;
637    assert_eq!(
638        result.as_ref().unwrap_err().to_string(),
639        "Session not found",
640        "unexpected result: {:?}",
641        result
642    );
643}
644
645#[gpui::test]
646async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
647    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
648    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
649    let fake_model = model.as_fake();
650
651    let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx));
652    cx.run_until_parked();
653
654    // Simulate streaming partial input.
655    let input = json!({});
656    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
657        LanguageModelToolUse {
658            id: "1".into(),
659            name: ThinkingTool.name().into(),
660            raw_input: input.to_string(),
661            input,
662            is_input_complete: false,
663        },
664    ));
665
666    // Input streaming completed
667    let input = json!({ "content": "Thinking hard!" });
668    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
669        LanguageModelToolUse {
670            id: "1".into(),
671            name: "thinking".into(),
672            raw_input: input.to_string(),
673            input,
674            is_input_complete: true,
675        },
676    ));
677    fake_model.end_last_completion_stream();
678    cx.run_until_parked();
679
680    let tool_call = expect_tool_call(&mut events).await;
681    assert_eq!(
682        tool_call,
683        acp::ToolCall {
684            id: acp::ToolCallId("1".into()),
685            title: "Thinking".into(),
686            kind: acp::ToolKind::Think,
687            status: acp::ToolCallStatus::Pending,
688            content: vec![],
689            locations: vec![],
690            raw_input: Some(json!({})),
691            raw_output: None,
692        }
693    );
694    let update = expect_tool_call_update_fields(&mut events).await;
695    assert_eq!(
696        update,
697        acp::ToolCallUpdate {
698            id: acp::ToolCallId("1".into()),
699            fields: acp::ToolCallUpdateFields {
700                title: Some("Thinking".into()),
701                kind: Some(acp::ToolKind::Think),
702                raw_input: Some(json!({ "content": "Thinking hard!" })),
703                ..Default::default()
704            },
705        }
706    );
707    let update = expect_tool_call_update_fields(&mut events).await;
708    assert_eq!(
709        update,
710        acp::ToolCallUpdate {
711            id: acp::ToolCallId("1".into()),
712            fields: acp::ToolCallUpdateFields {
713                status: Some(acp::ToolCallStatus::InProgress),
714                ..Default::default()
715            },
716        }
717    );
718    let update = expect_tool_call_update_fields(&mut events).await;
719    assert_eq!(
720        update,
721        acp::ToolCallUpdate {
722            id: acp::ToolCallId("1".into()),
723            fields: acp::ToolCallUpdateFields {
724                content: Some(vec!["Thinking hard!".into()]),
725                ..Default::default()
726            },
727        }
728    );
729    let update = expect_tool_call_update_fields(&mut events).await;
730    assert_eq!(
731        update,
732        acp::ToolCallUpdate {
733            id: acp::ToolCallId("1".into()),
734            fields: acp::ToolCallUpdateFields {
735                status: Some(acp::ToolCallStatus::Completed),
736                ..Default::default()
737            },
738        }
739    );
740}
741
742/// Filters out the stop events for asserting against in tests
743fn stop_events(
744    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
745) -> Vec<acp::StopReason> {
746    result_events
747        .into_iter()
748        .filter_map(|event| match event.unwrap() {
749            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
750            _ => None,
751        })
752        .collect()
753}
754
755struct ThreadTest {
756    model: Arc<dyn LanguageModel>,
757    thread: Entity<Thread>,
758    project_context: Rc<RefCell<ProjectContext>>,
759}
760
761enum TestModel {
762    Sonnet4,
763    Sonnet4Thinking,
764    Fake,
765}
766
767impl TestModel {
768    fn id(&self) -> LanguageModelId {
769        match self {
770            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
771            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
772            TestModel::Fake => unreachable!(),
773        }
774    }
775}
776
777async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
778    cx.executor().allow_parking();
779    cx.update(|cx| {
780        settings::init(cx);
781        Project::init_settings(cx);
782    });
783    let templates = Templates::new();
784
785    let fs = FakeFs::new(cx.background_executor.clone());
786    fs.insert_tree(path!("/test"), json!({})).await;
787    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
788
789    let model = cx
790        .update(|cx| {
791            gpui_tokio::init(cx);
792            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
793            cx.set_http_client(Arc::new(http_client));
794
795            client::init_settings(cx);
796            let client = Client::production(cx);
797            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
798            language_model::init(client.clone(), cx);
799            language_models::init(user_store.clone(), client.clone(), cx);
800
801            if let TestModel::Fake = model {
802                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
803            } else {
804                let model_id = model.id();
805                let models = LanguageModelRegistry::read_global(cx);
806                let model = models
807                    .available_models(cx)
808                    .find(|model| model.id() == model_id)
809                    .unwrap();
810
811                let provider = models.provider(&model.provider_id()).unwrap();
812                let authenticated = provider.authenticate(cx);
813
814                cx.spawn(async move |_cx| {
815                    authenticated.await.unwrap();
816                    model
817                })
818            }
819        })
820        .await;
821
822    let project_context = Rc::new(RefCell::new(ProjectContext::default()));
823    let action_log = cx.new(|_| ActionLog::new(project.clone()));
824    let thread = cx.new(|_| {
825        Thread::new(
826            project,
827            project_context.clone(),
828            action_log,
829            templates,
830            model.clone(),
831        )
832    });
833    ThreadTest {
834        model,
835        thread,
836        project_context,
837    }
838}
839
840#[cfg(test)]
841#[ctor::ctor]
842fn init_logger() {
843    if std::env::var("RUST_LOG").is_ok() {
844        env_logger::init();
845    }
846}