mod.rs

  1use super::*;
  2use crate::templates::Templates;
  3use acp_thread::AgentConnection as _;
  4use agent_client_protocol as acp;
  5use client::{Client, UserStore};
  6use gpui::{AppContext, Entity, TestAppContext};
  7use language_model::{
  8    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  9    LanguageModelRegistry, MessageContent, StopReason,
 10};
 11use project::Project;
 12use reqwest_client::ReqwestClient;
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use smol::stream::StreamExt;
 16use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
 17
 18mod test_tools;
 19use test_tools::*;
 20
 21#[gpui::test]
 22async fn test_echo(cx: &mut TestAppContext) {
 23    let ThreadTest { model, thread, .. } = setup(cx).await;
 24
 25    let events = thread
 26        .update(cx, |thread, cx| {
 27            thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
 28        })
 29        .collect()
 30        .await;
 31    thread.update(cx, |thread, _cx| {
 32        assert_eq!(
 33            thread.messages().last().unwrap().content,
 34            vec![MessageContent::Text("Hello".to_string())]
 35        );
 36    });
 37    assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
 38}
 39
 40#[gpui::test]
 41async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 42    let ThreadTest { model, thread, .. } = setup(cx).await;
 43
 44    // Test a tool call that's likely to complete *before* streaming stops.
 45    let events = thread
 46        .update(cx, |thread, cx| {
 47            thread.add_tool(EchoTool);
 48            thread.send(
 49                model.clone(),
 50                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
 51                cx,
 52            )
 53        })
 54        .collect()
 55        .await;
 56    assert_eq!(
 57        stop_events(events),
 58        vec![StopReason::ToolUse, StopReason::EndTurn]
 59    );
 60
 61    // Test a tool calls that's likely to complete *after* streaming stops.
 62    let events = thread
 63        .update(cx, |thread, cx| {
 64            thread.remove_tool(&AgentTool::name(&EchoTool));
 65            thread.add_tool(DelayTool);
 66            thread.send(
 67                model.clone(),
 68                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
 69                cx,
 70            )
 71        })
 72        .collect()
 73        .await;
 74    assert_eq!(
 75        stop_events(events),
 76        vec![StopReason::ToolUse, StopReason::EndTurn]
 77    );
 78    thread.update(cx, |thread, _cx| {
 79        assert!(thread
 80            .messages()
 81            .last()
 82            .unwrap()
 83            .content
 84            .iter()
 85            .any(|content| {
 86                if let MessageContent::Text(text) = content {
 87                    text.contains("Ding")
 88                } else {
 89                    false
 90                }
 91            }));
 92    });
 93}
 94
 95#[gpui::test]
 96async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 97    let ThreadTest { model, thread, .. } = setup(cx).await;
 98
 99    // Test a tool call that's likely to complete *before* streaming stops.
100    let mut events = thread.update(cx, |thread, cx| {
101        thread.add_tool(WordListTool);
102        thread.send(model.clone(), "Test the word_list tool.", cx)
103    });
104
105    let mut saw_partial_tool_use = false;
106    while let Some(event) = events.next().await {
107        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
108            thread.update(cx, |thread, _cx| {
109                // Look for a tool use in the thread's last message
110                let last_content = thread.messages().last().unwrap().content.last().unwrap();
111                if let MessageContent::ToolUse(last_tool_use) = last_content {
112                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
113                    if tool_use_event.is_input_complete {
114                        last_tool_use
115                            .input
116                            .get("a")
117                            .expect("'a' has streamed because input is now complete");
118                        last_tool_use
119                            .input
120                            .get("g")
121                            .expect("'g' has streamed because input is now complete");
122                    } else {
123                        if !last_tool_use.is_input_complete
124                            && last_tool_use.input.get("g").is_none()
125                        {
126                            saw_partial_tool_use = true;
127                        }
128                    }
129                } else {
130                    panic!("last content should be a tool use");
131                }
132            });
133        }
134    }
135
136    assert!(
137        saw_partial_tool_use,
138        "should see at least one partially streamed tool use in the history"
139    );
140}
141
142#[gpui::test]
143async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
144    let ThreadTest { model, thread, .. } = setup(cx).await;
145
146    // Test concurrent tool calls with different delay times
147    let events = thread
148        .update(cx, |thread, cx| {
149            thread.add_tool(DelayTool);
150            thread.send(
151                model.clone(),
152                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
153                cx,
154            )
155        })
156        .collect()
157        .await;
158
159    let stop_reasons = stop_events(events);
160    if stop_reasons.len() == 2 {
161        assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
162    } else if stop_reasons.len() == 3 {
163        assert_eq!(
164            stop_reasons,
165            vec![
166                StopReason::ToolUse,
167                StopReason::ToolUse,
168                StopReason::EndTurn
169            ]
170        );
171    } else {
172        panic!("Expected either 1 or 2 tool uses followed by end turn");
173    }
174
175    thread.update(cx, |thread, _cx| {
176        let last_message = thread.messages().last().unwrap();
177        let text = last_message
178            .content
179            .iter()
180            .filter_map(|content| {
181                if let MessageContent::Text(text) = content {
182                    Some(text.as_str())
183                } else {
184                    None
185                }
186            })
187            .collect::<String>();
188
189        assert!(text.contains("Ding"));
190    });
191}
192
193#[gpui::test]
194async fn test_agent_connection(cx: &mut TestAppContext) {
195    cx.executor().allow_parking();
196    cx.update(settings::init);
197    let templates = Templates::new();
198
199    // Initialize language model system with test provider
200    cx.update(|cx| {
201        gpui_tokio::init(cx);
202        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
203        cx.set_http_client(Arc::new(http_client));
204
205        client::init_settings(cx);
206        let client = Client::production(cx);
207        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
208        language_model::init(client.clone(), cx);
209        language_models::init(user_store.clone(), client.clone(), cx);
210
211        // Initialize project settings
212        Project::init_settings(cx);
213
214        // Use test registry with fake provider
215        LanguageModelRegistry::test(cx);
216    });
217
218    // Create agent and connection
219    let agent = cx.new(|_| NativeAgent::new(templates.clone()));
220    let connection = NativeAgentConnection(agent.clone());
221
222    // Test model_selector returns Some
223    let selector_opt = connection.model_selector();
224    assert!(
225        selector_opt.is_some(),
226        "agent2 should always support ModelSelector"
227    );
228    let selector = selector_opt.unwrap();
229
230    // Test list_models
231    let listed_models = cx
232        .update(|cx| {
233            let mut async_cx = cx.to_async();
234            selector.list_models(&mut async_cx)
235        })
236        .await
237        .expect("list_models should succeed");
238    assert!(!listed_models.is_empty(), "should have at least one model");
239    assert_eq!(listed_models[0].id().0, "fake");
240
241    // Create a project for new_thread
242    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
243    let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
244
245    // Create a thread using new_thread
246    let cwd = Path::new("/test");
247    let connection_rc = Rc::new(connection.clone());
248    let acp_thread = cx
249        .update(|cx| {
250            let mut async_cx = cx.to_async();
251            connection_rc.new_thread(project, cwd, &mut async_cx)
252        })
253        .await
254        .expect("new_thread should succeed");
255
256    // Get the session_id from the AcpThread
257    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
258
259    // Test selected_model returns the default
260    let selected = cx
261        .update(|cx| {
262            let mut async_cx = cx.to_async();
263            selector.selected_model(&session_id, &mut async_cx)
264        })
265        .await
266        .expect("selected_model should succeed");
267    assert_eq!(selected.id().0, "fake", "should return default model");
268
269    // The thread was created via prompt with the default model
270    // We can verify it through selected_model
271
272    // Test prompt uses the selected model
273    let prompt_request = acp::PromptRequest {
274        session_id: session_id.clone(),
275        prompt: vec![acp::ContentBlock::Text(acp::TextContent {
276            text: "Test prompt".into(),
277            annotations: None,
278        })],
279    };
280
281    cx.update(|cx| connection.prompt(prompt_request, cx))
282        .await
283        .expect("prompt should succeed");
284
285    // The prompt was sent successfully
286
287    // Test cancel
288    cx.update(|cx| connection.cancel(&session_id, cx));
289
290    // After cancel, selected_model should fail
291    let result = cx
292        .update(|cx| {
293            let mut async_cx = cx.to_async();
294            selector.selected_model(&session_id, &mut async_cx)
295        })
296        .await;
297    assert!(result.is_err(), "selected_model should fail after cancel");
298
299    // Test error case: invalid session
300    let invalid_session = acp::SessionId("invalid".into());
301    let result = cx
302        .update(|cx| {
303            let mut async_cx = cx.to_async();
304            selector.selected_model(&invalid_session, &mut async_cx)
305        })
306        .await;
307    assert!(result.is_err(), "should fail for invalid session");
308    if let Err(e) = result {
309        assert!(
310            e.to_string().contains("Session not found"),
311            "should have correct error message"
312        );
313    }
314}
315
316/// Filters out the stop events for asserting against in tests
317fn stop_events(
318    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
319) -> Vec<StopReason> {
320    result_events
321        .into_iter()
322        .filter_map(|event| match event.unwrap() {
323            LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
324            _ => None,
325        })
326        .collect()
327}
328
329struct ThreadTest {
330    model: Arc<dyn LanguageModel>,
331    thread: Entity<Thread>,
332}
333
334async fn setup(cx: &mut TestAppContext) -> ThreadTest {
335    cx.executor().allow_parking();
336    cx.update(settings::init);
337    let templates = Templates::new();
338
339    let model = cx
340        .update(|cx| {
341            gpui_tokio::init(cx);
342            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
343            cx.set_http_client(Arc::new(http_client));
344
345            client::init_settings(cx);
346            let client = Client::production(cx);
347            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
348            language_model::init(client.clone(), cx);
349            language_models::init(user_store.clone(), client.clone(), cx);
350
351            let models = LanguageModelRegistry::read_global(cx);
352            let model = models
353                .available_models(cx)
354                .find(|model| model.id().0 == "claude-3-7-sonnet-latest")
355                .unwrap();
356
357            let provider = models.provider(&model.provider_id()).unwrap();
358            let authenticated = provider.authenticate(cx);
359
360            cx.spawn(async move |_cx| {
361                authenticated.await.unwrap();
362                model
363            })
364        })
365        .await;
366
367    let thread = cx.new(|_| Thread::new(templates, model.clone()));
368
369    ThreadTest { model, thread }
370}
371
372#[cfg(test)]
373#[ctor::ctor]
374fn init_logger() {
375    if std::env::var("RUST_LOG").is_ok() {
376        env_logger::init();
377    }
378}