mod.rs

  1use super::*;
  2use crate::templates::Templates;
  3use client::{Client, UserStore};
  4use fs::FakeFs;
  5use gpui::{AppContext, Entity, TestAppContext};
  6use language_model::{
  7    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  8    LanguageModelRegistry, MessageContent, StopReason,
  9};
 10use reqwest_client::ReqwestClient;
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use smol::stream::StreamExt;
 14use std::{sync::Arc, time::Duration};
 15
 16mod test_tools;
 17use test_tools::*;
 18
 19#[gpui::test]
 20async fn test_echo(cx: &mut TestAppContext) {
 21    let AgentTest { model, agent, .. } = setup(cx).await;
 22
 23    let events = agent
 24        .update(cx, |agent, cx| {
 25            agent.send(model.clone(), "Testing: Reply with 'Hello'", cx)
 26        })
 27        .collect()
 28        .await;
 29    agent.update(cx, |agent, _cx| {
 30        assert_eq!(
 31            agent.messages().last().unwrap().content,
 32            vec![MessageContent::Text("Hello".to_string())]
 33        );
 34    });
 35    assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
 36}
 37
 38#[gpui::test]
 39async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 40    let AgentTest { model, agent, .. } = setup(cx).await;
 41
 42    // Test a tool call that's likely to complete *before* streaming stops.
 43    let events = agent
 44        .update(cx, |agent, cx| {
 45            agent.add_tool(EchoTool);
 46            agent.send(
 47                model.clone(),
 48                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
 49                cx,
 50            )
 51        })
 52        .collect()
 53        .await;
 54    assert_eq!(
 55        stop_events(events),
 56        vec![StopReason::ToolUse, StopReason::EndTurn]
 57    );
 58
 59    // Test a tool calls that's likely to complete *after* streaming stops.
 60    let events = agent
 61        .update(cx, |agent, cx| {
 62            agent.remove_tool(&AgentTool::name(&EchoTool));
 63            agent.add_tool(DelayTool);
 64            agent.send(
 65                model.clone(),
 66                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
 67                cx,
 68            )
 69        })
 70        .collect()
 71        .await;
 72    assert_eq!(
 73        stop_events(events),
 74        vec![StopReason::ToolUse, StopReason::EndTurn]
 75    );
 76    agent.update(cx, |agent, _cx| {
 77        assert!(agent
 78            .messages()
 79            .last()
 80            .unwrap()
 81            .content
 82            .iter()
 83            .any(|content| {
 84                if let MessageContent::Text(text) = content {
 85                    text.contains("Ding")
 86                } else {
 87                    false
 88                }
 89            }));
 90    });
 91}
 92
 93#[gpui::test]
 94async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 95    let AgentTest { model, agent, .. } = setup(cx).await;
 96
 97    // Test a tool call that's likely to complete *before* streaming stops.
 98    let mut events = agent.update(cx, |agent, cx| {
 99        agent.add_tool(WordListTool);
100        agent.send(model.clone(), "Test the word_list tool.", cx)
101    });
102
103    let mut saw_partial_tool_use = false;
104    while let Some(event) = events.next().await {
105        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
106            agent.update(cx, |agent, _cx| {
107                // Look for a tool use in the agent's last message
108                let last_content = agent.messages().last().unwrap().content.last().unwrap();
109                if let MessageContent::ToolUse(last_tool_use) = last_content {
110                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
111                    if tool_use_event.is_input_complete {
112                        last_tool_use
113                            .input
114                            .get("a")
115                            .expect("'a' has streamed because input is now complete");
116                        last_tool_use
117                            .input
118                            .get("g")
119                            .expect("'g' has streamed because input is now complete");
120                    } else {
121                        if !last_tool_use.is_input_complete
122                            && last_tool_use.input.get("g").is_none()
123                        {
124                            saw_partial_tool_use = true;
125                        }
126                    }
127                } else {
128                    panic!("last content should be a tool use");
129                }
130            });
131        }
132    }
133
134    assert!(
135        saw_partial_tool_use,
136        "should see at least one partially streamed tool use in the history"
137    );
138}
139
140#[gpui::test]
141async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
142    let AgentTest { model, agent, .. } = setup(cx).await;
143
144    // Test concurrent tool calls with different delay times
145    let events = agent
146        .update(cx, |agent, cx| {
147            agent.add_tool(DelayTool);
148            agent.send(
149                model.clone(),
150                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
151                cx,
152            )
153        })
154        .collect()
155        .await;
156
157    let stop_reasons = stop_events(events);
158    if stop_reasons.len() == 2 {
159        assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
160    } else if stop_reasons.len() == 3 {
161        assert_eq!(
162            stop_reasons,
163            vec![
164                StopReason::ToolUse,
165                StopReason::ToolUse,
166                StopReason::EndTurn
167            ]
168        );
169    } else {
170        panic!("Expected either 1 or 2 tool uses followed by end turn");
171    }
172
173    agent.update(cx, |agent, _cx| {
174        let last_message = agent.messages().last().unwrap();
175        let text = last_message
176            .content
177            .iter()
178            .filter_map(|content| {
179                if let MessageContent::Text(text) = content {
180                    Some(text.as_str())
181                } else {
182                    None
183                }
184            })
185            .collect::<String>();
186
187        assert!(text.contains("Ding"));
188    });
189}
190
191/// Filters out the stop events for asserting against in tests
192fn stop_events(
193    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
194) -> Vec<StopReason> {
195    result_events
196        .into_iter()
197        .filter_map(|event| match event.unwrap() {
198            LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
199            _ => None,
200        })
201        .collect()
202}
203
204struct AgentTest {
205    model: Arc<dyn LanguageModel>,
206    agent: Entity<Thread>,
207}
208
209async fn setup(cx: &mut TestAppContext) -> AgentTest {
210    cx.executor().allow_parking();
211    cx.update(settings::init);
212    let fs = FakeFs::new(cx.executor().clone());
213    // let project = Project::test(fs.clone(), [], cx).await;
214    // let action_log = cx.new(|_| ActionLog::new(project.clone()));
215    let templates = Templates::new();
216    let agent = cx.new(|_| Thread::new(templates));
217
218    let model = cx
219        .update(|cx| {
220            gpui_tokio::init(cx);
221            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
222            cx.set_http_client(Arc::new(http_client));
223
224            client::init_settings(cx);
225            let client = Client::production(cx);
226            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
227            language_model::init(client.clone(), cx);
228            language_models::init(user_store.clone(), client.clone(), cx);
229
230            let models = LanguageModelRegistry::read_global(cx);
231            let model = models
232                .available_models(cx)
233                .find(|model| model.id().0 == "claude-3-7-sonnet-latest")
234                .unwrap();
235
236            let provider = models.provider(&model.provider_id()).unwrap();
237            let authenticated = provider.authenticate(cx);
238
239            cx.spawn(async move |cx| {
240                authenticated.await.unwrap();
241                model
242            })
243        })
244        .await;
245
246    AgentTest { model, agent }
247}
248
249#[cfg(test)]
250#[ctor::ctor]
251fn init_logger() {
252    if std::env::var("RUST_LOG").is_ok() {
253        env_logger::init();
254    }
255}