mod.rs

  1use super::*;
  2use crate::templates::Templates;
  3use acp_thread::AgentConnection;
  4use agent_client_protocol::{self as acp};
  5use anyhow::Result;
  6use assistant_tool::ActionLog;
  7use client::{Client, UserStore};
  8use fs::FakeFs;
  9use futures::channel::mpsc::UnboundedReceiver;
 10use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
 11use indoc::indoc;
 12use language_model::{
 13    fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
 14    LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
 15    LanguageModelToolUse, MessageContent, Role, StopReason,
 16};
 17use project::Project;
 18use prompt_store::ProjectContext;
 19use reqwest_client::ReqwestClient;
 20use schemars::JsonSchema;
 21use serde::{Deserialize, Serialize};
 22use serde_json::json;
 23use smol::stream::StreamExt;
 24use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
 25use util::path;
 26
 27mod test_tools;
 28use test_tools::*;
 29
 30#[gpui::test]
 31#[ignore = "can't run on CI yet"]
 32async fn test_echo(cx: &mut TestAppContext) {
 33    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
 34
 35    let events = thread
 36        .update(cx, |thread, cx| {
 37            thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
 38        })
 39        .collect()
 40        .await;
 41    thread.update(cx, |thread, _cx| {
 42        assert_eq!(
 43            thread.messages().last().unwrap().content,
 44            vec![MessageContent::Text("Hello".to_string())]
 45        );
 46    });
 47    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 48}
 49
 50#[gpui::test]
 51#[ignore = "can't run on CI yet"]
 52async fn test_thinking(cx: &mut TestAppContext) {
 53    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
 54
 55    let events = thread
 56        .update(cx, |thread, cx| {
 57            thread.send(
 58                model.clone(),
 59                indoc! {"
 60                    Testing:
 61
 62                    Generate a thinking step where you just think the word 'Think',
 63                    and have your final answer be 'Hello'
 64                "},
 65                cx,
 66            )
 67        })
 68        .collect()
 69        .await;
 70    thread.update(cx, |thread, _cx| {
 71        assert_eq!(
 72            thread.messages().last().unwrap().to_markdown(),
 73            indoc! {"
 74                ## assistant
 75                <think>Think</think>
 76                Hello
 77            "}
 78        )
 79    });
 80    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 81}
 82
 83#[gpui::test]
 84async fn test_system_prompt(cx: &mut TestAppContext) {
 85    let ThreadTest {
 86        model,
 87        thread,
 88        project_context,
 89        ..
 90    } = setup(cx, TestModel::Fake).await;
 91    let fake_model = model.as_fake();
 92
 93    project_context.borrow_mut().shell = "test-shell".into();
 94    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 95    thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
 96    cx.run_until_parked();
 97    let mut pending_completions = fake_model.pending_completions();
 98    assert_eq!(
 99        pending_completions.len(),
100        1,
101        "unexpected pending completions: {:?}",
102        pending_completions
103    );
104
105    let pending_completion = pending_completions.pop().unwrap();
106    assert_eq!(pending_completion.messages[0].role, Role::System);
107
108    let system_message = &pending_completion.messages[0];
109    let system_prompt = system_message.content[0].to_str().unwrap();
110    assert!(
111        system_prompt.contains("test-shell"),
112        "unexpected system message: {:?}",
113        system_message
114    );
115    assert!(
116        system_prompt.contains("## Fixing Diagnostics"),
117        "unexpected system message: {:?}",
118        system_message
119    );
120}
121
122#[gpui::test]
123#[ignore = "can't run on CI yet"]
124async fn test_basic_tool_calls(cx: &mut TestAppContext) {
125    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
126
127    // Test a tool call that's likely to complete *before* streaming stops.
128    let events = thread
129        .update(cx, |thread, cx| {
130            thread.add_tool(EchoTool);
131            thread.send(
132                model.clone(),
133                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
134                cx,
135            )
136        })
137        .collect()
138        .await;
139    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
140
141    // Test a tool calls that's likely to complete *after* streaming stops.
142    let events = thread
143        .update(cx, |thread, cx| {
144            thread.remove_tool(&AgentTool::name(&EchoTool));
145            thread.add_tool(DelayTool);
146            thread.send(
147                model.clone(),
148                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
149                cx,
150            )
151        })
152        .collect()
153        .await;
154    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
155    thread.update(cx, |thread, _cx| {
156        assert!(thread
157            .messages()
158            .last()
159            .unwrap()
160            .content
161            .iter()
162            .any(|content| {
163                if let MessageContent::Text(text) = content {
164                    text.contains("Ding")
165                } else {
166                    false
167                }
168            }));
169    });
170}
171
172#[gpui::test]
173#[ignore = "can't run on CI yet"]
174async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
175    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
176
177    // Test a tool call that's likely to complete *before* streaming stops.
178    let mut events = thread.update(cx, |thread, cx| {
179        thread.add_tool(WordListTool);
180        thread.send(model.clone(), "Test the word_list tool.", cx)
181    });
182
183    let mut saw_partial_tool_use = false;
184    while let Some(event) = events.next().await {
185        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
186            thread.update(cx, |thread, _cx| {
187                // Look for a tool use in the thread's last message
188                let last_content = thread.messages().last().unwrap().content.last().unwrap();
189                if let MessageContent::ToolUse(last_tool_use) = last_content {
190                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
191                    if tool_call.status == acp::ToolCallStatus::Pending {
192                        if !last_tool_use.is_input_complete
193                            && last_tool_use.input.get("g").is_none()
194                        {
195                            saw_partial_tool_use = true;
196                        }
197                    } else {
198                        last_tool_use
199                            .input
200                            .get("a")
201                            .expect("'a' has streamed because input is now complete");
202                        last_tool_use
203                            .input
204                            .get("g")
205                            .expect("'g' has streamed because input is now complete");
206                    }
207                } else {
208                    panic!("last content should be a tool use");
209                }
210            });
211        }
212    }
213
214    assert!(
215        saw_partial_tool_use,
216        "should see at least one partially streamed tool use in the history"
217    );
218}
219
220#[gpui::test]
221async fn test_tool_authorization(cx: &mut TestAppContext) {
222    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
223    let fake_model = model.as_fake();
224
225    let mut events = thread.update(cx, |thread, cx| {
226        thread.add_tool(ToolRequiringPermission);
227        thread.send(model.clone(), "abc", cx)
228    });
229    cx.run_until_parked();
230    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
231        LanguageModelToolUse {
232            id: "tool_id_1".into(),
233            name: ToolRequiringPermission.name().into(),
234            raw_input: "{}".into(),
235            input: json!({}),
236            is_input_complete: true,
237        },
238    ));
239    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
240        LanguageModelToolUse {
241            id: "tool_id_2".into(),
242            name: ToolRequiringPermission.name().into(),
243            raw_input: "{}".into(),
244            input: json!({}),
245            is_input_complete: true,
246        },
247    ));
248    fake_model.end_last_completion_stream();
249    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
250    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
251
252    // Approve the first
253    tool_call_auth_1
254        .response
255        .send(tool_call_auth_1.options[1].id.clone())
256        .unwrap();
257    cx.run_until_parked();
258
259    // Reject the second
260    tool_call_auth_2
261        .response
262        .send(tool_call_auth_1.options[2].id.clone())
263        .unwrap();
264    cx.run_until_parked();
265
266    let completion = fake_model.pending_completions().pop().unwrap();
267    let message = completion.messages.last().unwrap();
268    assert_eq!(
269        message.content,
270        vec![
271            MessageContent::ToolResult(LanguageModelToolResult {
272                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
273                tool_name: tool_call_auth_1.tool_call.title.into(),
274                is_error: false,
275                content: "Allowed".into(),
276                output: None
277            }),
278            MessageContent::ToolResult(LanguageModelToolResult {
279                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
280                tool_name: tool_call_auth_2.tool_call.title.into(),
281                is_error: true,
282                content: "Permission to run tool denied by user".into(),
283                output: None
284            })
285        ]
286    );
287}
288
289async fn next_tool_call_authorization(
290    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
291) -> ToolCallAuthorization {
292    loop {
293        let event = events
294            .next()
295            .await
296            .expect("no tool call authorization event received")
297            .unwrap();
298        if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
299            let permission_kinds = tool_call_authorization
300                .options
301                .iter()
302                .map(|o| o.kind)
303                .collect::<Vec<_>>();
304            assert_eq!(
305                permission_kinds,
306                vec![
307                    acp::PermissionOptionKind::AllowAlways,
308                    acp::PermissionOptionKind::AllowOnce,
309                    acp::PermissionOptionKind::RejectOnce,
310                ]
311            );
312            return tool_call_authorization;
313        }
314    }
315}
316
317#[gpui::test]
318#[ignore = "can't run on CI yet"]
319async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
320    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
321
322    // Test concurrent tool calls with different delay times
323    let events = thread
324        .update(cx, |thread, cx| {
325            thread.add_tool(DelayTool);
326            thread.send(
327                model.clone(),
328                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
329                cx,
330            )
331        })
332        .collect()
333        .await;
334
335    let stop_reasons = stop_events(events);
336    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
337
338    thread.update(cx, |thread, _cx| {
339        let last_message = thread.messages().last().unwrap();
340        let text = last_message
341            .content
342            .iter()
343            .filter_map(|content| {
344                if let MessageContent::Text(text) = content {
345                    Some(text.as_str())
346                } else {
347                    None
348                }
349            })
350            .collect::<String>();
351
352        assert!(text.contains("Ding"));
353    });
354}
355
356#[gpui::test]
357#[ignore = "can't run on CI yet"]
358async fn test_cancellation(cx: &mut TestAppContext) {
359    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
360
361    let mut events = thread.update(cx, |thread, cx| {
362        thread.add_tool(InfiniteTool);
363        thread.add_tool(EchoTool);
364        thread.send(
365            model.clone(),
366            "Call the echo tool and then call the infinite tool, then explain their output",
367            cx,
368        )
369    });
370
371    // Wait until both tools are called.
372    let mut expected_tool_calls = vec!["echo", "infinite"];
373    let mut echo_id = None;
374    let mut echo_completed = false;
375    while let Some(event) = events.next().await {
376        match event.unwrap() {
377            AgentResponseEvent::ToolCall(tool_call) => {
378                assert_eq!(tool_call.title, expected_tool_calls.remove(0));
379                if tool_call.title == "echo" {
380                    echo_id = Some(tool_call.id);
381                }
382            }
383            AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate {
384                id,
385                fields:
386                    acp::ToolCallUpdateFields {
387                        status: Some(acp::ToolCallStatus::Completed),
388                        ..
389                    },
390            }) if Some(&id) == echo_id.as_ref() => {
391                echo_completed = true;
392            }
393            _ => {}
394        }
395
396        if expected_tool_calls.is_empty() && echo_completed {
397            break;
398        }
399    }
400
401    // Cancel the current send and ensure that the event stream is closed, even
402    // if one of the tools is still running.
403    thread.update(cx, |thread, _cx| thread.cancel());
404    events.collect::<Vec<_>>().await;
405
406    // Ensure we can still send a new message after cancellation.
407    let events = thread
408        .update(cx, |thread, cx| {
409            thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
410        })
411        .collect::<Vec<_>>()
412        .await;
413    thread.update(cx, |thread, _cx| {
414        assert_eq!(
415            thread.messages().last().unwrap().content,
416            vec![MessageContent::Text("Hello".to_string())]
417        );
418    });
419    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
420}
421
422#[gpui::test]
423async fn test_refusal(cx: &mut TestAppContext) {
424    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
425    let fake_model = model.as_fake();
426
427    let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
428    cx.run_until_parked();
429    thread.read_with(cx, |thread, _| {
430        assert_eq!(
431            thread.to_markdown(),
432            indoc! {"
433                ## user
434                Hello
435            "}
436        );
437    });
438
439    fake_model.send_last_completion_stream_text_chunk("Hey!");
440    cx.run_until_parked();
441    thread.read_with(cx, |thread, _| {
442        assert_eq!(
443            thread.to_markdown(),
444            indoc! {"
445                ## user
446                Hello
447                ## assistant
448                Hey!
449            "}
450        );
451    });
452
453    // If the model refuses to continue, the thread should remove all the messages after the last user message.
454    fake_model
455        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
456    let events = events.collect::<Vec<_>>().await;
457    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
458    thread.read_with(cx, |thread, _| {
459        assert_eq!(thread.to_markdown(), "");
460    });
461}
462
463#[gpui::test]
464async fn test_agent_connection(cx: &mut TestAppContext) {
465    cx.update(settings::init);
466    let templates = Templates::new();
467
468    // Initialize language model system with test provider
469    cx.update(|cx| {
470        gpui_tokio::init(cx);
471        client::init_settings(cx);
472
473        let http_client = FakeHttpClient::with_404_response();
474        let clock = Arc::new(clock::FakeSystemClock::new());
475        let client = Client::new(clock, http_client, cx);
476        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
477        language_model::init(client.clone(), cx);
478        language_models::init(user_store.clone(), client.clone(), cx);
479        Project::init_settings(cx);
480        LanguageModelRegistry::test(cx);
481    });
482    cx.executor().forbid_parking();
483
484    // Create a project for new_thread
485    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
486    fake_fs.insert_tree(path!("/test"), json!({})).await;
487    let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
488    let cwd = Path::new("/test");
489
490    // Create agent and connection
491    let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
492        .await
493        .unwrap();
494    let connection = NativeAgentConnection(agent.clone());
495
496    // Test model_selector returns Some
497    let selector_opt = connection.model_selector();
498    assert!(
499        selector_opt.is_some(),
500        "agent2 should always support ModelSelector"
501    );
502    let selector = selector_opt.unwrap();
503
504    // Test list_models
505    let listed_models = cx
506        .update(|cx| {
507            let mut async_cx = cx.to_async();
508            selector.list_models(&mut async_cx)
509        })
510        .await
511        .expect("list_models should succeed");
512    assert!(!listed_models.is_empty(), "should have at least one model");
513    assert_eq!(listed_models[0].id().0, "fake");
514
515    // Create a thread using new_thread
516    let connection_rc = Rc::new(connection.clone());
517    let acp_thread = cx
518        .update(|cx| {
519            let mut async_cx = cx.to_async();
520            connection_rc.new_thread(project, cwd, &mut async_cx)
521        })
522        .await
523        .expect("new_thread should succeed");
524
525    // Get the session_id from the AcpThread
526    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
527
528    // Test selected_model returns the default
529    let model = cx
530        .update(|cx| {
531            let mut async_cx = cx.to_async();
532            selector.selected_model(&session_id, &mut async_cx)
533        })
534        .await
535        .expect("selected_model should succeed");
536    let model = model.as_fake();
537    assert_eq!(model.id().0, "fake", "should return default model");
538
539    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
540    cx.run_until_parked();
541    model.send_last_completion_stream_text_chunk("def");
542    cx.run_until_parked();
543    acp_thread.read_with(cx, |thread, cx| {
544        assert_eq!(
545            thread.to_markdown(cx),
546            indoc! {"
547                ## User
548
549                abc
550
551                ## Assistant
552
553                def
554
555            "}
556        )
557    });
558
559    // Test cancel
560    cx.update(|cx| connection.cancel(&session_id, cx));
561    request.await.expect("prompt should fail gracefully");
562
563    // Ensure that dropping the ACP thread causes the native thread to be
564    // dropped as well.
565    cx.update(|_| drop(acp_thread));
566    let result = cx
567        .update(|cx| {
568            connection.prompt(
569                acp::PromptRequest {
570                    session_id: session_id.clone(),
571                    prompt: vec!["ghi".into()],
572                },
573                cx,
574            )
575        })
576        .await;
577    assert_eq!(
578        result.as_ref().unwrap_err().to_string(),
579        "Session not found",
580        "unexpected result: {:?}",
581        result
582    );
583}
584
585/// Filters out the stop events for asserting against in tests
586fn stop_events(
587    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
588) -> Vec<acp::StopReason> {
589    result_events
590        .into_iter()
591        .filter_map(|event| match event.unwrap() {
592            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
593            _ => None,
594        })
595        .collect()
596}
597
598struct ThreadTest {
599    model: Arc<dyn LanguageModel>,
600    thread: Entity<Thread>,
601    project_context: Rc<RefCell<ProjectContext>>,
602}
603
604enum TestModel {
605    Sonnet4,
606    Sonnet4Thinking,
607    Fake,
608}
609
610impl TestModel {
611    fn id(&self) -> LanguageModelId {
612        match self {
613            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
614            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
615            TestModel::Fake => unreachable!(),
616        }
617    }
618}
619
620async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
621    cx.executor().allow_parking();
622    cx.update(|cx| {
623        settings::init(cx);
624        Project::init_settings(cx);
625    });
626    let templates = Templates::new();
627
628    let fs = FakeFs::new(cx.background_executor.clone());
629    fs.insert_tree(path!("/test"), json!({})).await;
630    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
631
632    let model = cx
633        .update(|cx| {
634            gpui_tokio::init(cx);
635            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
636            cx.set_http_client(Arc::new(http_client));
637
638            client::init_settings(cx);
639            let client = Client::production(cx);
640            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
641            language_model::init(client.clone(), cx);
642            language_models::init(user_store.clone(), client.clone(), cx);
643
644            if let TestModel::Fake = model {
645                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
646            } else {
647                let model_id = model.id();
648                let models = LanguageModelRegistry::read_global(cx);
649                let model = models
650                    .available_models(cx)
651                    .find(|model| model.id() == model_id)
652                    .unwrap();
653
654                let provider = models.provider(&model.provider_id()).unwrap();
655                let authenticated = provider.authenticate(cx);
656
657                cx.spawn(async move |_cx| {
658                    authenticated.await.unwrap();
659                    model
660                })
661            }
662        })
663        .await;
664
665    let project_context = Rc::new(RefCell::new(ProjectContext::default()));
666    let action_log = cx.new(|_| ActionLog::new(project.clone()));
667    let thread = cx.new(|_| {
668        Thread::new(
669            project,
670            project_context.clone(),
671            action_log,
672            templates,
673            model.clone(),
674        )
675    });
676    ThreadTest {
677        model,
678        thread,
679        project_context,
680    }
681}
682
683#[cfg(test)]
684#[ctor::ctor]
685fn init_logger() {
686    if std::env::var("RUST_LOG").is_ok() {
687        env_logger::init();
688    }
689}