mod.rs

  1use super::*;
  2use acp_thread::AgentConnection;
  3use action_log::ActionLog;
  4use agent_client_protocol::{self as acp};
  5use anyhow::Result;
  6use client::{Client, UserStore};
  7use fs::{FakeFs, Fs};
  8use futures::channel::mpsc::UnboundedReceiver;
  9use gpui::{
 10    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
 11};
 12use indoc::indoc;
 13use language_model::{
 14    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
 15    LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
 16    StopReason, fake_provider::FakeLanguageModel,
 17};
 18use project::Project;
 19use prompt_store::ProjectContext;
 20use reqwest_client::ReqwestClient;
 21use schemars::JsonSchema;
 22use serde::{Deserialize, Serialize};
 23use serde_json::json;
 24use settings::SettingsStore;
 25use smol::stream::StreamExt;
 26use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
 27use util::path;
 28
 29mod test_tools;
 30use test_tools::*;
 31
 32#[gpui::test]
 33#[ignore = "can't run on CI yet"]
 34async fn test_echo(cx: &mut TestAppContext) {
 35    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 36
 37    let events = thread
 38        .update(cx, |thread, cx| {
 39            thread.send("Testing: Reply with 'Hello'", cx)
 40        })
 41        .collect()
 42        .await;
 43    thread.update(cx, |thread, _cx| {
 44        assert_eq!(
 45            thread.messages().last().unwrap().content,
 46            vec![MessageContent::Text("Hello".to_string())]
 47        );
 48    });
 49    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 50}
 51
 52#[gpui::test]
 53#[ignore = "can't run on CI yet"]
 54async fn test_thinking(cx: &mut TestAppContext) {
 55    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
 56
 57    let events = thread
 58        .update(cx, |thread, cx| {
 59            thread.send(
 60                indoc! {"
 61                    Testing:
 62
 63                    Generate a thinking step where you just think the word 'Think',
 64                    and have your final answer be 'Hello'
 65                "},
 66                cx,
 67            )
 68        })
 69        .collect()
 70        .await;
 71    thread.update(cx, |thread, _cx| {
 72        assert_eq!(
 73            thread.messages().last().unwrap().to_markdown(),
 74            indoc! {"
 75                ## assistant
 76                <think>Think</think>
 77                Hello
 78            "}
 79        )
 80    });
 81    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 82}
 83
 84#[gpui::test]
 85async fn test_system_prompt(cx: &mut TestAppContext) {
 86    let ThreadTest {
 87        model,
 88        thread,
 89        project_context,
 90        ..
 91    } = setup(cx, TestModel::Fake).await;
 92    let fake_model = model.as_fake();
 93
 94    project_context.borrow_mut().shell = "test-shell".into();
 95    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 96    thread.update(cx, |thread, cx| thread.send("abc", cx));
 97    cx.run_until_parked();
 98    let mut pending_completions = fake_model.pending_completions();
 99    assert_eq!(
100        pending_completions.len(),
101        1,
102        "unexpected pending completions: {:?}",
103        pending_completions
104    );
105
106    let pending_completion = pending_completions.pop().unwrap();
107    assert_eq!(pending_completion.messages[0].role, Role::System);
108
109    let system_message = &pending_completion.messages[0];
110    let system_prompt = system_message.content[0].to_str().unwrap();
111    assert!(
112        system_prompt.contains("test-shell"),
113        "unexpected system message: {:?}",
114        system_message
115    );
116    assert!(
117        system_prompt.contains("## Fixing Diagnostics"),
118        "unexpected system message: {:?}",
119        system_message
120    );
121}
122
123#[gpui::test]
124#[ignore = "can't run on CI yet"]
125async fn test_basic_tool_calls(cx: &mut TestAppContext) {
126    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
127
128    // Test a tool call that's likely to complete *before* streaming stops.
129    let events = thread
130        .update(cx, |thread, cx| {
131            thread.add_tool(EchoTool);
132            thread.send(
133                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
134                cx,
135            )
136        })
137        .collect()
138        .await;
139    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
140
141    // Test a tool calls that's likely to complete *after* streaming stops.
142    let events = thread
143        .update(cx, |thread, cx| {
144            thread.remove_tool(&AgentTool::name(&EchoTool));
145            thread.add_tool(DelayTool);
146            thread.send(
147                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
148                cx,
149            )
150        })
151        .collect()
152        .await;
153    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
154    thread.update(cx, |thread, _cx| {
155        assert!(
156            thread
157                .messages()
158                .last()
159                .unwrap()
160                .content
161                .iter()
162                .any(|content| {
163                    if let MessageContent::Text(text) = content {
164                        text.contains("Ding")
165                    } else {
166                        false
167                    }
168                })
169        );
170    });
171}
172
173#[gpui::test]
174#[ignore = "can't run on CI yet"]
175async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
176    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
177
178    // Test a tool call that's likely to complete *before* streaming stops.
179    let mut events = thread.update(cx, |thread, cx| {
180        thread.add_tool(WordListTool);
181        thread.send("Test the word_list tool.", cx)
182    });
183
184    let mut saw_partial_tool_use = false;
185    while let Some(event) = events.next().await {
186        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
187            thread.update(cx, |thread, _cx| {
188                // Look for a tool use in the thread's last message
189                let last_content = thread.messages().last().unwrap().content.last().unwrap();
190                if let MessageContent::ToolUse(last_tool_use) = last_content {
191                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
192                    if tool_call.status == acp::ToolCallStatus::Pending {
193                        if !last_tool_use.is_input_complete
194                            && last_tool_use.input.get("g").is_none()
195                        {
196                            saw_partial_tool_use = true;
197                        }
198                    } else {
199                        last_tool_use
200                            .input
201                            .get("a")
202                            .expect("'a' has streamed because input is now complete");
203                        last_tool_use
204                            .input
205                            .get("g")
206                            .expect("'g' has streamed because input is now complete");
207                    }
208                } else {
209                    panic!("last content should be a tool use");
210                }
211            });
212        }
213    }
214
215    assert!(
216        saw_partial_tool_use,
217        "should see at least one partially streamed tool use in the history"
218    );
219}
220
221#[gpui::test]
222async fn test_tool_authorization(cx: &mut TestAppContext) {
223    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
224    let fake_model = model.as_fake();
225
226    let mut events = thread.update(cx, |thread, cx| {
227        thread.add_tool(ToolRequiringPermission);
228        thread.send("abc", cx)
229    });
230    cx.run_until_parked();
231    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
232        LanguageModelToolUse {
233            id: "tool_id_1".into(),
234            name: ToolRequiringPermission.name().into(),
235            raw_input: "{}".into(),
236            input: json!({}),
237            is_input_complete: true,
238        },
239    ));
240    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
241        LanguageModelToolUse {
242            id: "tool_id_2".into(),
243            name: ToolRequiringPermission.name().into(),
244            raw_input: "{}".into(),
245            input: json!({}),
246            is_input_complete: true,
247        },
248    ));
249    fake_model.end_last_completion_stream();
250    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
251    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
252
253    // Approve the first
254    tool_call_auth_1
255        .response
256        .send(tool_call_auth_1.options[1].id.clone())
257        .unwrap();
258    cx.run_until_parked();
259
260    // Reject the second
261    tool_call_auth_2
262        .response
263        .send(tool_call_auth_1.options[2].id.clone())
264        .unwrap();
265    cx.run_until_parked();
266
267    let completion = fake_model.pending_completions().pop().unwrap();
268    let message = completion.messages.last().unwrap();
269    assert_eq!(
270        message.content,
271        vec![
272            MessageContent::ToolResult(LanguageModelToolResult {
273                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
274                tool_name: ToolRequiringPermission.name().into(),
275                is_error: false,
276                content: "Allowed".into(),
277                output: Some("Allowed".into())
278            }),
279            MessageContent::ToolResult(LanguageModelToolResult {
280                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
281                tool_name: ToolRequiringPermission.name().into(),
282                is_error: true,
283                content: "Permission to run tool denied by user".into(),
284                output: None
285            })
286        ]
287    );
288
289    // Simulate yet another tool call.
290    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
291        LanguageModelToolUse {
292            id: "tool_id_3".into(),
293            name: ToolRequiringPermission.name().into(),
294            raw_input: "{}".into(),
295            input: json!({}),
296            is_input_complete: true,
297        },
298    ));
299    fake_model.end_last_completion_stream();
300
301    // Respond by always allowing tools.
302    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
303    tool_call_auth_3
304        .response
305        .send(tool_call_auth_3.options[0].id.clone())
306        .unwrap();
307    cx.run_until_parked();
308    let completion = fake_model.pending_completions().pop().unwrap();
309    let message = completion.messages.last().unwrap();
310    assert_eq!(
311        message.content,
312        vec![MessageContent::ToolResult(LanguageModelToolResult {
313            tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
314            tool_name: ToolRequiringPermission.name().into(),
315            is_error: false,
316            content: "Allowed".into(),
317            output: Some("Allowed".into())
318        })]
319    );
320
321    // Simulate a final tool call, ensuring we don't trigger authorization.
322    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
323        LanguageModelToolUse {
324            id: "tool_id_4".into(),
325            name: ToolRequiringPermission.name().into(),
326            raw_input: "{}".into(),
327            input: json!({}),
328            is_input_complete: true,
329        },
330    ));
331    fake_model.end_last_completion_stream();
332    cx.run_until_parked();
333    let completion = fake_model.pending_completions().pop().unwrap();
334    let message = completion.messages.last().unwrap();
335    assert_eq!(
336        message.content,
337        vec![MessageContent::ToolResult(LanguageModelToolResult {
338            tool_use_id: "tool_id_4".into(),
339            tool_name: ToolRequiringPermission.name().into(),
340            is_error: false,
341            content: "Allowed".into(),
342            output: Some("Allowed".into())
343        })]
344    );
345}
346
347#[gpui::test]
348async fn test_tool_hallucination(cx: &mut TestAppContext) {
349    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
350    let fake_model = model.as_fake();
351
352    let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
353    cx.run_until_parked();
354    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
355        LanguageModelToolUse {
356            id: "tool_id_1".into(),
357            name: "nonexistent_tool".into(),
358            raw_input: "{}".into(),
359            input: json!({}),
360            is_input_complete: true,
361        },
362    ));
363    fake_model.end_last_completion_stream();
364
365    let tool_call = expect_tool_call(&mut events).await;
366    assert_eq!(tool_call.title, "nonexistent_tool");
367    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
368    let update = expect_tool_call_update_fields(&mut events).await;
369    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
370}
371
372async fn expect_tool_call(
373    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
374) -> acp::ToolCall {
375    let event = events
376        .next()
377        .await
378        .expect("no tool call authorization event received")
379        .unwrap();
380    match event {
381        AgentResponseEvent::ToolCall(tool_call) => return tool_call,
382        event => {
383            panic!("Unexpected event {event:?}");
384        }
385    }
386}
387
388async fn expect_tool_call_update_fields(
389    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
390) -> acp::ToolCallUpdate {
391    let event = events
392        .next()
393        .await
394        .expect("no tool call authorization event received")
395        .unwrap();
396    match event {
397        AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
398            return update;
399        }
400        event => {
401            panic!("Unexpected event {event:?}");
402        }
403    }
404}
405
406async fn next_tool_call_authorization(
407    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
408) -> ToolCallAuthorization {
409    loop {
410        let event = events
411            .next()
412            .await
413            .expect("no tool call authorization event received")
414            .unwrap();
415        if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
416            let permission_kinds = tool_call_authorization
417                .options
418                .iter()
419                .map(|o| o.kind)
420                .collect::<Vec<_>>();
421            assert_eq!(
422                permission_kinds,
423                vec![
424                    acp::PermissionOptionKind::AllowAlways,
425                    acp::PermissionOptionKind::AllowOnce,
426                    acp::PermissionOptionKind::RejectOnce,
427                ]
428            );
429            return tool_call_authorization;
430        }
431    }
432}
433
434#[gpui::test]
435#[ignore = "can't run on CI yet"]
436async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
437    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
438
439    // Test concurrent tool calls with different delay times
440    let events = thread
441        .update(cx, |thread, cx| {
442            thread.add_tool(DelayTool);
443            thread.send(
444                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
445                cx,
446            )
447        })
448        .collect()
449        .await;
450
451    let stop_reasons = stop_events(events);
452    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
453
454    thread.update(cx, |thread, _cx| {
455        let last_message = thread.messages().last().unwrap();
456        let text = last_message
457            .content
458            .iter()
459            .filter_map(|content| {
460                if let MessageContent::Text(text) = content {
461                    Some(text.as_str())
462                } else {
463                    None
464                }
465            })
466            .collect::<String>();
467
468        assert!(text.contains("Ding"));
469    });
470}
471
472#[gpui::test]
473#[ignore = "can't run on CI yet"]
474async fn test_cancellation(cx: &mut TestAppContext) {
475    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
476
477    let mut events = thread.update(cx, |thread, cx| {
478        thread.add_tool(InfiniteTool);
479        thread.add_tool(EchoTool);
480        thread.send(
481            "Call the echo tool and then call the infinite tool, then explain their output",
482            cx,
483        )
484    });
485
486    // Wait until both tools are called.
487    let mut expected_tools = vec!["Echo", "Infinite Tool"];
488    let mut echo_id = None;
489    let mut echo_completed = false;
490    while let Some(event) = events.next().await {
491        match event.unwrap() {
492            AgentResponseEvent::ToolCall(tool_call) => {
493                assert_eq!(tool_call.title, expected_tools.remove(0));
494                if tool_call.title == "Echo" {
495                    echo_id = Some(tool_call.id);
496                }
497            }
498            AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
499                acp::ToolCallUpdate {
500                    id,
501                    fields:
502                        acp::ToolCallUpdateFields {
503                            status: Some(acp::ToolCallStatus::Completed),
504                            ..
505                        },
506                },
507            )) if Some(&id) == echo_id.as_ref() => {
508                echo_completed = true;
509            }
510            _ => {}
511        }
512
513        if expected_tools.is_empty() && echo_completed {
514            break;
515        }
516    }
517
518    // Cancel the current send and ensure that the event stream is closed, even
519    // if one of the tools is still running.
520    thread.update(cx, |thread, _cx| thread.cancel());
521    events.collect::<Vec<_>>().await;
522
523    // Ensure we can still send a new message after cancellation.
524    let events = thread
525        .update(cx, |thread, cx| {
526            thread.send("Testing: reply with 'Hello' then stop.", cx)
527        })
528        .collect::<Vec<_>>()
529        .await;
530    thread.update(cx, |thread, _cx| {
531        assert_eq!(
532            thread.messages().last().unwrap().content,
533            vec![MessageContent::Text("Hello".to_string())]
534        );
535    });
536    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
537}
538
539#[gpui::test]
540async fn test_refusal(cx: &mut TestAppContext) {
541    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
542    let fake_model = model.as_fake();
543
544    let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
545    cx.run_until_parked();
546    thread.read_with(cx, |thread, _| {
547        assert_eq!(
548            thread.to_markdown(),
549            indoc! {"
550                ## user
551                Hello
552            "}
553        );
554    });
555
556    fake_model.send_last_completion_stream_text_chunk("Hey!");
557    cx.run_until_parked();
558    thread.read_with(cx, |thread, _| {
559        assert_eq!(
560            thread.to_markdown(),
561            indoc! {"
562                ## user
563                Hello
564                ## assistant
565                Hey!
566            "}
567        );
568    });
569
570    // If the model refuses to continue, the thread should remove all the messages after the last user message.
571    fake_model
572        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
573    let events = events.collect::<Vec<_>>().await;
574    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
575    thread.read_with(cx, |thread, _| {
576        assert_eq!(thread.to_markdown(), "");
577    });
578}
579
580#[gpui::test]
581async fn test_agent_connection(cx: &mut TestAppContext) {
582    cx.update(settings::init);
583    let templates = Templates::new();
584
585    // Initialize language model system with test provider
586    cx.update(|cx| {
587        gpui_tokio::init(cx);
588        client::init_settings(cx);
589
590        let http_client = FakeHttpClient::with_404_response();
591        let clock = Arc::new(clock::FakeSystemClock::new());
592        let client = Client::new(clock, http_client, cx);
593        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
594        language_model::init(client.clone(), cx);
595        language_models::init(user_store.clone(), client.clone(), cx);
596        Project::init_settings(cx);
597        LanguageModelRegistry::test(cx);
598    });
599    cx.executor().forbid_parking();
600
601    // Create a project for new_thread
602    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
603    fake_fs.insert_tree(path!("/test"), json!({})).await;
604    let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
605    let cwd = Path::new("/test");
606
607    // Create agent and connection
608    let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
609        .await
610        .unwrap();
611    let connection = NativeAgentConnection(agent.clone());
612
613    // Test model_selector returns Some
614    let selector_opt = connection.model_selector();
615    assert!(
616        selector_opt.is_some(),
617        "agent2 should always support ModelSelector"
618    );
619    let selector = selector_opt.unwrap();
620
621    // Test list_models
622    let listed_models = cx
623        .update(|cx| {
624            let mut async_cx = cx.to_async();
625            selector.list_models(&mut async_cx)
626        })
627        .await
628        .expect("list_models should succeed");
629    assert!(!listed_models.is_empty(), "should have at least one model");
630    assert_eq!(listed_models[0].id().0, "fake");
631
632    // Create a thread using new_thread
633    let connection_rc = Rc::new(connection.clone());
634    let acp_thread = cx
635        .update(|cx| {
636            let mut async_cx = cx.to_async();
637            connection_rc.new_thread(project, cwd, &mut async_cx)
638        })
639        .await
640        .expect("new_thread should succeed");
641
642    // Get the session_id from the AcpThread
643    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
644
645    // Test selected_model returns the default
646    let model = cx
647        .update(|cx| {
648            let mut async_cx = cx.to_async();
649            selector.selected_model(&session_id, &mut async_cx)
650        })
651        .await
652        .expect("selected_model should succeed");
653    let model = model.as_fake();
654    assert_eq!(model.id().0, "fake", "should return default model");
655
656    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
657    cx.run_until_parked();
658    model.send_last_completion_stream_text_chunk("def");
659    cx.run_until_parked();
660    acp_thread.read_with(cx, |thread, cx| {
661        assert_eq!(
662            thread.to_markdown(cx),
663            indoc! {"
664                ## User
665
666                abc
667
668                ## Assistant
669
670                def
671
672            "}
673        )
674    });
675
676    // Test cancel
677    cx.update(|cx| connection.cancel(&session_id, cx));
678    request.await.expect("prompt should fail gracefully");
679
680    // Ensure that dropping the ACP thread causes the native thread to be
681    // dropped as well.
682    cx.update(|_| drop(acp_thread));
683    let result = cx
684        .update(|cx| {
685            connection.prompt(
686                acp::PromptRequest {
687                    session_id: session_id.clone(),
688                    prompt: vec!["ghi".into()],
689                },
690                cx,
691            )
692        })
693        .await;
694    assert_eq!(
695        result.as_ref().unwrap_err().to_string(),
696        "Session not found",
697        "unexpected result: {:?}",
698        result
699    );
700}
701
702#[gpui::test]
703async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
704    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
705    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
706    let fake_model = model.as_fake();
707
708    let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
709    cx.run_until_parked();
710
711    // Simulate streaming partial input.
712    let input = json!({});
713    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
714        LanguageModelToolUse {
715            id: "1".into(),
716            name: ThinkingTool.name().into(),
717            raw_input: input.to_string(),
718            input,
719            is_input_complete: false,
720        },
721    ));
722
723    // Input streaming completed
724    let input = json!({ "content": "Thinking hard!" });
725    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
726        LanguageModelToolUse {
727            id: "1".into(),
728            name: "thinking".into(),
729            raw_input: input.to_string(),
730            input,
731            is_input_complete: true,
732        },
733    ));
734    fake_model.end_last_completion_stream();
735    cx.run_until_parked();
736
737    let tool_call = expect_tool_call(&mut events).await;
738    assert_eq!(
739        tool_call,
740        acp::ToolCall {
741            id: acp::ToolCallId("1".into()),
742            title: "Thinking".into(),
743            kind: acp::ToolKind::Think,
744            status: acp::ToolCallStatus::Pending,
745            content: vec![],
746            locations: vec![],
747            raw_input: Some(json!({})),
748            raw_output: None,
749        }
750    );
751    let update = expect_tool_call_update_fields(&mut events).await;
752    assert_eq!(
753        update,
754        acp::ToolCallUpdate {
755            id: acp::ToolCallId("1".into()),
756            fields: acp::ToolCallUpdateFields {
757                title: Some("Thinking".into()),
758                kind: Some(acp::ToolKind::Think),
759                raw_input: Some(json!({ "content": "Thinking hard!" })),
760                ..Default::default()
761            },
762        }
763    );
764    let update = expect_tool_call_update_fields(&mut events).await;
765    assert_eq!(
766        update,
767        acp::ToolCallUpdate {
768            id: acp::ToolCallId("1".into()),
769            fields: acp::ToolCallUpdateFields {
770                status: Some(acp::ToolCallStatus::InProgress),
771                ..Default::default()
772            },
773        }
774    );
775    let update = expect_tool_call_update_fields(&mut events).await;
776    assert_eq!(
777        update,
778        acp::ToolCallUpdate {
779            id: acp::ToolCallId("1".into()),
780            fields: acp::ToolCallUpdateFields {
781                content: Some(vec!["Thinking hard!".into()]),
782                ..Default::default()
783            },
784        }
785    );
786    let update = expect_tool_call_update_fields(&mut events).await;
787    assert_eq!(
788        update,
789        acp::ToolCallUpdate {
790            id: acp::ToolCallId("1".into()),
791            fields: acp::ToolCallUpdateFields {
792                status: Some(acp::ToolCallStatus::Completed),
793                ..Default::default()
794            },
795        }
796    );
797}
798
799/// Filters out the stop events for asserting against in tests
800fn stop_events(
801    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
802) -> Vec<acp::StopReason> {
803    result_events
804        .into_iter()
805        .filter_map(|event| match event.unwrap() {
806            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
807            _ => None,
808        })
809        .collect()
810}
811
812struct ThreadTest {
813    model: Arc<dyn LanguageModel>,
814    thread: Entity<Thread>,
815    project_context: Rc<RefCell<ProjectContext>>,
816}
817
818enum TestModel {
819    Sonnet4,
820    Sonnet4Thinking,
821    Fake,
822}
823
824impl TestModel {
825    fn id(&self) -> LanguageModelId {
826        match self {
827            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
828            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
829            TestModel::Fake => unreachable!(),
830        }
831    }
832}
833
834async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
835    cx.executor().allow_parking();
836
837    let fs = FakeFs::new(cx.background_executor.clone());
838
839    cx.update(|cx| {
840        settings::init(cx);
841        watch_settings(fs.clone(), cx);
842        Project::init_settings(cx);
843        agent_settings::init(cx);
844    });
845    let templates = Templates::new();
846
847    fs.insert_tree(path!("/test"), json!({})).await;
848    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
849
850    let model = cx
851        .update(|cx| {
852            gpui_tokio::init(cx);
853            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
854            cx.set_http_client(Arc::new(http_client));
855
856            client::init_settings(cx);
857            let client = Client::production(cx);
858            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
859            language_model::init(client.clone(), cx);
860            language_models::init(user_store.clone(), client.clone(), cx);
861
862            if let TestModel::Fake = model {
863                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
864            } else {
865                let model_id = model.id();
866                let models = LanguageModelRegistry::read_global(cx);
867                let model = models
868                    .available_models(cx)
869                    .find(|model| model.id() == model_id)
870                    .unwrap();
871
872                let provider = models.provider(&model.provider_id()).unwrap();
873                let authenticated = provider.authenticate(cx);
874
875                cx.spawn(async move |_cx| {
876                    authenticated.await.unwrap();
877                    model
878                })
879            }
880        })
881        .await;
882
883    let project_context = Rc::new(RefCell::new(ProjectContext::default()));
884    let action_log = cx.new(|_| ActionLog::new(project.clone()));
885    let thread = cx.new(|_| {
886        Thread::new(
887            project,
888            project_context.clone(),
889            action_log,
890            templates,
891            model.clone(),
892        )
893    });
894    ThreadTest {
895        model,
896        thread,
897        project_context,
898    }
899}
900
901#[cfg(test)]
902#[ctor::ctor]
903fn init_logger() {
904    if std::env::var("RUST_LOG").is_ok() {
905        env_logger::init();
906    }
907}
908
909fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
910    let fs = fs.clone();
911    cx.spawn({
912        async move |cx| {
913            let mut new_settings_content_rx = settings::watch_config_file(
914                cx.background_executor(),
915                fs,
916                paths::settings_file().clone(),
917            );
918
919            while let Some(new_settings_content) = new_settings_content_rx.next().await {
920                cx.update(|cx| {
921                    SettingsStore::update_global(cx, |settings, cx| {
922                        settings.set_user_settings(&new_settings_content, cx)
923                    })
924                })
925                .ok();
926            }
927        }
928    })
929    .detach();
930}