tests.rs

  1use super::*;
  2use client::{proto::language_server_prompt_request, Client, UserStore};
  3use fs::FakeFs;
  4use gpui::{AppContext, Entity, TestAppContext};
  5use language_model::{
  6    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  7    LanguageModelRegistry, MessageContent, StopReason,
  8};
  9use reqwest_client::ReqwestClient;
 10use schemars::JsonSchema;
 11use serde::{Deserialize, Serialize};
 12use smol::stream::StreamExt;
 13use std::{sync::Arc, time::Duration};
 14
 15mod test_tools;
 16use test_tools::*;
 17
 18#[gpui::test]
 19async fn test_echo(cx: &mut TestAppContext) {
 20    let AgentTest { model, agent, .. } = setup(cx).await;
 21
 22    let events = agent
 23        .update(cx, |agent, cx| {
 24            agent.send(model.clone(), "Testing: Reply with 'Hello'", cx)
 25        })
 26        .collect()
 27        .await;
 28    agent.update(cx, |agent, _cx| {
 29        assert_eq!(
 30            agent.messages.last().unwrap().content,
 31            vec![MessageContent::Text("Hello".to_string())]
 32        );
 33    });
 34    assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
 35}
 36
 37#[gpui::test]
 38async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 39    let AgentTest { model, agent, .. } = setup(cx).await;
 40
 41    // Test a tool call that's likely to complete *before* streaming stops.
 42    let events = agent
 43        .update(cx, |agent, cx| {
 44            agent.add_tool(EchoTool);
 45            agent.send(
 46                model.clone(),
 47                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
 48                cx,
 49            )
 50        })
 51        .collect()
 52        .await;
 53    assert_eq!(
 54        stop_events(events),
 55        vec![StopReason::ToolUse, StopReason::EndTurn]
 56    );
 57
 58    // Test a tool calls that's likely to complete *after* streaming stops.
 59    let events = agent
 60        .update(cx, |agent, cx| {
 61            agent.remove_tool(&AgentTool::name(&EchoTool));
 62            agent.add_tool(DelayTool);
 63            agent.send(
 64                model.clone(),
 65                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
 66                cx,
 67            )
 68        })
 69        .collect()
 70        .await;
 71    assert_eq!(
 72        stop_events(events),
 73        vec![StopReason::ToolUse, StopReason::EndTurn]
 74    );
 75    agent.update(cx, |agent, _cx| {
 76        assert!(agent
 77            .messages
 78            .last()
 79            .unwrap()
 80            .content
 81            .iter()
 82            .any(|content| {
 83                if let MessageContent::Text(text) = content {
 84                    text.contains("Ding")
 85                } else {
 86                    false
 87                }
 88            }));
 89    });
 90}
 91
 92#[gpui::test]
 93async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 94    let AgentTest { model, agent, .. } = setup(cx).await;
 95
 96    // Test a tool call that's likely to complete *before* streaming stops.
 97    let mut events = agent.update(cx, |agent, cx| {
 98        agent.add_tool(WordListTool);
 99        agent.send(model.clone(), "Test the word_list tool.", cx)
100    });
101
102    let mut saw_partial_tool_use = false;
103    while let Some(event) = events.next().await {
104        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
105            agent.update(cx, |agent, _cx| {
106                // Look for a tool use in the agent's last message
107                let last_content = agent.messages().last().unwrap().content.last().unwrap();
108                if let MessageContent::ToolUse(last_tool_use) = last_content {
109                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
110                    if tool_use_event.is_input_complete {
111                        last_tool_use
112                            .input
113                            .get("a")
114                            .expect("'a' has streamed because input is now complete");
115                        last_tool_use
116                            .input
117                            .get("g")
118                            .expect("'g' has streamed because input is now complete");
119                    } else {
120                        if !last_tool_use.is_input_complete
121                            && last_tool_use.input.get("g").is_none()
122                        {
123                            saw_partial_tool_use = true;
124                        }
125                    }
126                } else {
127                    panic!("last content should be a tool use");
128                }
129            });
130        }
131    }
132
133    assert!(
134        saw_partial_tool_use,
135        "should see at least one partially streamed tool use in the history"
136    );
137}
138
139#[gpui::test]
140async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
141    let AgentTest { model, agent, .. } = setup(cx).await;
142
143    // Test concurrent tool calls with different delay times
144    let events = agent
145        .update(cx, |agent, cx| {
146            agent.add_tool(DelayTool);
147            agent.send(
148                model.clone(),
149                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
150                cx,
151            )
152        })
153        .collect()
154        .await;
155
156    let stop_reasons = stop_events(events);
157    if stop_reasons.len() == 2 {
158        assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
159    } else if stop_reasons.len() == 3 {
160        assert_eq!(
161            stop_reasons,
162            vec![
163                StopReason::ToolUse,
164                StopReason::ToolUse,
165                StopReason::EndTurn
166            ]
167        );
168    } else {
169        panic!("Expected either 1 or 2 tool uses followed by end turn");
170    }
171
172    agent.update(cx, |agent, _cx| {
173        let last_message = agent.messages.last().unwrap();
174        let text = last_message
175            .content
176            .iter()
177            .filter_map(|content| {
178                if let MessageContent::Text(text) = content {
179                    Some(text.as_str())
180                } else {
181                    None
182                }
183            })
184            .collect::<String>();
185
186        assert!(text.contains("Ding"));
187    });
188}
189
190/// Filters out the stop events for asserting against in tests
191fn stop_events(
192    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
193) -> Vec<StopReason> {
194    result_events
195        .into_iter()
196        .filter_map(|event| match event.unwrap() {
197            LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
198            _ => None,
199        })
200        .collect()
201}
202
203struct AgentTest {
204    model: Arc<dyn LanguageModel>,
205    agent: Entity<Thread>,
206}
207
208async fn setup(cx: &mut TestAppContext) -> AgentTest {
209    cx.executor().allow_parking();
210    cx.update(settings::init);
211    let fs = FakeFs::new(cx.executor().clone());
212    // let project = Project::test(fs.clone(), [], cx).await;
213    // let action_log = cx.new(|_| ActionLog::new(project.clone()));
214    let templates = Templates::new();
215    let agent = cx.new(|_| Thread::new(templates));
216
217    let model = cx
218        .update(|cx| {
219            gpui_tokio::init(cx);
220            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
221            cx.set_http_client(Arc::new(http_client));
222
223            client::init_settings(cx);
224            let client = Client::production(cx);
225            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
226            language_model::init(client.clone(), cx);
227            language_models::init(user_store.clone(), client.clone(), cx);
228
229            let models = LanguageModelRegistry::read_global(cx);
230            let model = models
231                .available_models(cx)
232                .find(|model| model.id().0 == "claude-3-7-sonnet-latest")
233                .unwrap();
234
235            let provider = models.provider(&model.provider_id()).unwrap();
236            let authenticated = provider.authenticate(cx);
237
238            cx.spawn(async move |cx| {
239                authenticated.await.unwrap();
240                model
241            })
242        })
243        .await;
244
245    AgentTest { model, agent }
246}
247
248#[cfg(test)]
249#[ctor::ctor]
250fn init_logger() {
251    if std::env::var("RUST_LOG").is_ok() {
252        env_logger::init();
253    }
254}