1use super::*;
2use crate::templates::Templates;
3use client::{Client, UserStore};
4use fs::FakeFs;
5use gpui::{AppContext, Entity, TestAppContext};
6use language_model::{
7 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
8 LanguageModelRegistry, MessageContent, StopReason,
9};
10use reqwest_client::ReqwestClient;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use smol::stream::StreamExt;
14use std::{sync::Arc, time::Duration};
15
16mod test_tools;
17use test_tools::*;
18
19#[gpui::test]
20async fn test_echo(cx: &mut TestAppContext) {
21 let AgentTest { model, agent, .. } = setup(cx).await;
22
23 let events = agent
24 .update(cx, |agent, cx| {
25 agent.send(model.clone(), "Testing: Reply with 'Hello'", cx)
26 })
27 .collect()
28 .await;
29 agent.update(cx, |agent, _cx| {
30 assert_eq!(
31 agent.messages().last().unwrap().content,
32 vec![MessageContent::Text("Hello".to_string())]
33 );
34 });
35 assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
36}
37
38#[gpui::test]
39async fn test_basic_tool_calls(cx: &mut TestAppContext) {
40 let AgentTest { model, agent, .. } = setup(cx).await;
41
42 // Test a tool call that's likely to complete *before* streaming stops.
43 let events = agent
44 .update(cx, |agent, cx| {
45 agent.add_tool(EchoTool);
46 agent.send(
47 model.clone(),
48 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
49 cx,
50 )
51 })
52 .collect()
53 .await;
54 assert_eq!(
55 stop_events(events),
56 vec![StopReason::ToolUse, StopReason::EndTurn]
57 );
58
59 // Test a tool calls that's likely to complete *after* streaming stops.
60 let events = agent
61 .update(cx, |agent, cx| {
62 agent.remove_tool(&AgentTool::name(&EchoTool));
63 agent.add_tool(DelayTool);
64 agent.send(
65 model.clone(),
66 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
67 cx,
68 )
69 })
70 .collect()
71 .await;
72 assert_eq!(
73 stop_events(events),
74 vec![StopReason::ToolUse, StopReason::EndTurn]
75 );
76 agent.update(cx, |agent, _cx| {
77 assert!(agent
78 .messages()
79 .last()
80 .unwrap()
81 .content
82 .iter()
83 .any(|content| {
84 if let MessageContent::Text(text) = content {
85 text.contains("Ding")
86 } else {
87 false
88 }
89 }));
90 });
91}
92
93#[gpui::test]
94async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
95 let AgentTest { model, agent, .. } = setup(cx).await;
96
97 // Test a tool call that's likely to complete *before* streaming stops.
98 let mut events = agent.update(cx, |agent, cx| {
99 agent.add_tool(WordListTool);
100 agent.send(model.clone(), "Test the word_list tool.", cx)
101 });
102
103 let mut saw_partial_tool_use = false;
104 while let Some(event) = events.next().await {
105 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
106 agent.update(cx, |agent, _cx| {
107 // Look for a tool use in the agent's last message
108 let last_content = agent.messages().last().unwrap().content.last().unwrap();
109 if let MessageContent::ToolUse(last_tool_use) = last_content {
110 assert_eq!(last_tool_use.name.as_ref(), "word_list");
111 if tool_use_event.is_input_complete {
112 last_tool_use
113 .input
114 .get("a")
115 .expect("'a' has streamed because input is now complete");
116 last_tool_use
117 .input
118 .get("g")
119 .expect("'g' has streamed because input is now complete");
120 } else {
121 if !last_tool_use.is_input_complete
122 && last_tool_use.input.get("g").is_none()
123 {
124 saw_partial_tool_use = true;
125 }
126 }
127 } else {
128 panic!("last content should be a tool use");
129 }
130 });
131 }
132 }
133
134 assert!(
135 saw_partial_tool_use,
136 "should see at least one partially streamed tool use in the history"
137 );
138}
139
140#[gpui::test]
141async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
142 let AgentTest { model, agent, .. } = setup(cx).await;
143
144 // Test concurrent tool calls with different delay times
145 let events = agent
146 .update(cx, |agent, cx| {
147 agent.add_tool(DelayTool);
148 agent.send(
149 model.clone(),
150 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
151 cx,
152 )
153 })
154 .collect()
155 .await;
156
157 let stop_reasons = stop_events(events);
158 if stop_reasons.len() == 2 {
159 assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
160 } else if stop_reasons.len() == 3 {
161 assert_eq!(
162 stop_reasons,
163 vec![
164 StopReason::ToolUse,
165 StopReason::ToolUse,
166 StopReason::EndTurn
167 ]
168 );
169 } else {
170 panic!("Expected either 1 or 2 tool uses followed by end turn");
171 }
172
173 agent.update(cx, |agent, _cx| {
174 let last_message = agent.messages().last().unwrap();
175 let text = last_message
176 .content
177 .iter()
178 .filter_map(|content| {
179 if let MessageContent::Text(text) = content {
180 Some(text.as_str())
181 } else {
182 None
183 }
184 })
185 .collect::<String>();
186
187 assert!(text.contains("Ding"));
188 });
189}
190
191/// Filters out the stop events for asserting against in tests
192fn stop_events(
193 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
194) -> Vec<StopReason> {
195 result_events
196 .into_iter()
197 .filter_map(|event| match event.unwrap() {
198 LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
199 _ => None,
200 })
201 .collect()
202}
203
204struct AgentTest {
205 model: Arc<dyn LanguageModel>,
206 agent: Entity<Thread>,
207}
208
209async fn setup(cx: &mut TestAppContext) -> AgentTest {
210 cx.executor().allow_parking();
211 cx.update(settings::init);
212 let fs = FakeFs::new(cx.executor().clone());
213 // let project = Project::test(fs.clone(), [], cx).await;
214 // let action_log = cx.new(|_| ActionLog::new(project.clone()));
215 let templates = Templates::new();
216 let agent = cx.new(|_| Thread::new(templates));
217
218 let model = cx
219 .update(|cx| {
220 gpui_tokio::init(cx);
221 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
222 cx.set_http_client(Arc::new(http_client));
223
224 client::init_settings(cx);
225 let client = Client::production(cx);
226 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
227 language_model::init(client.clone(), cx);
228 language_models::init(user_store.clone(), client.clone(), cx);
229
230 let models = LanguageModelRegistry::read_global(cx);
231 let model = models
232 .available_models(cx)
233 .find(|model| model.id().0 == "claude-3-7-sonnet-latest")
234 .unwrap();
235
236 let provider = models.provider(&model.provider_id()).unwrap();
237 let authenticated = provider.authenticate(cx);
238
239 cx.spawn(async move |cx| {
240 authenticated.await.unwrap();
241 model
242 })
243 })
244 .await;
245
246 AgentTest { model, agent }
247}
248
249#[cfg(test)]
250#[ctor::ctor]
251fn init_logger() {
252 if std::env::var("RUST_LOG").is_ok() {
253 env_logger::init();
254 }
255}