1use super::*;
2use client::{proto::language_server_prompt_request, Client, UserStore};
3use fs::FakeFs;
4use gpui::{AppContext, Entity, TestAppContext};
5use language_model::{
6 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
7 LanguageModelRegistry, MessageContent, StopReason,
8};
9use reqwest_client::ReqwestClient;
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use smol::stream::StreamExt;
13use std::{sync::Arc, time::Duration};
14
15mod test_tools;
16use test_tools::*;
17
18#[gpui::test]
19async fn test_echo(cx: &mut TestAppContext) {
20 let AgentTest { model, agent, .. } = setup(cx).await;
21
22 let events = agent
23 .update(cx, |agent, cx| {
24 agent.send(model.clone(), "Testing: Reply with 'Hello'", cx)
25 })
26 .collect()
27 .await;
28 agent.update(cx, |agent, _cx| {
29 assert_eq!(
30 agent.messages.last().unwrap().content,
31 vec![MessageContent::Text("Hello".to_string())]
32 );
33 });
34 assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
35}
36
37#[gpui::test]
38async fn test_basic_tool_calls(cx: &mut TestAppContext) {
39 let AgentTest { model, agent, .. } = setup(cx).await;
40
41 // Test a tool call that's likely to complete *before* streaming stops.
42 let events = agent
43 .update(cx, |agent, cx| {
44 agent.add_tool(EchoTool);
45 agent.send(
46 model.clone(),
47 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
48 cx,
49 )
50 })
51 .collect()
52 .await;
53 assert_eq!(
54 stop_events(events),
55 vec![StopReason::ToolUse, StopReason::EndTurn]
56 );
57
58 // Test a tool calls that's likely to complete *after* streaming stops.
59 let events = agent
60 .update(cx, |agent, cx| {
61 agent.remove_tool(&AgentTool::name(&EchoTool));
62 agent.add_tool(DelayTool);
63 agent.send(
64 model.clone(),
65 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
66 cx,
67 )
68 })
69 .collect()
70 .await;
71 assert_eq!(
72 stop_events(events),
73 vec![StopReason::ToolUse, StopReason::EndTurn]
74 );
75 agent.update(cx, |agent, _cx| {
76 assert!(agent
77 .messages
78 .last()
79 .unwrap()
80 .content
81 .iter()
82 .any(|content| {
83 if let MessageContent::Text(text) = content {
84 text.contains("Ding")
85 } else {
86 false
87 }
88 }));
89 });
90}
91
92#[gpui::test]
93async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
94 let AgentTest { model, agent, .. } = setup(cx).await;
95
96 // Test a tool call that's likely to complete *before* streaming stops.
97 let mut events = agent.update(cx, |agent, cx| {
98 agent.add_tool(WordListTool);
99 agent.send(model.clone(), "Test the word_list tool.", cx)
100 });
101
102 let mut saw_partial_tool_use = false;
103 while let Some(event) = events.next().await {
104 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
105 agent.update(cx, |agent, _cx| {
106 // Look for a tool use in the agent's last message
107 let last_content = agent.messages().last().unwrap().content.last().unwrap();
108 if let MessageContent::ToolUse(last_tool_use) = last_content {
109 assert_eq!(last_tool_use.name.as_ref(), "word_list");
110 if tool_use_event.is_input_complete {
111 last_tool_use
112 .input
113 .get("a")
114 .expect("'a' has streamed because input is now complete");
115 last_tool_use
116 .input
117 .get("g")
118 .expect("'g' has streamed because input is now complete");
119 } else {
120 if !last_tool_use.is_input_complete
121 && last_tool_use.input.get("g").is_none()
122 {
123 saw_partial_tool_use = true;
124 }
125 }
126 } else {
127 panic!("last content should be a tool use");
128 }
129 });
130 }
131 }
132
133 assert!(
134 saw_partial_tool_use,
135 "should see at least one partially streamed tool use in the history"
136 );
137}
138
139#[gpui::test]
140async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
141 let AgentTest { model, agent, .. } = setup(cx).await;
142
143 // Test concurrent tool calls with different delay times
144 let events = agent
145 .update(cx, |agent, cx| {
146 agent.add_tool(DelayTool);
147 agent.send(
148 model.clone(),
149 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
150 cx,
151 )
152 })
153 .collect()
154 .await;
155
156 let stop_reasons = stop_events(events);
157 if stop_reasons.len() == 2 {
158 assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
159 } else if stop_reasons.len() == 3 {
160 assert_eq!(
161 stop_reasons,
162 vec![
163 StopReason::ToolUse,
164 StopReason::ToolUse,
165 StopReason::EndTurn
166 ]
167 );
168 } else {
169 panic!("Expected either 1 or 2 tool uses followed by end turn");
170 }
171
172 agent.update(cx, |agent, _cx| {
173 let last_message = agent.messages.last().unwrap();
174 let text = last_message
175 .content
176 .iter()
177 .filter_map(|content| {
178 if let MessageContent::Text(text) = content {
179 Some(text.as_str())
180 } else {
181 None
182 }
183 })
184 .collect::<String>();
185
186 assert!(text.contains("Ding"));
187 });
188}
189
190/// Filters out the stop events for asserting against in tests
191fn stop_events(
192 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
193) -> Vec<StopReason> {
194 result_events
195 .into_iter()
196 .filter_map(|event| match event.unwrap() {
197 LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
198 _ => None,
199 })
200 .collect()
201}
202
203struct AgentTest {
204 model: Arc<dyn LanguageModel>,
205 agent: Entity<Thread>,
206}
207
208async fn setup(cx: &mut TestAppContext) -> AgentTest {
209 cx.executor().allow_parking();
210 cx.update(settings::init);
211 let fs = FakeFs::new(cx.executor().clone());
212 // let project = Project::test(fs.clone(), [], cx).await;
213 // let action_log = cx.new(|_| ActionLog::new(project.clone()));
214 let templates = Templates::new();
215 let agent = cx.new(|_| Thread::new(templates));
216
217 let model = cx
218 .update(|cx| {
219 gpui_tokio::init(cx);
220 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
221 cx.set_http_client(Arc::new(http_client));
222
223 client::init_settings(cx);
224 let client = Client::production(cx);
225 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
226 language_model::init(client.clone(), cx);
227 language_models::init(user_store.clone(), client.clone(), cx);
228
229 let models = LanguageModelRegistry::read_global(cx);
230 let model = models
231 .available_models(cx)
232 .find(|model| model.id().0 == "claude-3-7-sonnet-latest")
233 .unwrap();
234
235 let provider = models.provider(&model.provider_id()).unwrap();
236 let authenticated = provider.authenticate(cx);
237
238 cx.spawn(async move |cx| {
239 authenticated.await.unwrap();
240 model
241 })
242 })
243 .await;
244
245 AgentTest { model, agent }
246}
247
248#[cfg(test)]
249#[ctor::ctor]
250fn init_logger() {
251 if std::env::var("RUST_LOG").is_ok() {
252 env_logger::init();
253 }
254}