mod.rs

  1use super::*;
  2use crate::templates::Templates;
  3use acp_thread::AgentConnection;
  4use agent_client_protocol as acp;
  5use client::{Client, UserStore};
  6use fs::FakeFs;
  7use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
  8use indoc::indoc;
  9use language_model::{
 10    fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
 11    LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, MessageContent,
 12    StopReason,
 13};
 14use project::Project;
 15use reqwest_client::ReqwestClient;
 16use schemars::JsonSchema;
 17use serde::{Deserialize, Serialize};
 18use serde_json::json;
 19use smol::stream::StreamExt;
 20use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
 21use util::path;
 22
 23mod test_tools;
 24use test_tools::*;
 25
 26#[gpui::test]
 27#[ignore = "temporarily disabled until it can be run on CI"]
 28async fn test_echo(cx: &mut TestAppContext) {
 29    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
 30
 31    let events = thread
 32        .update(cx, |thread, cx| {
 33            thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
 34        })
 35        .collect()
 36        .await;
 37    thread.update(cx, |thread, _cx| {
 38        assert_eq!(
 39            thread.messages().last().unwrap().content,
 40            vec![MessageContent::Text("Hello".to_string())]
 41        );
 42    });
 43    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 44}
 45
 46#[gpui::test]
 47#[ignore = "temporarily disabled until it can be run on CI"]
 48async fn test_thinking(cx: &mut TestAppContext) {
 49    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
 50
 51    let events = thread
 52        .update(cx, |thread, cx| {
 53            thread.send(
 54                model.clone(),
 55                indoc! {"
 56                    Testing:
 57
 58                    Generate a thinking step where you just think the word 'Think',
 59                    and have your final answer be 'Hello'
 60                "},
 61                cx,
 62            )
 63        })
 64        .collect()
 65        .await;
 66    thread.update(cx, |thread, _cx| {
 67        assert_eq!(
 68            thread.messages().last().unwrap().to_markdown(),
 69            indoc! {"
 70                ## assistant
 71                <think>Think</think>
 72                Hello
 73            "}
 74        )
 75    });
 76    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 77}
 78
 79#[gpui::test]
 80#[ignore = "temporarily disabled until it can be run on CI"]
 81async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 82    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
 83
 84    // Test a tool call that's likely to complete *before* streaming stops.
 85    let events = thread
 86        .update(cx, |thread, cx| {
 87            thread.add_tool(EchoTool);
 88            thread.send(
 89                model.clone(),
 90                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
 91                cx,
 92            )
 93        })
 94        .collect()
 95        .await;
 96    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 97
 98    // Test a tool calls that's likely to complete *after* streaming stops.
 99    let events = thread
100        .update(cx, |thread, cx| {
101            thread.remove_tool(&AgentTool::name(&EchoTool));
102            thread.add_tool(DelayTool);
103            thread.send(
104                model.clone(),
105                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
106                cx,
107            )
108        })
109        .collect()
110        .await;
111    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
112    thread.update(cx, |thread, _cx| {
113        assert!(thread
114            .messages()
115            .last()
116            .unwrap()
117            .content
118            .iter()
119            .any(|content| {
120                if let MessageContent::Text(text) = content {
121                    text.contains("Ding")
122                } else {
123                    false
124                }
125            }));
126    });
127}
128
129#[gpui::test]
130#[ignore = "temporarily disabled until it can be run on CI"]
131async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
132    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
133
134    // Test a tool call that's likely to complete *before* streaming stops.
135    let mut events = thread.update(cx, |thread, cx| {
136        thread.add_tool(WordListTool);
137        thread.send(model.clone(), "Test the word_list tool.", cx)
138    });
139
140    let mut saw_partial_tool_use = false;
141    while let Some(event) = events.next().await {
142        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
143            thread.update(cx, |thread, _cx| {
144                // Look for a tool use in the thread's last message
145                let last_content = thread.messages().last().unwrap().content.last().unwrap();
146                if let MessageContent::ToolUse(last_tool_use) = last_content {
147                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
148                    if tool_call.status == acp::ToolCallStatus::Pending {
149                        if !last_tool_use.is_input_complete
150                            && last_tool_use.input.get("g").is_none()
151                        {
152                            saw_partial_tool_use = true;
153                        }
154                    } else {
155                        last_tool_use
156                            .input
157                            .get("a")
158                            .expect("'a' has streamed because input is now complete");
159                        last_tool_use
160                            .input
161                            .get("g")
162                            .expect("'g' has streamed because input is now complete");
163                    }
164                } else {
165                    panic!("last content should be a tool use");
166                }
167            });
168        }
169    }
170
171    assert!(
172        saw_partial_tool_use,
173        "should see at least one partially streamed tool use in the history"
174    );
175}
176
177#[gpui::test]
178#[ignore = "temporarily disabled until it can be run on CI"]
179async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
180    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
181
182    // Test concurrent tool calls with different delay times
183    let events = thread
184        .update(cx, |thread, cx| {
185            thread.add_tool(DelayTool);
186            thread.send(
187                model.clone(),
188                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
189                cx,
190            )
191        })
192        .collect()
193        .await;
194
195    let stop_reasons = stop_events(events);
196    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
197
198    thread.update(cx, |thread, _cx| {
199        let last_message = thread.messages().last().unwrap();
200        let text = last_message
201            .content
202            .iter()
203            .filter_map(|content| {
204                if let MessageContent::Text(text) = content {
205                    Some(text.as_str())
206                } else {
207                    None
208                }
209            })
210            .collect::<String>();
211
212        assert!(text.contains("Ding"));
213    });
214}
215
216#[gpui::test]
217#[ignore = "temporarily disabled until it can be run on CI"]
218async fn test_cancellation(cx: &mut TestAppContext) {
219    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
220
221    let mut events = thread.update(cx, |thread, cx| {
222        thread.add_tool(InfiniteTool);
223        thread.add_tool(EchoTool);
224        thread.send(
225            model.clone(),
226            "Call the echo tool and then call the infinite tool, then explain their output",
227            cx,
228        )
229    });
230
231    // Wait until both tools are called.
232    let mut expected_tool_calls = vec!["echo", "infinite"];
233    let mut echo_id = None;
234    let mut echo_completed = false;
235    while let Some(event) = events.next().await {
236        match event.unwrap() {
237            AgentResponseEvent::ToolCall(tool_call) => {
238                assert_eq!(tool_call.title, expected_tool_calls.remove(0));
239                if tool_call.title == "echo" {
240                    echo_id = Some(tool_call.id);
241                }
242            }
243            AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate {
244                id,
245                fields:
246                    acp::ToolCallUpdateFields {
247                        status: Some(acp::ToolCallStatus::Completed),
248                        ..
249                    },
250            }) if Some(&id) == echo_id.as_ref() => {
251                echo_completed = true;
252            }
253            _ => {}
254        }
255
256        if expected_tool_calls.is_empty() && echo_completed {
257            break;
258        }
259    }
260
261    // Cancel the current send and ensure that the event stream is closed, even
262    // if one of the tools is still running.
263    thread.update(cx, |thread, _cx| thread.cancel());
264    events.collect::<Vec<_>>().await;
265
266    // Ensure we can still send a new message after cancellation.
267    let events = thread
268        .update(cx, |thread, cx| {
269            thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
270        })
271        .collect::<Vec<_>>()
272        .await;
273    thread.update(cx, |thread, _cx| {
274        assert_eq!(
275            thread.messages().last().unwrap().content,
276            vec![MessageContent::Text("Hello".to_string())]
277        );
278    });
279    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
280}
281
282#[gpui::test]
283async fn test_refusal(cx: &mut TestAppContext) {
284    let fake_model = Arc::new(FakeLanguageModel::default());
285    let ThreadTest { thread, .. } = setup(cx, TestModel::Fake(fake_model.clone())).await;
286
287    let events = thread.update(cx, |thread, cx| {
288        thread.send(fake_model.clone(), "Hello", cx)
289    });
290    cx.run_until_parked();
291    thread.read_with(cx, |thread, _| {
292        assert_eq!(
293            thread.to_markdown(),
294            indoc! {"
295                ## user
296                Hello
297            "}
298        );
299    });
300
301    fake_model.send_last_completion_stream_text_chunk("Hey!");
302    cx.run_until_parked();
303    thread.read_with(cx, |thread, _| {
304        assert_eq!(
305            thread.to_markdown(),
306            indoc! {"
307                ## user
308                Hello
309                ## assistant
310                Hey!
311            "}
312        );
313    });
314
315    // If the model refuses to continue, the thread should remove all the messages after the last user message.
316    fake_model
317        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
318    let events = events.collect::<Vec<_>>().await;
319    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
320    thread.read_with(cx, |thread, _| {
321        assert_eq!(thread.to_markdown(), "");
322    });
323}
324
325#[gpui::test]
326async fn test_agent_connection(cx: &mut TestAppContext) {
327    cx.update(settings::init);
328    let templates = Templates::new();
329
330    // Initialize language model system with test provider
331    cx.update(|cx| {
332        gpui_tokio::init(cx);
333        client::init_settings(cx);
334
335        let http_client = FakeHttpClient::with_404_response();
336        let clock = Arc::new(clock::FakeSystemClock::new());
337        let client = Client::new(clock, http_client, cx);
338        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
339        language_model::init(client.clone(), cx);
340        language_models::init(user_store.clone(), client.clone(), cx);
341        Project::init_settings(cx);
342        LanguageModelRegistry::test(cx);
343    });
344    cx.executor().forbid_parking();
345
346    // Create agent and connection
347    let agent = cx.new(|_| NativeAgent::new(templates.clone()));
348    let connection = NativeAgentConnection(agent.clone());
349
350    // Test model_selector returns Some
351    let selector_opt = connection.model_selector();
352    assert!(
353        selector_opt.is_some(),
354        "agent2 should always support ModelSelector"
355    );
356    let selector = selector_opt.unwrap();
357
358    // Test list_models
359    let listed_models = cx
360        .update(|cx| {
361            let mut async_cx = cx.to_async();
362            selector.list_models(&mut async_cx)
363        })
364        .await
365        .expect("list_models should succeed");
366    assert!(!listed_models.is_empty(), "should have at least one model");
367    assert_eq!(listed_models[0].id().0, "fake");
368
369    // Create a project for new_thread
370    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
371    let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
372
373    // Create a thread using new_thread
374    let cwd = Path::new("/test");
375    let connection_rc = Rc::new(connection.clone());
376    let acp_thread = cx
377        .update(|cx| {
378            let mut async_cx = cx.to_async();
379            connection_rc.new_thread(project, cwd, &mut async_cx)
380        })
381        .await
382        .expect("new_thread should succeed");
383
384    // Get the session_id from the AcpThread
385    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
386
387    // Test selected_model returns the default
388    let model = cx
389        .update(|cx| {
390            let mut async_cx = cx.to_async();
391            selector.selected_model(&session_id, &mut async_cx)
392        })
393        .await
394        .expect("selected_model should succeed");
395    let model = model.as_fake();
396    assert_eq!(model.id().0, "fake", "should return default model");
397
398    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
399    cx.run_until_parked();
400    model.send_last_completion_stream_text_chunk("def");
401    cx.run_until_parked();
402    acp_thread.read_with(cx, |thread, cx| {
403        assert_eq!(
404            thread.to_markdown(cx),
405            indoc! {"
406                ## User
407
408                abc
409
410                ## Assistant
411
412                def
413
414            "}
415        )
416    });
417
418    // Test cancel
419    cx.update(|cx| connection.cancel(&session_id, cx));
420    request.await.expect("prompt should fail gracefully");
421
422    // Ensure that dropping the ACP thread causes the native thread to be
423    // dropped as well.
424    cx.update(|_| drop(acp_thread));
425    let result = cx
426        .update(|cx| {
427            connection.prompt(
428                acp::PromptRequest {
429                    session_id: session_id.clone(),
430                    prompt: vec!["ghi".into()],
431                },
432                cx,
433            )
434        })
435        .await;
436    assert_eq!(
437        result.as_ref().unwrap_err().to_string(),
438        "Session not found",
439        "unexpected result: {:?}",
440        result
441    );
442}
443
444/// Filters out the stop events for asserting against in tests
445fn stop_events(
446    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
447) -> Vec<acp::StopReason> {
448    result_events
449        .into_iter()
450        .filter_map(|event| match event.unwrap() {
451            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
452            _ => None,
453        })
454        .collect()
455}
456
457struct ThreadTest {
458    model: Arc<dyn LanguageModel>,
459    thread: Entity<Thread>,
460}
461
462enum TestModel {
463    Sonnet4,
464    Sonnet4Thinking,
465    Fake(Arc<FakeLanguageModel>),
466}
467
468impl TestModel {
469    fn id(&self) -> LanguageModelId {
470        match self {
471            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
472            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
473            TestModel::Fake(fake_model) => fake_model.id(),
474        }
475    }
476}
477
478async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
479    cx.executor().allow_parking();
480    cx.update(|cx| {
481        settings::init(cx);
482        Project::init_settings(cx);
483    });
484    let templates = Templates::new();
485
486    let fs = FakeFs::new(cx.background_executor.clone());
487    fs.insert_tree(path!("/test"), json!({})).await;
488    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
489
490    let model = cx
491        .update(|cx| {
492            gpui_tokio::init(cx);
493            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
494            cx.set_http_client(Arc::new(http_client));
495
496            client::init_settings(cx);
497            let client = Client::production(cx);
498            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
499            language_model::init(client.clone(), cx);
500            language_models::init(user_store.clone(), client.clone(), cx);
501
502            if let TestModel::Fake(model) = model {
503                Task::ready(model as Arc<_>)
504            } else {
505                let model_id = model.id();
506                let models = LanguageModelRegistry::read_global(cx);
507                let model = models
508                    .available_models(cx)
509                    .find(|model| model.id() == model_id)
510                    .unwrap();
511
512                let provider = models.provider(&model.provider_id()).unwrap();
513                let authenticated = provider.authenticate(cx);
514
515                cx.spawn(async move |_cx| {
516                    authenticated.await.unwrap();
517                    model
518                })
519            }
520        })
521        .await;
522
523    let thread = cx.new(|_| Thread::new(project, templates, model.clone()));
524
525    ThreadTest { model, thread }
526}
527
528#[cfg(test)]
529#[ctor::ctor]
530fn init_logger() {
531    if std::env::var("RUST_LOG").is_ok() {
532        env_logger::init();
533    }
534}