1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp};
5use agent_settings::AgentProfileId;
6use anyhow::Result;
7use client::{Client, UserStore};
8use fs::{FakeFs, Fs};
9use futures::channel::mpsc::UnboundedReceiver;
10use gpui::{
11 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
12};
13use indoc::indoc;
14use language_model::{
15 LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
16 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
17 Role, StopReason, fake_provider::FakeLanguageModel,
18};
19use pretty_assertions::assert_eq;
20use project::Project;
21use prompt_store::ProjectContext;
22use reqwest_client::ReqwestClient;
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize};
25use serde_json::json;
26use settings::SettingsStore;
27use smol::stream::StreamExt;
28use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
29use util::path;
30
31mod test_tools;
32use test_tools::*;
33
34#[gpui::test]
35#[ignore = "can't run on CI yet"]
36async fn test_echo(cx: &mut TestAppContext) {
37 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
38
39 let events = thread
40 .update(cx, |thread, cx| {
41 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
42 })
43 .collect()
44 .await;
45 thread.update(cx, |thread, _cx| {
46 assert_eq!(
47 thread.last_message().unwrap().to_markdown(),
48 indoc! {"
49 ## Assistant
50
51 Hello
52 "}
53 )
54 });
55 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
56}
57
58#[gpui::test]
59#[ignore = "can't run on CI yet"]
60async fn test_thinking(cx: &mut TestAppContext) {
61 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
62
63 let events = thread
64 .update(cx, |thread, cx| {
65 thread.send(
66 UserMessageId::new(),
67 [indoc! {"
68 Testing:
69
70 Generate a thinking step where you just think the word 'Think',
71 and have your final answer be 'Hello'
72 "}],
73 cx,
74 )
75 })
76 .collect()
77 .await;
78 thread.update(cx, |thread, _cx| {
79 assert_eq!(
80 thread.last_message().unwrap().to_markdown(),
81 indoc! {"
82 ## Assistant
83
84 <think>Think</think>
85 Hello
86 "}
87 )
88 });
89 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
90}
91
92#[gpui::test]
93async fn test_system_prompt(cx: &mut TestAppContext) {
94 let ThreadTest {
95 model,
96 thread,
97 project_context,
98 ..
99 } = setup(cx, TestModel::Fake).await;
100 let fake_model = model.as_fake();
101
102 project_context.borrow_mut().shell = "test-shell".into();
103 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
104 thread.update(cx, |thread, cx| {
105 thread.send(UserMessageId::new(), ["abc"], cx)
106 });
107 cx.run_until_parked();
108 let mut pending_completions = fake_model.pending_completions();
109 assert_eq!(
110 pending_completions.len(),
111 1,
112 "unexpected pending completions: {:?}",
113 pending_completions
114 );
115
116 let pending_completion = pending_completions.pop().unwrap();
117 assert_eq!(pending_completion.messages[0].role, Role::System);
118
119 let system_message = &pending_completion.messages[0];
120 let system_prompt = system_message.content[0].to_str().unwrap();
121 assert!(
122 system_prompt.contains("test-shell"),
123 "unexpected system message: {:?}",
124 system_message
125 );
126 assert!(
127 system_prompt.contains("## Fixing Diagnostics"),
128 "unexpected system message: {:?}",
129 system_message
130 );
131}
132
133#[gpui::test]
134async fn test_prompt_caching(cx: &mut TestAppContext) {
135 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
136 let fake_model = model.as_fake();
137
138 // Send initial user message and verify it's cached
139 thread.update(cx, |thread, cx| {
140 thread.send(UserMessageId::new(), ["Message 1"], cx)
141 });
142 cx.run_until_parked();
143
144 let completion = fake_model.pending_completions().pop().unwrap();
145 assert_eq!(
146 completion.messages[1..],
147 vec![LanguageModelRequestMessage {
148 role: Role::User,
149 content: vec!["Message 1".into()],
150 cache: true
151 }]
152 );
153 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
154 "Response to Message 1".into(),
155 ));
156 fake_model.end_last_completion_stream();
157 cx.run_until_parked();
158
159 // Send another user message and verify only the latest is cached
160 thread.update(cx, |thread, cx| {
161 thread.send(UserMessageId::new(), ["Message 2"], cx)
162 });
163 cx.run_until_parked();
164
165 let completion = fake_model.pending_completions().pop().unwrap();
166 assert_eq!(
167 completion.messages[1..],
168 vec![
169 LanguageModelRequestMessage {
170 role: Role::User,
171 content: vec!["Message 1".into()],
172 cache: false
173 },
174 LanguageModelRequestMessage {
175 role: Role::Assistant,
176 content: vec!["Response to Message 1".into()],
177 cache: false
178 },
179 LanguageModelRequestMessage {
180 role: Role::User,
181 content: vec!["Message 2".into()],
182 cache: true
183 }
184 ]
185 );
186 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
187 "Response to Message 2".into(),
188 ));
189 fake_model.end_last_completion_stream();
190 cx.run_until_parked();
191
192 // Simulate a tool call and verify that the latest tool result is cached
193 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
194 thread.update(cx, |thread, cx| {
195 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
196 });
197 cx.run_until_parked();
198
199 let tool_use = LanguageModelToolUse {
200 id: "tool_1".into(),
201 name: EchoTool.name().into(),
202 raw_input: json!({"text": "test"}).to_string(),
203 input: json!({"text": "test"}),
204 is_input_complete: true,
205 };
206 fake_model
207 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
208 fake_model.end_last_completion_stream();
209 cx.run_until_parked();
210
211 let completion = fake_model.pending_completions().pop().unwrap();
212 let tool_result = LanguageModelToolResult {
213 tool_use_id: "tool_1".into(),
214 tool_name: EchoTool.name().into(),
215 is_error: false,
216 content: "test".into(),
217 output: Some("test".into()),
218 };
219 assert_eq!(
220 completion.messages[1..],
221 vec![
222 LanguageModelRequestMessage {
223 role: Role::User,
224 content: vec!["Message 1".into()],
225 cache: false
226 },
227 LanguageModelRequestMessage {
228 role: Role::Assistant,
229 content: vec!["Response to Message 1".into()],
230 cache: false
231 },
232 LanguageModelRequestMessage {
233 role: Role::User,
234 content: vec!["Message 2".into()],
235 cache: false
236 },
237 LanguageModelRequestMessage {
238 role: Role::Assistant,
239 content: vec!["Response to Message 2".into()],
240 cache: false
241 },
242 LanguageModelRequestMessage {
243 role: Role::User,
244 content: vec!["Use the echo tool".into()],
245 cache: false
246 },
247 LanguageModelRequestMessage {
248 role: Role::Assistant,
249 content: vec![MessageContent::ToolUse(tool_use)],
250 cache: false
251 },
252 LanguageModelRequestMessage {
253 role: Role::User,
254 content: vec![MessageContent::ToolResult(tool_result)],
255 cache: true
256 }
257 ]
258 );
259}
260
261#[gpui::test]
262#[ignore = "can't run on CI yet"]
263async fn test_basic_tool_calls(cx: &mut TestAppContext) {
264 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
265
266 // Test a tool call that's likely to complete *before* streaming stops.
267 let events = thread
268 .update(cx, |thread, cx| {
269 thread.add_tool(EchoTool);
270 thread.send(
271 UserMessageId::new(),
272 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
273 cx,
274 )
275 })
276 .collect()
277 .await;
278 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
279
280 // Test a tool calls that's likely to complete *after* streaming stops.
281 let events = thread
282 .update(cx, |thread, cx| {
283 thread.remove_tool(&AgentTool::name(&EchoTool));
284 thread.add_tool(DelayTool);
285 thread.send(
286 UserMessageId::new(),
287 [
288 "Now call the delay tool with 200ms.",
289 "When the timer goes off, then you echo the output of the tool.",
290 ],
291 cx,
292 )
293 })
294 .collect()
295 .await;
296 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
297 thread.update(cx, |thread, _cx| {
298 assert!(
299 thread
300 .last_message()
301 .unwrap()
302 .as_agent_message()
303 .unwrap()
304 .content
305 .iter()
306 .any(|content| {
307 if let AgentMessageContent::Text(text) = content {
308 text.contains("Ding")
309 } else {
310 false
311 }
312 }),
313 "{}",
314 thread.to_markdown()
315 );
316 });
317}
318
319#[gpui::test]
320#[ignore = "can't run on CI yet"]
321async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
322 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
323
324 // Test a tool call that's likely to complete *before* streaming stops.
325 let mut events = thread.update(cx, |thread, cx| {
326 thread.add_tool(WordListTool);
327 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
328 });
329
330 let mut saw_partial_tool_use = false;
331 while let Some(event) = events.next().await {
332 if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
333 thread.update(cx, |thread, _cx| {
334 // Look for a tool use in the thread's last message
335 let message = thread.last_message().unwrap();
336 let agent_message = message.as_agent_message().unwrap();
337 let last_content = agent_message.content.last().unwrap();
338 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
339 assert_eq!(last_tool_use.name.as_ref(), "word_list");
340 if tool_call.status == acp::ToolCallStatus::Pending {
341 if !last_tool_use.is_input_complete
342 && last_tool_use.input.get("g").is_none()
343 {
344 saw_partial_tool_use = true;
345 }
346 } else {
347 last_tool_use
348 .input
349 .get("a")
350 .expect("'a' has streamed because input is now complete");
351 last_tool_use
352 .input
353 .get("g")
354 .expect("'g' has streamed because input is now complete");
355 }
356 } else {
357 panic!("last content should be a tool use");
358 }
359 });
360 }
361 }
362
363 assert!(
364 saw_partial_tool_use,
365 "should see at least one partially streamed tool use in the history"
366 );
367}
368
369#[gpui::test]
370async fn test_tool_authorization(cx: &mut TestAppContext) {
371 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
372 let fake_model = model.as_fake();
373
374 let mut events = thread.update(cx, |thread, cx| {
375 thread.add_tool(ToolRequiringPermission);
376 thread.send(UserMessageId::new(), ["abc"], cx)
377 });
378 cx.run_until_parked();
379 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
380 LanguageModelToolUse {
381 id: "tool_id_1".into(),
382 name: ToolRequiringPermission.name().into(),
383 raw_input: "{}".into(),
384 input: json!({}),
385 is_input_complete: true,
386 },
387 ));
388 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
389 LanguageModelToolUse {
390 id: "tool_id_2".into(),
391 name: ToolRequiringPermission.name().into(),
392 raw_input: "{}".into(),
393 input: json!({}),
394 is_input_complete: true,
395 },
396 ));
397 fake_model.end_last_completion_stream();
398 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
399 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
400
401 // Approve the first
402 tool_call_auth_1
403 .response
404 .send(tool_call_auth_1.options[1].id.clone())
405 .unwrap();
406 cx.run_until_parked();
407
408 // Reject the second
409 tool_call_auth_2
410 .response
411 .send(tool_call_auth_1.options[2].id.clone())
412 .unwrap();
413 cx.run_until_parked();
414
415 let completion = fake_model.pending_completions().pop().unwrap();
416 let message = completion.messages.last().unwrap();
417 assert_eq!(
418 message.content,
419 vec![
420 language_model::MessageContent::ToolResult(LanguageModelToolResult {
421 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
422 tool_name: ToolRequiringPermission.name().into(),
423 is_error: false,
424 content: "Allowed".into(),
425 output: Some("Allowed".into())
426 }),
427 language_model::MessageContent::ToolResult(LanguageModelToolResult {
428 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
429 tool_name: ToolRequiringPermission.name().into(),
430 is_error: true,
431 content: "Permission to run tool denied by user".into(),
432 output: None
433 })
434 ]
435 );
436
437 // Simulate yet another tool call.
438 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
439 LanguageModelToolUse {
440 id: "tool_id_3".into(),
441 name: ToolRequiringPermission.name().into(),
442 raw_input: "{}".into(),
443 input: json!({}),
444 is_input_complete: true,
445 },
446 ));
447 fake_model.end_last_completion_stream();
448
449 // Respond by always allowing tools.
450 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
451 tool_call_auth_3
452 .response
453 .send(tool_call_auth_3.options[0].id.clone())
454 .unwrap();
455 cx.run_until_parked();
456 let completion = fake_model.pending_completions().pop().unwrap();
457 let message = completion.messages.last().unwrap();
458 assert_eq!(
459 message.content,
460 vec![language_model::MessageContent::ToolResult(
461 LanguageModelToolResult {
462 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
463 tool_name: ToolRequiringPermission.name().into(),
464 is_error: false,
465 content: "Allowed".into(),
466 output: Some("Allowed".into())
467 }
468 )]
469 );
470
471 // Simulate a final tool call, ensuring we don't trigger authorization.
472 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
473 LanguageModelToolUse {
474 id: "tool_id_4".into(),
475 name: ToolRequiringPermission.name().into(),
476 raw_input: "{}".into(),
477 input: json!({}),
478 is_input_complete: true,
479 },
480 ));
481 fake_model.end_last_completion_stream();
482 cx.run_until_parked();
483 let completion = fake_model.pending_completions().pop().unwrap();
484 let message = completion.messages.last().unwrap();
485 assert_eq!(
486 message.content,
487 vec![language_model::MessageContent::ToolResult(
488 LanguageModelToolResult {
489 tool_use_id: "tool_id_4".into(),
490 tool_name: ToolRequiringPermission.name().into(),
491 is_error: false,
492 content: "Allowed".into(),
493 output: Some("Allowed".into())
494 }
495 )]
496 );
497}
498
499#[gpui::test]
500async fn test_tool_hallucination(cx: &mut TestAppContext) {
501 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
502 let fake_model = model.as_fake();
503
504 let mut events = thread.update(cx, |thread, cx| {
505 thread.send(UserMessageId::new(), ["abc"], cx)
506 });
507 cx.run_until_parked();
508 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
509 LanguageModelToolUse {
510 id: "tool_id_1".into(),
511 name: "nonexistent_tool".into(),
512 raw_input: "{}".into(),
513 input: json!({}),
514 is_input_complete: true,
515 },
516 ));
517 fake_model.end_last_completion_stream();
518
519 let tool_call = expect_tool_call(&mut events).await;
520 assert_eq!(tool_call.title, "nonexistent_tool");
521 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
522 let update = expect_tool_call_update_fields(&mut events).await;
523 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
524}
525
526#[gpui::test]
527async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
528 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
529 let fake_model = model.as_fake();
530
531 let events = thread.update(cx, |thread, cx| {
532 thread.add_tool(EchoTool);
533 thread.send(UserMessageId::new(), ["abc"], cx)
534 });
535 cx.run_until_parked();
536 let tool_use = LanguageModelToolUse {
537 id: "tool_id_1".into(),
538 name: EchoTool.name().into(),
539 raw_input: "{}".into(),
540 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
541 is_input_complete: true,
542 };
543 fake_model
544 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
545 fake_model.end_last_completion_stream();
546
547 cx.run_until_parked();
548 let completion = fake_model.pending_completions().pop().unwrap();
549 let tool_result = LanguageModelToolResult {
550 tool_use_id: "tool_id_1".into(),
551 tool_name: EchoTool.name().into(),
552 is_error: false,
553 content: "def".into(),
554 output: Some("def".into()),
555 };
556 assert_eq!(
557 completion.messages[1..],
558 vec![
559 LanguageModelRequestMessage {
560 role: Role::User,
561 content: vec!["abc".into()],
562 cache: false
563 },
564 LanguageModelRequestMessage {
565 role: Role::Assistant,
566 content: vec![MessageContent::ToolUse(tool_use.clone())],
567 cache: false
568 },
569 LanguageModelRequestMessage {
570 role: Role::User,
571 content: vec![MessageContent::ToolResult(tool_result.clone())],
572 cache: true
573 },
574 ]
575 );
576
577 // Simulate reaching tool use limit.
578 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
579 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
580 ));
581 fake_model.end_last_completion_stream();
582 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
583 assert!(
584 last_event
585 .unwrap_err()
586 .is::<language_model::ToolUseLimitReachedError>()
587 );
588
589 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
590 cx.run_until_parked();
591 let completion = fake_model.pending_completions().pop().unwrap();
592 assert_eq!(
593 completion.messages[1..],
594 vec![
595 LanguageModelRequestMessage {
596 role: Role::User,
597 content: vec!["abc".into()],
598 cache: false
599 },
600 LanguageModelRequestMessage {
601 role: Role::Assistant,
602 content: vec![MessageContent::ToolUse(tool_use)],
603 cache: false
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec![MessageContent::ToolResult(tool_result)],
608 cache: false
609 },
610 LanguageModelRequestMessage {
611 role: Role::User,
612 content: vec!["Continue where you left off".into()],
613 cache: true
614 }
615 ]
616 );
617
618 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
619 fake_model.end_last_completion_stream();
620 events.collect::<Vec<_>>().await;
621 thread.read_with(cx, |thread, _cx| {
622 assert_eq!(
623 thread.last_message().unwrap().to_markdown(),
624 indoc! {"
625 ## Assistant
626
627 Done
628 "}
629 )
630 });
631
632 // Ensure we error if calling resume when tool use limit was *not* reached.
633 let error = thread
634 .update(cx, |thread, cx| thread.resume(cx))
635 .unwrap_err();
636 assert_eq!(
637 error.to_string(),
638 "can only resume after tool use limit is reached"
639 )
640}
641
642#[gpui::test]
643async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
644 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
645 let fake_model = model.as_fake();
646
647 let events = thread.update(cx, |thread, cx| {
648 thread.add_tool(EchoTool);
649 thread.send(UserMessageId::new(), ["abc"], cx)
650 });
651 cx.run_until_parked();
652
653 let tool_use = LanguageModelToolUse {
654 id: "tool_id_1".into(),
655 name: EchoTool.name().into(),
656 raw_input: "{}".into(),
657 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
658 is_input_complete: true,
659 };
660 let tool_result = LanguageModelToolResult {
661 tool_use_id: "tool_id_1".into(),
662 tool_name: EchoTool.name().into(),
663 is_error: false,
664 content: "def".into(),
665 output: Some("def".into()),
666 };
667 fake_model
668 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
669 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
670 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
671 ));
672 fake_model.end_last_completion_stream();
673 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
674 assert!(
675 last_event
676 .unwrap_err()
677 .is::<language_model::ToolUseLimitReachedError>()
678 );
679
680 thread.update(cx, |thread, cx| {
681 thread.send(UserMessageId::new(), vec!["ghi"], cx)
682 });
683 cx.run_until_parked();
684 let completion = fake_model.pending_completions().pop().unwrap();
685 assert_eq!(
686 completion.messages[1..],
687 vec![
688 LanguageModelRequestMessage {
689 role: Role::User,
690 content: vec!["abc".into()],
691 cache: false
692 },
693 LanguageModelRequestMessage {
694 role: Role::Assistant,
695 content: vec![MessageContent::ToolUse(tool_use)],
696 cache: false
697 },
698 LanguageModelRequestMessage {
699 role: Role::User,
700 content: vec![MessageContent::ToolResult(tool_result)],
701 cache: false
702 },
703 LanguageModelRequestMessage {
704 role: Role::User,
705 content: vec!["ghi".into()],
706 cache: true
707 }
708 ]
709 );
710}
711
712async fn expect_tool_call(
713 events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
714) -> acp::ToolCall {
715 let event = events
716 .next()
717 .await
718 .expect("no tool call authorization event received")
719 .unwrap();
720 match event {
721 AgentResponseEvent::ToolCall(tool_call) => return tool_call,
722 event => {
723 panic!("Unexpected event {event:?}");
724 }
725 }
726}
727
728async fn expect_tool_call_update_fields(
729 events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
730) -> acp::ToolCallUpdate {
731 let event = events
732 .next()
733 .await
734 .expect("no tool call authorization event received")
735 .unwrap();
736 match event {
737 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
738 return update;
739 }
740 event => {
741 panic!("Unexpected event {event:?}");
742 }
743 }
744}
745
746async fn next_tool_call_authorization(
747 events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
748) -> ToolCallAuthorization {
749 loop {
750 let event = events
751 .next()
752 .await
753 .expect("no tool call authorization event received")
754 .unwrap();
755 if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
756 let permission_kinds = tool_call_authorization
757 .options
758 .iter()
759 .map(|o| o.kind)
760 .collect::<Vec<_>>();
761 assert_eq!(
762 permission_kinds,
763 vec![
764 acp::PermissionOptionKind::AllowAlways,
765 acp::PermissionOptionKind::AllowOnce,
766 acp::PermissionOptionKind::RejectOnce,
767 ]
768 );
769 return tool_call_authorization;
770 }
771 }
772}
773
774#[gpui::test]
775#[ignore = "can't run on CI yet"]
776async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
777 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
778
779 // Test concurrent tool calls with different delay times
780 let events = thread
781 .update(cx, |thread, cx| {
782 thread.add_tool(DelayTool);
783 thread.send(
784 UserMessageId::new(),
785 [
786 "Call the delay tool twice in the same message.",
787 "Once with 100ms. Once with 300ms.",
788 "When both timers are complete, describe the outputs.",
789 ],
790 cx,
791 )
792 })
793 .collect()
794 .await;
795
796 let stop_reasons = stop_events(events);
797 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
798
799 thread.update(cx, |thread, _cx| {
800 let last_message = thread.last_message().unwrap();
801 let agent_message = last_message.as_agent_message().unwrap();
802 let text = agent_message
803 .content
804 .iter()
805 .filter_map(|content| {
806 if let AgentMessageContent::Text(text) = content {
807 Some(text.as_str())
808 } else {
809 None
810 }
811 })
812 .collect::<String>();
813
814 assert!(text.contains("Ding"));
815 });
816}
817
818#[gpui::test]
819async fn test_profiles(cx: &mut TestAppContext) {
820 let ThreadTest {
821 model, thread, fs, ..
822 } = setup(cx, TestModel::Fake).await;
823 let fake_model = model.as_fake();
824
825 thread.update(cx, |thread, _cx| {
826 thread.add_tool(DelayTool);
827 thread.add_tool(EchoTool);
828 thread.add_tool(InfiniteTool);
829 });
830
831 // Override profiles and wait for settings to be loaded.
832 fs.insert_file(
833 paths::settings_file(),
834 json!({
835 "agent": {
836 "profiles": {
837 "test-1": {
838 "name": "Test Profile 1",
839 "tools": {
840 EchoTool.name(): true,
841 DelayTool.name(): true,
842 }
843 },
844 "test-2": {
845 "name": "Test Profile 2",
846 "tools": {
847 InfiniteTool.name(): true,
848 }
849 }
850 }
851 }
852 })
853 .to_string()
854 .into_bytes(),
855 )
856 .await;
857 cx.run_until_parked();
858
859 // Test that test-1 profile (default) has echo and delay tools
860 thread.update(cx, |thread, cx| {
861 thread.set_profile(AgentProfileId("test-1".into()));
862 thread.send(UserMessageId::new(), ["test"], cx);
863 });
864 cx.run_until_parked();
865
866 let mut pending_completions = fake_model.pending_completions();
867 assert_eq!(pending_completions.len(), 1);
868 let completion = pending_completions.pop().unwrap();
869 let tool_names: Vec<String> = completion
870 .tools
871 .iter()
872 .map(|tool| tool.name.clone())
873 .collect();
874 assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
875 fake_model.end_last_completion_stream();
876
877 // Switch to test-2 profile, and verify that it has only the infinite tool.
878 thread.update(cx, |thread, cx| {
879 thread.set_profile(AgentProfileId("test-2".into()));
880 thread.send(UserMessageId::new(), ["test2"], cx)
881 });
882 cx.run_until_parked();
883 let mut pending_completions = fake_model.pending_completions();
884 assert_eq!(pending_completions.len(), 1);
885 let completion = pending_completions.pop().unwrap();
886 let tool_names: Vec<String> = completion
887 .tools
888 .iter()
889 .map(|tool| tool.name.clone())
890 .collect();
891 assert_eq!(tool_names, vec![InfiniteTool.name()]);
892}
893
894#[gpui::test]
895#[ignore = "can't run on CI yet"]
896async fn test_cancellation(cx: &mut TestAppContext) {
897 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
898
899 let mut events = thread.update(cx, |thread, cx| {
900 thread.add_tool(InfiniteTool);
901 thread.add_tool(EchoTool);
902 thread.send(
903 UserMessageId::new(),
904 ["Call the echo tool, then call the infinite tool, then explain their output"],
905 cx,
906 )
907 });
908
909 // Wait until both tools are called.
910 let mut expected_tools = vec!["Echo", "Infinite Tool"];
911 let mut echo_id = None;
912 let mut echo_completed = false;
913 while let Some(event) = events.next().await {
914 match event.unwrap() {
915 AgentResponseEvent::ToolCall(tool_call) => {
916 assert_eq!(tool_call.title, expected_tools.remove(0));
917 if tool_call.title == "Echo" {
918 echo_id = Some(tool_call.id);
919 }
920 }
921 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
922 acp::ToolCallUpdate {
923 id,
924 fields:
925 acp::ToolCallUpdateFields {
926 status: Some(acp::ToolCallStatus::Completed),
927 ..
928 },
929 },
930 )) if Some(&id) == echo_id.as_ref() => {
931 echo_completed = true;
932 }
933 _ => {}
934 }
935
936 if expected_tools.is_empty() && echo_completed {
937 break;
938 }
939 }
940
941 // Cancel the current send and ensure that the event stream is closed, even
942 // if one of the tools is still running.
943 thread.update(cx, |thread, _cx| thread.cancel());
944 events.collect::<Vec<_>>().await;
945
946 // Ensure we can still send a new message after cancellation.
947 let events = thread
948 .update(cx, |thread, cx| {
949 thread.send(
950 UserMessageId::new(),
951 ["Testing: reply with 'Hello' then stop."],
952 cx,
953 )
954 })
955 .collect::<Vec<_>>()
956 .await;
957 thread.update(cx, |thread, _cx| {
958 let message = thread.last_message().unwrap();
959 let agent_message = message.as_agent_message().unwrap();
960 assert_eq!(
961 agent_message.content,
962 vec![AgentMessageContent::Text("Hello".to_string())]
963 );
964 });
965 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
966}
967
968#[gpui::test]
969async fn test_refusal(cx: &mut TestAppContext) {
970 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
971 let fake_model = model.as_fake();
972
973 let events = thread.update(cx, |thread, cx| {
974 thread.send(UserMessageId::new(), ["Hello"], cx)
975 });
976 cx.run_until_parked();
977 thread.read_with(cx, |thread, _| {
978 assert_eq!(
979 thread.to_markdown(),
980 indoc! {"
981 ## User
982
983 Hello
984 "}
985 );
986 });
987
988 fake_model.send_last_completion_stream_text_chunk("Hey!");
989 cx.run_until_parked();
990 thread.read_with(cx, |thread, _| {
991 assert_eq!(
992 thread.to_markdown(),
993 indoc! {"
994 ## User
995
996 Hello
997
998 ## Assistant
999
1000 Hey!
1001 "}
1002 );
1003 });
1004
1005 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1006 fake_model
1007 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1008 let events = events.collect::<Vec<_>>().await;
1009 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1010 thread.read_with(cx, |thread, _| {
1011 assert_eq!(thread.to_markdown(), "");
1012 });
1013}
1014
1015#[gpui::test]
1016async fn test_truncate(cx: &mut TestAppContext) {
1017 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1018 let fake_model = model.as_fake();
1019
1020 let message_id = UserMessageId::new();
1021 thread.update(cx, |thread, cx| {
1022 thread.send(message_id.clone(), ["Hello"], cx)
1023 });
1024 cx.run_until_parked();
1025 thread.read_with(cx, |thread, _| {
1026 assert_eq!(
1027 thread.to_markdown(),
1028 indoc! {"
1029 ## User
1030
1031 Hello
1032 "}
1033 );
1034 });
1035
1036 fake_model.send_last_completion_stream_text_chunk("Hey!");
1037 cx.run_until_parked();
1038 thread.read_with(cx, |thread, _| {
1039 assert_eq!(
1040 thread.to_markdown(),
1041 indoc! {"
1042 ## User
1043
1044 Hello
1045
1046 ## Assistant
1047
1048 Hey!
1049 "}
1050 );
1051 });
1052
1053 thread
1054 .update(cx, |thread, _cx| thread.truncate(message_id))
1055 .unwrap();
1056 cx.run_until_parked();
1057 thread.read_with(cx, |thread, _| {
1058 assert_eq!(thread.to_markdown(), "");
1059 });
1060
1061 // Ensure we can still send a new message after truncation.
1062 thread.update(cx, |thread, cx| {
1063 thread.send(UserMessageId::new(), ["Hi"], cx)
1064 });
1065 thread.update(cx, |thread, _cx| {
1066 assert_eq!(
1067 thread.to_markdown(),
1068 indoc! {"
1069 ## User
1070
1071 Hi
1072 "}
1073 );
1074 });
1075 cx.run_until_parked();
1076 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1077 cx.run_until_parked();
1078 thread.read_with(cx, |thread, _| {
1079 assert_eq!(
1080 thread.to_markdown(),
1081 indoc! {"
1082 ## User
1083
1084 Hi
1085
1086 ## Assistant
1087
1088 Ahoy!
1089 "}
1090 );
1091 });
1092}
1093
1094#[gpui::test]
1095async fn test_agent_connection(cx: &mut TestAppContext) {
1096 cx.update(settings::init);
1097 let templates = Templates::new();
1098
1099 // Initialize language model system with test provider
1100 cx.update(|cx| {
1101 gpui_tokio::init(cx);
1102 client::init_settings(cx);
1103
1104 let http_client = FakeHttpClient::with_404_response();
1105 let clock = Arc::new(clock::FakeSystemClock::new());
1106 let client = Client::new(clock, http_client, cx);
1107 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1108 language_model::init(client.clone(), cx);
1109 language_models::init(user_store.clone(), client.clone(), cx);
1110 Project::init_settings(cx);
1111 LanguageModelRegistry::test(cx);
1112 agent_settings::init(cx);
1113 });
1114 cx.executor().forbid_parking();
1115
1116 // Create a project for new_thread
1117 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1118 fake_fs.insert_tree(path!("/test"), json!({})).await;
1119 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1120 let cwd = Path::new("/test");
1121
1122 // Create agent and connection
1123 let agent = NativeAgent::new(
1124 project.clone(),
1125 templates.clone(),
1126 None,
1127 fake_fs.clone(),
1128 &mut cx.to_async(),
1129 )
1130 .await
1131 .unwrap();
1132 let connection = NativeAgentConnection(agent.clone());
1133
1134 // Test model_selector returns Some
1135 let selector_opt = connection.model_selector();
1136 assert!(
1137 selector_opt.is_some(),
1138 "agent2 should always support ModelSelector"
1139 );
1140 let selector = selector_opt.unwrap();
1141
1142 // Test list_models
1143 let listed_models = cx
1144 .update(|cx| selector.list_models(cx))
1145 .await
1146 .expect("list_models should succeed");
1147 let AgentModelList::Grouped(listed_models) = listed_models else {
1148 panic!("Unexpected model list type");
1149 };
1150 assert!(!listed_models.is_empty(), "should have at least one model");
1151 assert_eq!(
1152 listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1153 "fake/fake"
1154 );
1155
1156 // Create a thread using new_thread
1157 let connection_rc = Rc::new(connection.clone());
1158 let acp_thread = cx
1159 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1160 .await
1161 .expect("new_thread should succeed");
1162
1163 // Get the session_id from the AcpThread
1164 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1165
1166 // Test selected_model returns the default
1167 let model = cx
1168 .update(|cx| selector.selected_model(&session_id, cx))
1169 .await
1170 .expect("selected_model should succeed");
1171 let model = cx
1172 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1173 .unwrap();
1174 let model = model.as_fake();
1175 assert_eq!(model.id().0, "fake", "should return default model");
1176
1177 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1178 cx.run_until_parked();
1179 model.send_last_completion_stream_text_chunk("def");
1180 cx.run_until_parked();
1181 acp_thread.read_with(cx, |thread, cx| {
1182 assert_eq!(
1183 thread.to_markdown(cx),
1184 indoc! {"
1185 ## User
1186
1187 abc
1188
1189 ## Assistant
1190
1191 def
1192
1193 "}
1194 )
1195 });
1196
1197 // Test cancel
1198 cx.update(|cx| connection.cancel(&session_id, cx));
1199 request.await.expect("prompt should fail gracefully");
1200
1201 // Ensure that dropping the ACP thread causes the native thread to be
1202 // dropped as well.
1203 cx.update(|_| drop(acp_thread));
1204 let result = cx
1205 .update(|cx| {
1206 connection.prompt(
1207 Some(acp_thread::UserMessageId::new()),
1208 acp::PromptRequest {
1209 session_id: session_id.clone(),
1210 prompt: vec!["ghi".into()],
1211 },
1212 cx,
1213 )
1214 })
1215 .await;
1216 assert_eq!(
1217 result.as_ref().unwrap_err().to_string(),
1218 "Session not found",
1219 "unexpected result: {:?}",
1220 result
1221 );
1222}
1223
1224#[gpui::test]
1225async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1226 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1227 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1228 let fake_model = model.as_fake();
1229
1230 let mut events = thread.update(cx, |thread, cx| {
1231 thread.send(UserMessageId::new(), ["Think"], cx)
1232 });
1233 cx.run_until_parked();
1234
1235 // Simulate streaming partial input.
1236 let input = json!({});
1237 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1238 LanguageModelToolUse {
1239 id: "1".into(),
1240 name: ThinkingTool.name().into(),
1241 raw_input: input.to_string(),
1242 input,
1243 is_input_complete: false,
1244 },
1245 ));
1246
1247 // Input streaming completed
1248 let input = json!({ "content": "Thinking hard!" });
1249 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1250 LanguageModelToolUse {
1251 id: "1".into(),
1252 name: "thinking".into(),
1253 raw_input: input.to_string(),
1254 input,
1255 is_input_complete: true,
1256 },
1257 ));
1258 fake_model.end_last_completion_stream();
1259 cx.run_until_parked();
1260
1261 let tool_call = expect_tool_call(&mut events).await;
1262 assert_eq!(
1263 tool_call,
1264 acp::ToolCall {
1265 id: acp::ToolCallId("1".into()),
1266 title: "Thinking".into(),
1267 kind: acp::ToolKind::Think,
1268 status: acp::ToolCallStatus::Pending,
1269 content: vec![],
1270 locations: vec![],
1271 raw_input: Some(json!({})),
1272 raw_output: None,
1273 }
1274 );
1275 let update = expect_tool_call_update_fields(&mut events).await;
1276 assert_eq!(
1277 update,
1278 acp::ToolCallUpdate {
1279 id: acp::ToolCallId("1".into()),
1280 fields: acp::ToolCallUpdateFields {
1281 title: Some("Thinking".into()),
1282 kind: Some(acp::ToolKind::Think),
1283 raw_input: Some(json!({ "content": "Thinking hard!" })),
1284 ..Default::default()
1285 },
1286 }
1287 );
1288 let update = expect_tool_call_update_fields(&mut events).await;
1289 assert_eq!(
1290 update,
1291 acp::ToolCallUpdate {
1292 id: acp::ToolCallId("1".into()),
1293 fields: acp::ToolCallUpdateFields {
1294 status: Some(acp::ToolCallStatus::InProgress),
1295 ..Default::default()
1296 },
1297 }
1298 );
1299 let update = expect_tool_call_update_fields(&mut events).await;
1300 assert_eq!(
1301 update,
1302 acp::ToolCallUpdate {
1303 id: acp::ToolCallId("1".into()),
1304 fields: acp::ToolCallUpdateFields {
1305 content: Some(vec!["Thinking hard!".into()]),
1306 ..Default::default()
1307 },
1308 }
1309 );
1310 let update = expect_tool_call_update_fields(&mut events).await;
1311 assert_eq!(
1312 update,
1313 acp::ToolCallUpdate {
1314 id: acp::ToolCallId("1".into()),
1315 fields: acp::ToolCallUpdateFields {
1316 status: Some(acp::ToolCallStatus::Completed),
1317 raw_output: Some("Finished thinking.".into()),
1318 ..Default::default()
1319 },
1320 }
1321 );
1322}
1323
1324/// Filters out the stop events for asserting against in tests
1325fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
1326 result_events
1327 .into_iter()
1328 .filter_map(|event| match event.unwrap() {
1329 AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
1330 _ => None,
1331 })
1332 .collect()
1333}
1334
1335struct ThreadTest {
1336 model: Arc<dyn LanguageModel>,
1337 thread: Entity<Thread>,
1338 project_context: Rc<RefCell<ProjectContext>>,
1339 fs: Arc<FakeFs>,
1340}
1341
1342enum TestModel {
1343 Sonnet4,
1344 Sonnet4Thinking,
1345 Fake,
1346}
1347
1348impl TestModel {
1349 fn id(&self) -> LanguageModelId {
1350 match self {
1351 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1352 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
1353 TestModel::Fake => unreachable!(),
1354 }
1355 }
1356}
1357
1358async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1359 cx.executor().allow_parking();
1360
1361 let fs = FakeFs::new(cx.background_executor.clone());
1362 fs.create_dir(paths::settings_file().parent().unwrap())
1363 .await
1364 .unwrap();
1365 fs.insert_file(
1366 paths::settings_file(),
1367 json!({
1368 "agent": {
1369 "default_profile": "test-profile",
1370 "profiles": {
1371 "test-profile": {
1372 "name": "Test Profile",
1373 "tools": {
1374 EchoTool.name(): true,
1375 DelayTool.name(): true,
1376 WordListTool.name(): true,
1377 ToolRequiringPermission.name(): true,
1378 InfiniteTool.name(): true,
1379 }
1380 }
1381 }
1382 }
1383 })
1384 .to_string()
1385 .into_bytes(),
1386 )
1387 .await;
1388
1389 cx.update(|cx| {
1390 settings::init(cx);
1391 Project::init_settings(cx);
1392 agent_settings::init(cx);
1393 gpui_tokio::init(cx);
1394 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1395 cx.set_http_client(Arc::new(http_client));
1396
1397 client::init_settings(cx);
1398 let client = Client::production(cx);
1399 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1400 language_model::init(client.clone(), cx);
1401 language_models::init(user_store.clone(), client.clone(), cx);
1402
1403 watch_settings(fs.clone(), cx);
1404 });
1405
1406 let templates = Templates::new();
1407
1408 fs.insert_tree(path!("/test"), json!({})).await;
1409 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1410
1411 let model = cx
1412 .update(|cx| {
1413 if let TestModel::Fake = model {
1414 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1415 } else {
1416 let model_id = model.id();
1417 let models = LanguageModelRegistry::read_global(cx);
1418 let model = models
1419 .available_models(cx)
1420 .find(|model| model.id() == model_id)
1421 .unwrap();
1422
1423 let provider = models.provider(&model.provider_id()).unwrap();
1424 let authenticated = provider.authenticate(cx);
1425
1426 cx.spawn(async move |_cx| {
1427 authenticated.await.unwrap();
1428 model
1429 })
1430 }
1431 })
1432 .await;
1433
1434 let project_context = Rc::new(RefCell::new(ProjectContext::default()));
1435 let context_server_registry =
1436 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1437 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1438 let thread = cx.new(|cx| {
1439 Thread::new(
1440 project,
1441 project_context.clone(),
1442 context_server_registry,
1443 action_log,
1444 templates,
1445 model.clone(),
1446 cx,
1447 )
1448 });
1449 ThreadTest {
1450 model,
1451 thread,
1452 project_context,
1453 fs,
1454 }
1455}
1456
1457#[cfg(test)]
1458#[ctor::ctor]
1459fn init_logger() {
1460 if std::env::var("RUST_LOG").is_ok() {
1461 env_logger::init();
1462 }
1463}
1464
1465fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1466 let fs = fs.clone();
1467 cx.spawn({
1468 async move |cx| {
1469 let mut new_settings_content_rx = settings::watch_config_file(
1470 cx.background_executor(),
1471 fs,
1472 paths::settings_file().clone(),
1473 );
1474
1475 while let Some(new_settings_content) = new_settings_content_rx.next().await {
1476 cx.update(|cx| {
1477 SettingsStore::update_global(cx, |settings, cx| {
1478 settings.set_user_settings(&new_settings_content, cx)
1479 })
1480 })
1481 .ok();
1482 }
1483 }
1484 })
1485 .detach();
1486}