mod.rs

  1use super::*;
  2use crate::templates::Templates;
  3use acp_thread::AgentConnection as _;
  4use agent_client_protocol as acp;
  5use client::{Client, UserStore};
  6use fs::FakeFs;
  7use gpui::{AppContext, Entity, Task, TestAppContext};
  8use indoc::indoc;
  9use language_model::{
 10    fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
 11    LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, MessageContent,
 12    StopReason,
 13};
 14use project::Project;
 15use reqwest_client::ReqwestClient;
 16use schemars::JsonSchema;
 17use serde::{Deserialize, Serialize};
 18use serde_json::json;
 19use smol::stream::StreamExt;
 20use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
 21use util::path;
 22
 23mod test_tools;
 24use test_tools::*;
 25
 26#[gpui::test]
 27#[ignore = "temporarily disabled until it can be run on CI"]
 28async fn test_echo(cx: &mut TestAppContext) {
 29    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
 30
 31    let events = thread
 32        .update(cx, |thread, cx| {
 33            thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
 34        })
 35        .collect()
 36        .await;
 37    thread.update(cx, |thread, _cx| {
 38        assert_eq!(
 39            thread.messages().last().unwrap().content,
 40            vec![MessageContent::Text("Hello".to_string())]
 41        );
 42    });
 43    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 44}
 45
 46#[gpui::test]
 47#[ignore = "temporarily disabled until it can be run on CI"]
 48async fn test_thinking(cx: &mut TestAppContext) {
 49    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
 50
 51    let events = thread
 52        .update(cx, |thread, cx| {
 53            thread.send(
 54                model.clone(),
 55                indoc! {"
 56                    Testing:
 57
 58                    Generate a thinking step where you just think the word 'Think',
 59                    and have your final answer be 'Hello'
 60                "},
 61                cx,
 62            )
 63        })
 64        .collect()
 65        .await;
 66    thread.update(cx, |thread, _cx| {
 67        assert_eq!(
 68            thread.messages().last().unwrap().to_markdown(),
 69            indoc! {"
 70                ## assistant
 71                <think>Think</think>
 72                Hello
 73            "}
 74        )
 75    });
 76    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 77}
 78
 79#[gpui::test]
 80#[ignore = "temporarily disabled until it can be run on CI"]
 81async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 82    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
 83
 84    // Test a tool call that's likely to complete *before* streaming stops.
 85    let events = thread
 86        .update(cx, |thread, cx| {
 87            thread.add_tool(EchoTool);
 88            thread.send(
 89                model.clone(),
 90                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
 91                cx,
 92            )
 93        })
 94        .collect()
 95        .await;
 96    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 97
 98    // Test a tool calls that's likely to complete *after* streaming stops.
 99    let events = thread
100        .update(cx, |thread, cx| {
101            thread.remove_tool(&AgentTool::name(&EchoTool));
102            thread.add_tool(DelayTool);
103            thread.send(
104                model.clone(),
105                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
106                cx,
107            )
108        })
109        .collect()
110        .await;
111    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
112    thread.update(cx, |thread, _cx| {
113        assert!(thread
114            .messages()
115            .last()
116            .unwrap()
117            .content
118            .iter()
119            .any(|content| {
120                if let MessageContent::Text(text) = content {
121                    text.contains("Ding")
122                } else {
123                    false
124                }
125            }));
126    });
127}
128
129#[gpui::test]
130#[ignore = "temporarily disabled until it can be run on CI"]
131async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
132    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
133
134    // Test a tool call that's likely to complete *before* streaming stops.
135    let mut events = thread.update(cx, |thread, cx| {
136        thread.add_tool(WordListTool);
137        thread.send(model.clone(), "Test the word_list tool.", cx)
138    });
139
140    let mut saw_partial_tool_use = false;
141    while let Some(event) = events.next().await {
142        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
143            thread.update(cx, |thread, _cx| {
144                // Look for a tool use in the thread's last message
145                let last_content = thread.messages().last().unwrap().content.last().unwrap();
146                if let MessageContent::ToolUse(last_tool_use) = last_content {
147                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
148                    if tool_call.status == acp::ToolCallStatus::Pending {
149                        if !last_tool_use.is_input_complete
150                            && last_tool_use.input.get("g").is_none()
151                        {
152                            saw_partial_tool_use = true;
153                        }
154                    } else {
155                        last_tool_use
156                            .input
157                            .get("a")
158                            .expect("'a' has streamed because input is now complete");
159                        last_tool_use
160                            .input
161                            .get("g")
162                            .expect("'g' has streamed because input is now complete");
163                    }
164                } else {
165                    panic!("last content should be a tool use");
166                }
167            });
168        }
169    }
170
171    assert!(
172        saw_partial_tool_use,
173        "should see at least one partially streamed tool use in the history"
174    );
175}
176
177#[gpui::test]
178#[ignore = "temporarily disabled until it can be run on CI"]
179async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
180    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
181
182    // Test concurrent tool calls with different delay times
183    let events = thread
184        .update(cx, |thread, cx| {
185            thread.add_tool(DelayTool);
186            thread.send(
187                model.clone(),
188                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
189                cx,
190            )
191        })
192        .collect()
193        .await;
194
195    let stop_reasons = stop_events(events);
196    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
197
198    thread.update(cx, |thread, _cx| {
199        let last_message = thread.messages().last().unwrap();
200        let text = last_message
201            .content
202            .iter()
203            .filter_map(|content| {
204                if let MessageContent::Text(text) = content {
205                    Some(text.as_str())
206                } else {
207                    None
208                }
209            })
210            .collect::<String>();
211
212        assert!(text.contains("Ding"));
213    });
214}
215
216#[gpui::test]
217#[ignore = "temporarily disabled until it can be run on CI"]
218async fn test_cancellation(cx: &mut TestAppContext) {
219    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
220
221    let mut events = thread.update(cx, |thread, cx| {
222        thread.add_tool(InfiniteTool);
223        thread.add_tool(EchoTool);
224        thread.send(
225            model.clone(),
226            "Call the echo tool and then call the infinite tool, then explain their output",
227            cx,
228        )
229    });
230
231    // Wait until both tools are called.
232    let mut expected_tool_calls = vec!["echo", "infinite"];
233    let mut echo_id = None;
234    let mut echo_completed = false;
235    while let Some(event) = events.next().await {
236        match event.unwrap() {
237            AgentResponseEvent::ToolCall(tool_call) => {
238                assert_eq!(tool_call.title, expected_tool_calls.remove(0));
239                if tool_call.title == "echo" {
240                    echo_id = Some(tool_call.id);
241                }
242            }
243            AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate {
244                id,
245                fields:
246                    acp::ToolCallUpdateFields {
247                        status: Some(acp::ToolCallStatus::Completed),
248                        ..
249                    },
250            }) if Some(&id) == echo_id.as_ref() => {
251                echo_completed = true;
252            }
253            _ => {}
254        }
255
256        if expected_tool_calls.is_empty() && echo_completed {
257            break;
258        }
259    }
260
261    // Cancel the current send and ensure that the event stream is closed, even
262    // if one of the tools is still running.
263    thread.update(cx, |thread, _cx| thread.cancel());
264    events.collect::<Vec<_>>().await;
265
266    // Ensure we can still send a new message after cancellation.
267    let events = thread
268        .update(cx, |thread, cx| {
269            thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
270        })
271        .collect::<Vec<_>>()
272        .await;
273    thread.update(cx, |thread, _cx| {
274        assert_eq!(
275            thread.messages().last().unwrap().content,
276            vec![MessageContent::Text("Hello".to_string())]
277        );
278    });
279    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
280}
281
282#[gpui::test]
283async fn test_refusal(cx: &mut TestAppContext) {
284    let fake_model = Arc::new(FakeLanguageModel::default());
285    let ThreadTest { thread, .. } = setup(cx, TestModel::Fake(fake_model.clone())).await;
286
287    let events = thread.update(cx, |thread, cx| {
288        thread.send(fake_model.clone(), "Hello", cx)
289    });
290    cx.run_until_parked();
291    thread.read_with(cx, |thread, _| {
292        assert_eq!(
293            thread.to_markdown(),
294            indoc! {"
295                ## user
296                Hello
297            "}
298        );
299    });
300
301    fake_model.send_last_completion_stream_text_chunk("Hey!");
302    cx.run_until_parked();
303    thread.read_with(cx, |thread, _| {
304        assert_eq!(
305            thread.to_markdown(),
306            indoc! {"
307                ## user
308                Hello
309                ## assistant
310                Hey!
311            "}
312        );
313    });
314
315    // If the model refuses to continue, the thread should remove all the messages after the last user message.
316    fake_model
317        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
318    let events = events.collect::<Vec<_>>().await;
319    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
320    thread.read_with(cx, |thread, _| {
321        assert_eq!(thread.to_markdown(), "");
322    });
323}
324
325#[ignore = "temporarily disabled until it can be run on CI"]
326#[gpui::test]
327async fn test_agent_connection(cx: &mut TestAppContext) {
328    cx.executor().allow_parking();
329    cx.update(settings::init);
330    let templates = Templates::new();
331
332    // Initialize language model system with test provider
333    cx.update(|cx| {
334        gpui_tokio::init(cx);
335        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
336        cx.set_http_client(Arc::new(http_client));
337
338        client::init_settings(cx);
339        let client = Client::production(cx);
340        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
341        language_model::init(client.clone(), cx);
342        language_models::init(user_store.clone(), client.clone(), cx);
343
344        // Initialize project settings
345        Project::init_settings(cx);
346
347        // Use test registry with fake provider
348        LanguageModelRegistry::test(cx);
349    });
350
351    // Create agent and connection
352    let agent = cx.new(|_| NativeAgent::new(templates.clone()));
353    let connection = NativeAgentConnection(agent.clone());
354
355    // Test model_selector returns Some
356    let selector_opt = connection.model_selector();
357    assert!(
358        selector_opt.is_some(),
359        "agent2 should always support ModelSelector"
360    );
361    let selector = selector_opt.unwrap();
362
363    // Test list_models
364    let listed_models = cx
365        .update(|cx| {
366            let mut async_cx = cx.to_async();
367            selector.list_models(&mut async_cx)
368        })
369        .await
370        .expect("list_models should succeed");
371    assert!(!listed_models.is_empty(), "should have at least one model");
372    assert_eq!(listed_models[0].id().0, "fake");
373
374    // Create a project for new_thread
375    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
376    let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
377
378    // Create a thread using new_thread
379    let cwd = Path::new("/test");
380    let connection_rc = Rc::new(connection.clone());
381    let acp_thread = cx
382        .update(|cx| {
383            let mut async_cx = cx.to_async();
384            connection_rc.new_thread(project, cwd, &mut async_cx)
385        })
386        .await
387        .expect("new_thread should succeed");
388
389    // Get the session_id from the AcpThread
390    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
391
392    // Test selected_model returns the default
393    let selected = cx
394        .update(|cx| {
395            let mut async_cx = cx.to_async();
396            selector.selected_model(&session_id, &mut async_cx)
397        })
398        .await
399        .expect("selected_model should succeed");
400    assert_eq!(selected.id().0, "fake", "should return default model");
401
402    // The thread was created via prompt with the default model
403    // We can verify it through selected_model
404
405    // Test prompt uses the selected model
406    let prompt_request = acp::PromptRequest {
407        session_id: session_id.clone(),
408        prompt: vec![acp::ContentBlock::Text(acp::TextContent {
409            text: "Test prompt".into(),
410            annotations: None,
411        })],
412    };
413
414    let request = cx.update(|cx| connection.prompt(prompt_request, cx));
415    let request = cx.background_spawn(request);
416    smol::Timer::after(Duration::from_millis(100)).await;
417
418    // Test cancel
419    cx.update(|cx| connection.cancel(&session_id, cx));
420    request.await.expect("prompt should fail gracefully");
421}
422
423/// Filters out the stop events for asserting against in tests
424fn stop_events(
425    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
426) -> Vec<acp::StopReason> {
427    result_events
428        .into_iter()
429        .filter_map(|event| match event.unwrap() {
430            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
431            _ => None,
432        })
433        .collect()
434}
435
436struct ThreadTest {
437    model: Arc<dyn LanguageModel>,
438    thread: Entity<Thread>,
439}
440
441enum TestModel {
442    Sonnet4,
443    Sonnet4Thinking,
444    Fake(Arc<FakeLanguageModel>),
445}
446
447impl TestModel {
448    fn id(&self) -> LanguageModelId {
449        match self {
450            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
451            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
452            TestModel::Fake(fake_model) => fake_model.id(),
453        }
454    }
455}
456
457async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
458    cx.executor().allow_parking();
459    cx.update(|cx| {
460        settings::init(cx);
461        Project::init_settings(cx);
462    });
463    let templates = Templates::new();
464
465    let fs = FakeFs::new(cx.background_executor.clone());
466    fs.insert_tree(path!("/test"), json!({})).await;
467    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
468
469    let model = cx
470        .update(|cx| {
471            gpui_tokio::init(cx);
472            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
473            cx.set_http_client(Arc::new(http_client));
474
475            client::init_settings(cx);
476            let client = Client::production(cx);
477            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
478            language_model::init(client.clone(), cx);
479            language_models::init(user_store.clone(), client.clone(), cx);
480
481            if let TestModel::Fake(model) = model {
482                Task::ready(model as Arc<_>)
483            } else {
484                let model_id = model.id();
485                let models = LanguageModelRegistry::read_global(cx);
486                let model = models
487                    .available_models(cx)
488                    .find(|model| model.id() == model_id)
489                    .unwrap();
490
491                let provider = models.provider(&model.provider_id()).unwrap();
492                let authenticated = provider.authenticate(cx);
493
494                cx.spawn(async move |_cx| {
495                    authenticated.await.unwrap();
496                    model
497                })
498            }
499        })
500        .await;
501
502    let thread = cx.new(|_| Thread::new(project, templates, model.clone()));
503
504    ThreadTest { model, thread }
505}
506
507#[cfg(test)]
508#[ctor::ctor]
509fn init_logger() {
510    if std::env::var("RUST_LOG").is_ok() {
511        env_logger::init();
512    }
513}