1use super::*;
2use crate::templates::Templates;
3use client::{Client, UserStore};
4use gpui::{AppContext, Entity, TestAppContext};
5use language_model::{
6 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
7 LanguageModelRegistry, MessageContent, StopReason,
8};
9use reqwest_client::ReqwestClient;
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use smol::stream::StreamExt;
13use std::{sync::Arc, time::Duration};
14
15mod test_tools;
16use test_tools::*;
17
18#[gpui::test]
19async fn test_echo(cx: &mut TestAppContext) {
20 let ThreadTest { model, thread, .. } = setup(cx).await;
21
22 let events = thread
23 .update(cx, |thread, cx| {
24 thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
25 })
26 .collect()
27 .await;
28 thread.update(cx, |thread, _cx| {
29 assert_eq!(
30 thread.messages().last().unwrap().content,
31 vec![MessageContent::Text("Hello".to_string())]
32 );
33 });
34 assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
35}
36
37#[gpui::test]
38async fn test_basic_tool_calls(cx: &mut TestAppContext) {
39 let ThreadTest { model, thread, .. } = setup(cx).await;
40
41 // Test a tool call that's likely to complete *before* streaming stops.
42 let events = thread
43 .update(cx, |thread, cx| {
44 thread.add_tool(EchoTool);
45 thread.send(
46 model.clone(),
47 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
48 cx,
49 )
50 })
51 .collect()
52 .await;
53 assert_eq!(
54 stop_events(events),
55 vec![StopReason::ToolUse, StopReason::EndTurn]
56 );
57
58 // Test a tool calls that's likely to complete *after* streaming stops.
59 let events = thread
60 .update(cx, |thread, cx| {
61 thread.remove_tool(&AgentTool::name(&EchoTool));
62 thread.add_tool(DelayTool);
63 thread.send(
64 model.clone(),
65 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
66 cx,
67 )
68 })
69 .collect()
70 .await;
71 assert_eq!(
72 stop_events(events),
73 vec![StopReason::ToolUse, StopReason::EndTurn]
74 );
75 thread.update(cx, |thread, _cx| {
76 assert!(thread
77 .messages()
78 .last()
79 .unwrap()
80 .content
81 .iter()
82 .any(|content| {
83 if let MessageContent::Text(text) = content {
84 text.contains("Ding")
85 } else {
86 false
87 }
88 }));
89 });
90}
91
92#[gpui::test]
93async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
94 let ThreadTest { model, thread, .. } = setup(cx).await;
95
96 // Test a tool call that's likely to complete *before* streaming stops.
97 let mut events = thread.update(cx, |thread, cx| {
98 thread.add_tool(WordListTool);
99 thread.send(model.clone(), "Test the word_list tool.", cx)
100 });
101
102 let mut saw_partial_tool_use = false;
103 while let Some(event) = events.next().await {
104 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
105 thread.update(cx, |thread, _cx| {
106 // Look for a tool use in the thread's last message
107 let last_content = thread.messages().last().unwrap().content.last().unwrap();
108 if let MessageContent::ToolUse(last_tool_use) = last_content {
109 assert_eq!(last_tool_use.name.as_ref(), "word_list");
110 if tool_use_event.is_input_complete {
111 last_tool_use
112 .input
113 .get("a")
114 .expect("'a' has streamed because input is now complete");
115 last_tool_use
116 .input
117 .get("g")
118 .expect("'g' has streamed because input is now complete");
119 } else {
120 if !last_tool_use.is_input_complete
121 && last_tool_use.input.get("g").is_none()
122 {
123 saw_partial_tool_use = true;
124 }
125 }
126 } else {
127 panic!("last content should be a tool use");
128 }
129 });
130 }
131 }
132
133 assert!(
134 saw_partial_tool_use,
135 "should see at least one partially streamed tool use in the history"
136 );
137}
138
139#[gpui::test]
140async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
141 let ThreadTest { model, thread, .. } = setup(cx).await;
142
143 // Test concurrent tool calls with different delay times
144 let events = thread
145 .update(cx, |thread, cx| {
146 thread.add_tool(DelayTool);
147 thread.send(
148 model.clone(),
149 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
150 cx,
151 )
152 })
153 .collect()
154 .await;
155
156 let stop_reasons = stop_events(events);
157 if stop_reasons.len() == 2 {
158 assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
159 } else if stop_reasons.len() == 3 {
160 assert_eq!(
161 stop_reasons,
162 vec![
163 StopReason::ToolUse,
164 StopReason::ToolUse,
165 StopReason::EndTurn
166 ]
167 );
168 } else {
169 panic!("Expected either 1 or 2 tool uses followed by end turn");
170 }
171
172 thread.update(cx, |thread, _cx| {
173 let last_message = thread.messages().last().unwrap();
174 let text = last_message
175 .content
176 .iter()
177 .filter_map(|content| {
178 if let MessageContent::Text(text) = content {
179 Some(text.as_str())
180 } else {
181 None
182 }
183 })
184 .collect::<String>();
185
186 assert!(text.contains("Ding"));
187 });
188}
189
190/// Filters out the stop events for asserting against in tests
191fn stop_events(
192 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
193) -> Vec<StopReason> {
194 result_events
195 .into_iter()
196 .filter_map(|event| match event.unwrap() {
197 LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
198 _ => None,
199 })
200 .collect()
201}
202
203struct ThreadTest {
204 model: Arc<dyn LanguageModel>,
205 thread: Entity<Thread>,
206}
207
208async fn setup(cx: &mut TestAppContext) -> ThreadTest {
209 cx.executor().allow_parking();
210 cx.update(settings::init);
211 let templates = Templates::new();
212
213 let model = cx
214 .update(|cx| {
215 gpui_tokio::init(cx);
216 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
217 cx.set_http_client(Arc::new(http_client));
218
219 client::init_settings(cx);
220 let client = Client::production(cx);
221 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
222 language_model::init(client.clone(), cx);
223 language_models::init(user_store.clone(), client.clone(), cx);
224
225 let models = LanguageModelRegistry::read_global(cx);
226 let model = models
227 .available_models(cx)
228 .find(|model| model.id().0 == "claude-3-7-sonnet-latest")
229 .unwrap();
230
231 let provider = models.provider(&model.provider_id()).unwrap();
232 let authenticated = provider.authenticate(cx);
233
234 cx.spawn(async move |_cx| {
235 authenticated.await.unwrap();
236 model
237 })
238 })
239 .await;
240
241 let thread = cx.new(|_| Thread::new(templates, model.clone()));
242
243 ThreadTest { model, thread }
244}
245
246#[cfg(test)]
247#[ctor::ctor]
248fn init_logger() {
249 if std::env::var("RUST_LOG").is_ok() {
250 env_logger::init();
251 }
252}