1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use fs::{FakeFs, Fs};
8use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
9use gpui::{
10 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
11};
12use indoc::indoc;
13use language_model::{
14 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
15 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
16 LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
17 fake_provider::FakeLanguageModel,
18};
19use pretty_assertions::assert_eq;
20use project::Project;
21use prompt_store::ProjectContext;
22use reqwest_client::ReqwestClient;
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
28use util::path;
29
30mod test_tools;
31use test_tools::*;
32
33#[gpui::test]
34async fn test_echo(cx: &mut TestAppContext) {
35 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
36 let fake_model = model.as_fake();
37
38 let events = thread
39 .update(cx, |thread, cx| {
40 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
41 })
42 .unwrap();
43 cx.run_until_parked();
44 fake_model.send_last_completion_stream_text_chunk("Hello");
45 fake_model
46 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
47 fake_model.end_last_completion_stream();
48
49 let events = events.collect().await;
50 thread.update(cx, |thread, _cx| {
51 assert_eq!(
52 thread.last_message().unwrap().to_markdown(),
53 indoc! {"
54 ## Assistant
55
56 Hello
57 "}
58 )
59 });
60 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
61}
62
63#[gpui::test]
64async fn test_thinking(cx: &mut TestAppContext) {
65 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
66 let fake_model = model.as_fake();
67
68 let events = thread
69 .update(cx, |thread, cx| {
70 thread.send(
71 UserMessageId::new(),
72 [indoc! {"
73 Testing:
74
75 Generate a thinking step where you just think the word 'Think',
76 and have your final answer be 'Hello'
77 "}],
78 cx,
79 )
80 })
81 .unwrap();
82 cx.run_until_parked();
83 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
84 text: "Think".to_string(),
85 signature: None,
86 });
87 fake_model.send_last_completion_stream_text_chunk("Hello");
88 fake_model
89 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
90 fake_model.end_last_completion_stream();
91
92 let events = events.collect().await;
93 thread.update(cx, |thread, _cx| {
94 assert_eq!(
95 thread.last_message().unwrap().to_markdown(),
96 indoc! {"
97 ## Assistant
98
99 <think>Think</think>
100 Hello
101 "}
102 )
103 });
104 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
105}
106
107#[gpui::test]
108async fn test_system_prompt(cx: &mut TestAppContext) {
109 let ThreadTest {
110 model,
111 thread,
112 project_context,
113 ..
114 } = setup(cx, TestModel::Fake).await;
115 let fake_model = model.as_fake();
116
117 project_context.update(cx, |project_context, _cx| {
118 project_context.shell = "test-shell".into()
119 });
120 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
121 thread
122 .update(cx, |thread, cx| {
123 thread.send(UserMessageId::new(), ["abc"], cx)
124 })
125 .unwrap();
126 cx.run_until_parked();
127 let mut pending_completions = fake_model.pending_completions();
128 assert_eq!(
129 pending_completions.len(),
130 1,
131 "unexpected pending completions: {:?}",
132 pending_completions
133 );
134
135 let pending_completion = pending_completions.pop().unwrap();
136 assert_eq!(pending_completion.messages[0].role, Role::System);
137
138 let system_message = &pending_completion.messages[0];
139 let system_prompt = system_message.content[0].to_str().unwrap();
140 assert!(
141 system_prompt.contains("test-shell"),
142 "unexpected system message: {:?}",
143 system_message
144 );
145 assert!(
146 system_prompt.contains("## Fixing Diagnostics"),
147 "unexpected system message: {:?}",
148 system_message
149 );
150}
151
152#[gpui::test]
153async fn test_prompt_caching(cx: &mut TestAppContext) {
154 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
155 let fake_model = model.as_fake();
156
157 // Send initial user message and verify it's cached
158 thread
159 .update(cx, |thread, cx| {
160 thread.send(UserMessageId::new(), ["Message 1"], cx)
161 })
162 .unwrap();
163 cx.run_until_parked();
164
165 let completion = fake_model.pending_completions().pop().unwrap();
166 assert_eq!(
167 completion.messages[1..],
168 vec![LanguageModelRequestMessage {
169 role: Role::User,
170 content: vec!["Message 1".into()],
171 cache: true
172 }]
173 );
174 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
175 "Response to Message 1".into(),
176 ));
177 fake_model.end_last_completion_stream();
178 cx.run_until_parked();
179
180 // Send another user message and verify only the latest is cached
181 thread
182 .update(cx, |thread, cx| {
183 thread.send(UserMessageId::new(), ["Message 2"], cx)
184 })
185 .unwrap();
186 cx.run_until_parked();
187
188 let completion = fake_model.pending_completions().pop().unwrap();
189 assert_eq!(
190 completion.messages[1..],
191 vec![
192 LanguageModelRequestMessage {
193 role: Role::User,
194 content: vec!["Message 1".into()],
195 cache: false
196 },
197 LanguageModelRequestMessage {
198 role: Role::Assistant,
199 content: vec!["Response to Message 1".into()],
200 cache: false
201 },
202 LanguageModelRequestMessage {
203 role: Role::User,
204 content: vec!["Message 2".into()],
205 cache: true
206 }
207 ]
208 );
209 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
210 "Response to Message 2".into(),
211 ));
212 fake_model.end_last_completion_stream();
213 cx.run_until_parked();
214
215 // Simulate a tool call and verify that the latest tool result is cached
216 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
217 thread
218 .update(cx, |thread, cx| {
219 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
220 })
221 .unwrap();
222 cx.run_until_parked();
223
224 let tool_use = LanguageModelToolUse {
225 id: "tool_1".into(),
226 name: EchoTool::name().into(),
227 raw_input: json!({"text": "test"}).to_string(),
228 input: json!({"text": "test"}),
229 is_input_complete: true,
230 };
231 fake_model
232 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
233 fake_model.end_last_completion_stream();
234 cx.run_until_parked();
235
236 let completion = fake_model.pending_completions().pop().unwrap();
237 let tool_result = LanguageModelToolResult {
238 tool_use_id: "tool_1".into(),
239 tool_name: EchoTool::name().into(),
240 is_error: false,
241 content: "test".into(),
242 output: Some("test".into()),
243 };
244 assert_eq!(
245 completion.messages[1..],
246 vec![
247 LanguageModelRequestMessage {
248 role: Role::User,
249 content: vec!["Message 1".into()],
250 cache: false
251 },
252 LanguageModelRequestMessage {
253 role: Role::Assistant,
254 content: vec!["Response to Message 1".into()],
255 cache: false
256 },
257 LanguageModelRequestMessage {
258 role: Role::User,
259 content: vec!["Message 2".into()],
260 cache: false
261 },
262 LanguageModelRequestMessage {
263 role: Role::Assistant,
264 content: vec!["Response to Message 2".into()],
265 cache: false
266 },
267 LanguageModelRequestMessage {
268 role: Role::User,
269 content: vec!["Use the echo tool".into()],
270 cache: false
271 },
272 LanguageModelRequestMessage {
273 role: Role::Assistant,
274 content: vec![MessageContent::ToolUse(tool_use)],
275 cache: false
276 },
277 LanguageModelRequestMessage {
278 role: Role::User,
279 content: vec![MessageContent::ToolResult(tool_result)],
280 cache: true
281 }
282 ]
283 );
284}
285
286#[gpui::test]
287#[cfg_attr(not(feature = "e2e"), ignore)]
288async fn test_basic_tool_calls(cx: &mut TestAppContext) {
289 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
290
291 // Test a tool call that's likely to complete *before* streaming stops.
292 let events = thread
293 .update(cx, |thread, cx| {
294 thread.add_tool(EchoTool);
295 thread.send(
296 UserMessageId::new(),
297 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
298 cx,
299 )
300 })
301 .unwrap()
302 .collect()
303 .await;
304 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
305
306 // Test a tool calls that's likely to complete *after* streaming stops.
307 let events = thread
308 .update(cx, |thread, cx| {
309 thread.remove_tool(&EchoTool::name());
310 thread.add_tool(DelayTool);
311 thread.send(
312 UserMessageId::new(),
313 [
314 "Now call the delay tool with 200ms.",
315 "When the timer goes off, then you echo the output of the tool.",
316 ],
317 cx,
318 )
319 })
320 .unwrap()
321 .collect()
322 .await;
323 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
324 thread.update(cx, |thread, _cx| {
325 assert!(
326 thread
327 .last_message()
328 .unwrap()
329 .as_agent_message()
330 .unwrap()
331 .content
332 .iter()
333 .any(|content| {
334 if let AgentMessageContent::Text(text) = content {
335 text.contains("Ding")
336 } else {
337 false
338 }
339 }),
340 "{}",
341 thread.to_markdown()
342 );
343 });
344}
345
346#[gpui::test]
347#[cfg_attr(not(feature = "e2e"), ignore)]
348async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
349 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
350
351 // Test a tool call that's likely to complete *before* streaming stops.
352 let mut events = thread
353 .update(cx, |thread, cx| {
354 thread.add_tool(WordListTool);
355 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
356 })
357 .unwrap();
358
359 let mut saw_partial_tool_use = false;
360 while let Some(event) = events.next().await {
361 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
362 thread.update(cx, |thread, _cx| {
363 // Look for a tool use in the thread's last message
364 let message = thread.last_message().unwrap();
365 let agent_message = message.as_agent_message().unwrap();
366 let last_content = agent_message.content.last().unwrap();
367 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
368 assert_eq!(last_tool_use.name.as_ref(), "word_list");
369 if tool_call.status == acp::ToolCallStatus::Pending {
370 if !last_tool_use.is_input_complete
371 && last_tool_use.input.get("g").is_none()
372 {
373 saw_partial_tool_use = true;
374 }
375 } else {
376 last_tool_use
377 .input
378 .get("a")
379 .expect("'a' has streamed because input is now complete");
380 last_tool_use
381 .input
382 .get("g")
383 .expect("'g' has streamed because input is now complete");
384 }
385 } else {
386 panic!("last content should be a tool use");
387 }
388 });
389 }
390 }
391
392 assert!(
393 saw_partial_tool_use,
394 "should see at least one partially streamed tool use in the history"
395 );
396}
397
398#[gpui::test]
399async fn test_tool_authorization(cx: &mut TestAppContext) {
400 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
401 let fake_model = model.as_fake();
402
403 let mut events = thread
404 .update(cx, |thread, cx| {
405 thread.add_tool(ToolRequiringPermission);
406 thread.send(UserMessageId::new(), ["abc"], cx)
407 })
408 .unwrap();
409 cx.run_until_parked();
410 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
411 LanguageModelToolUse {
412 id: "tool_id_1".into(),
413 name: ToolRequiringPermission::name().into(),
414 raw_input: "{}".into(),
415 input: json!({}),
416 is_input_complete: true,
417 },
418 ));
419 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
420 LanguageModelToolUse {
421 id: "tool_id_2".into(),
422 name: ToolRequiringPermission::name().into(),
423 raw_input: "{}".into(),
424 input: json!({}),
425 is_input_complete: true,
426 },
427 ));
428 fake_model.end_last_completion_stream();
429 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
430 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
431
432 // Approve the first
433 tool_call_auth_1
434 .response
435 .send(tool_call_auth_1.options[1].id.clone())
436 .unwrap();
437 cx.run_until_parked();
438
439 // Reject the second
440 tool_call_auth_2
441 .response
442 .send(tool_call_auth_1.options[2].id.clone())
443 .unwrap();
444 cx.run_until_parked();
445
446 let completion = fake_model.pending_completions().pop().unwrap();
447 let message = completion.messages.last().unwrap();
448 assert_eq!(
449 message.content,
450 vec![
451 language_model::MessageContent::ToolResult(LanguageModelToolResult {
452 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
453 tool_name: ToolRequiringPermission::name().into(),
454 is_error: false,
455 content: "Allowed".into(),
456 output: Some("Allowed".into())
457 }),
458 language_model::MessageContent::ToolResult(LanguageModelToolResult {
459 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
460 tool_name: ToolRequiringPermission::name().into(),
461 is_error: true,
462 content: "Permission to run tool denied by user".into(),
463 output: None
464 })
465 ]
466 );
467
468 // Simulate yet another tool call.
469 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
470 LanguageModelToolUse {
471 id: "tool_id_3".into(),
472 name: ToolRequiringPermission::name().into(),
473 raw_input: "{}".into(),
474 input: json!({}),
475 is_input_complete: true,
476 },
477 ));
478 fake_model.end_last_completion_stream();
479
480 // Respond by always allowing tools.
481 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
482 tool_call_auth_3
483 .response
484 .send(tool_call_auth_3.options[0].id.clone())
485 .unwrap();
486 cx.run_until_parked();
487 let completion = fake_model.pending_completions().pop().unwrap();
488 let message = completion.messages.last().unwrap();
489 assert_eq!(
490 message.content,
491 vec![language_model::MessageContent::ToolResult(
492 LanguageModelToolResult {
493 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
494 tool_name: ToolRequiringPermission::name().into(),
495 is_error: false,
496 content: "Allowed".into(),
497 output: Some("Allowed".into())
498 }
499 )]
500 );
501
502 // Simulate a final tool call, ensuring we don't trigger authorization.
503 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
504 LanguageModelToolUse {
505 id: "tool_id_4".into(),
506 name: ToolRequiringPermission::name().into(),
507 raw_input: "{}".into(),
508 input: json!({}),
509 is_input_complete: true,
510 },
511 ));
512 fake_model.end_last_completion_stream();
513 cx.run_until_parked();
514 let completion = fake_model.pending_completions().pop().unwrap();
515 let message = completion.messages.last().unwrap();
516 assert_eq!(
517 message.content,
518 vec![language_model::MessageContent::ToolResult(
519 LanguageModelToolResult {
520 tool_use_id: "tool_id_4".into(),
521 tool_name: ToolRequiringPermission::name().into(),
522 is_error: false,
523 content: "Allowed".into(),
524 output: Some("Allowed".into())
525 }
526 )]
527 );
528}
529
530#[gpui::test]
531async fn test_tool_hallucination(cx: &mut TestAppContext) {
532 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
533 let fake_model = model.as_fake();
534
535 let mut events = thread
536 .update(cx, |thread, cx| {
537 thread.send(UserMessageId::new(), ["abc"], cx)
538 })
539 .unwrap();
540 cx.run_until_parked();
541 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
542 LanguageModelToolUse {
543 id: "tool_id_1".into(),
544 name: "nonexistent_tool".into(),
545 raw_input: "{}".into(),
546 input: json!({}),
547 is_input_complete: true,
548 },
549 ));
550 fake_model.end_last_completion_stream();
551
552 let tool_call = expect_tool_call(&mut events).await;
553 assert_eq!(tool_call.title, "nonexistent_tool");
554 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
555 let update = expect_tool_call_update_fields(&mut events).await;
556 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
557}
558
559#[gpui::test]
560async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
561 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
562 let fake_model = model.as_fake();
563
564 let events = thread
565 .update(cx, |thread, cx| {
566 thread.add_tool(EchoTool);
567 thread.send(UserMessageId::new(), ["abc"], cx)
568 })
569 .unwrap();
570 cx.run_until_parked();
571 let tool_use = LanguageModelToolUse {
572 id: "tool_id_1".into(),
573 name: EchoTool::name().into(),
574 raw_input: "{}".into(),
575 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
576 is_input_complete: true,
577 };
578 fake_model
579 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
580 fake_model.end_last_completion_stream();
581
582 cx.run_until_parked();
583 let completion = fake_model.pending_completions().pop().unwrap();
584 let tool_result = LanguageModelToolResult {
585 tool_use_id: "tool_id_1".into(),
586 tool_name: EchoTool::name().into(),
587 is_error: false,
588 content: "def".into(),
589 output: Some("def".into()),
590 };
591 assert_eq!(
592 completion.messages[1..],
593 vec![
594 LanguageModelRequestMessage {
595 role: Role::User,
596 content: vec!["abc".into()],
597 cache: false
598 },
599 LanguageModelRequestMessage {
600 role: Role::Assistant,
601 content: vec![MessageContent::ToolUse(tool_use.clone())],
602 cache: false
603 },
604 LanguageModelRequestMessage {
605 role: Role::User,
606 content: vec![MessageContent::ToolResult(tool_result.clone())],
607 cache: true
608 },
609 ]
610 );
611
612 // Simulate reaching tool use limit.
613 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
614 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
615 ));
616 fake_model.end_last_completion_stream();
617 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
618 assert!(
619 last_event
620 .unwrap_err()
621 .is::<language_model::ToolUseLimitReachedError>()
622 );
623
624 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
625 cx.run_until_parked();
626 let completion = fake_model.pending_completions().pop().unwrap();
627 assert_eq!(
628 completion.messages[1..],
629 vec![
630 LanguageModelRequestMessage {
631 role: Role::User,
632 content: vec!["abc".into()],
633 cache: false
634 },
635 LanguageModelRequestMessage {
636 role: Role::Assistant,
637 content: vec![MessageContent::ToolUse(tool_use)],
638 cache: false
639 },
640 LanguageModelRequestMessage {
641 role: Role::User,
642 content: vec![MessageContent::ToolResult(tool_result)],
643 cache: false
644 },
645 LanguageModelRequestMessage {
646 role: Role::User,
647 content: vec!["Continue where you left off".into()],
648 cache: true
649 }
650 ]
651 );
652
653 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
654 fake_model.end_last_completion_stream();
655 events.collect::<Vec<_>>().await;
656 thread.read_with(cx, |thread, _cx| {
657 assert_eq!(
658 thread.last_message().unwrap().to_markdown(),
659 indoc! {"
660 ## Assistant
661
662 Done
663 "}
664 )
665 });
666
667 // Ensure we error if calling resume when tool use limit was *not* reached.
668 let error = thread
669 .update(cx, |thread, cx| thread.resume(cx))
670 .unwrap_err();
671 assert_eq!(
672 error.to_string(),
673 "can only resume after tool use limit is reached"
674 )
675}
676
677#[gpui::test]
678async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
679 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
680 let fake_model = model.as_fake();
681
682 let events = thread
683 .update(cx, |thread, cx| {
684 thread.add_tool(EchoTool);
685 thread.send(UserMessageId::new(), ["abc"], cx)
686 })
687 .unwrap();
688 cx.run_until_parked();
689
690 let tool_use = LanguageModelToolUse {
691 id: "tool_id_1".into(),
692 name: EchoTool::name().into(),
693 raw_input: "{}".into(),
694 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
695 is_input_complete: true,
696 };
697 let tool_result = LanguageModelToolResult {
698 tool_use_id: "tool_id_1".into(),
699 tool_name: EchoTool::name().into(),
700 is_error: false,
701 content: "def".into(),
702 output: Some("def".into()),
703 };
704 fake_model
705 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
706 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
707 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
708 ));
709 fake_model.end_last_completion_stream();
710 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
711 assert!(
712 last_event
713 .unwrap_err()
714 .is::<language_model::ToolUseLimitReachedError>()
715 );
716
717 thread
718 .update(cx, |thread, cx| {
719 thread.send(UserMessageId::new(), vec!["ghi"], cx)
720 })
721 .unwrap();
722 cx.run_until_parked();
723 let completion = fake_model.pending_completions().pop().unwrap();
724 assert_eq!(
725 completion.messages[1..],
726 vec![
727 LanguageModelRequestMessage {
728 role: Role::User,
729 content: vec!["abc".into()],
730 cache: false
731 },
732 LanguageModelRequestMessage {
733 role: Role::Assistant,
734 content: vec![MessageContent::ToolUse(tool_use)],
735 cache: false
736 },
737 LanguageModelRequestMessage {
738 role: Role::User,
739 content: vec![MessageContent::ToolResult(tool_result)],
740 cache: false
741 },
742 LanguageModelRequestMessage {
743 role: Role::User,
744 content: vec!["ghi".into()],
745 cache: true
746 }
747 ]
748 );
749}
750
751async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
752 let event = events
753 .next()
754 .await
755 .expect("no tool call authorization event received")
756 .unwrap();
757 match event {
758 ThreadEvent::ToolCall(tool_call) => tool_call,
759 event => {
760 panic!("Unexpected event {event:?}");
761 }
762 }
763}
764
765async fn expect_tool_call_update_fields(
766 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
767) -> acp::ToolCallUpdate {
768 let event = events
769 .next()
770 .await
771 .expect("no tool call authorization event received")
772 .unwrap();
773 match event {
774 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
775 event => {
776 panic!("Unexpected event {event:?}");
777 }
778 }
779}
780
781async fn next_tool_call_authorization(
782 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
783) -> ToolCallAuthorization {
784 loop {
785 let event = events
786 .next()
787 .await
788 .expect("no tool call authorization event received")
789 .unwrap();
790 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
791 let permission_kinds = tool_call_authorization
792 .options
793 .iter()
794 .map(|o| o.kind)
795 .collect::<Vec<_>>();
796 assert_eq!(
797 permission_kinds,
798 vec![
799 acp::PermissionOptionKind::AllowAlways,
800 acp::PermissionOptionKind::AllowOnce,
801 acp::PermissionOptionKind::RejectOnce,
802 ]
803 );
804 return tool_call_authorization;
805 }
806 }
807}
808
809#[gpui::test]
810#[cfg_attr(not(feature = "e2e"), ignore)]
811async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
812 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
813
814 // Test concurrent tool calls with different delay times
815 let events = thread
816 .update(cx, |thread, cx| {
817 thread.add_tool(DelayTool);
818 thread.send(
819 UserMessageId::new(),
820 [
821 "Call the delay tool twice in the same message.",
822 "Once with 100ms. Once with 300ms.",
823 "When both timers are complete, describe the outputs.",
824 ],
825 cx,
826 )
827 })
828 .unwrap()
829 .collect()
830 .await;
831
832 let stop_reasons = stop_events(events);
833 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
834
835 thread.update(cx, |thread, _cx| {
836 let last_message = thread.last_message().unwrap();
837 let agent_message = last_message.as_agent_message().unwrap();
838 let text = agent_message
839 .content
840 .iter()
841 .filter_map(|content| {
842 if let AgentMessageContent::Text(text) = content {
843 Some(text.as_str())
844 } else {
845 None
846 }
847 })
848 .collect::<String>();
849
850 assert!(text.contains("Ding"));
851 });
852}
853
854#[gpui::test]
855async fn test_profiles(cx: &mut TestAppContext) {
856 let ThreadTest {
857 model, thread, fs, ..
858 } = setup(cx, TestModel::Fake).await;
859 let fake_model = model.as_fake();
860
861 thread.update(cx, |thread, _cx| {
862 thread.add_tool(DelayTool);
863 thread.add_tool(EchoTool);
864 thread.add_tool(InfiniteTool);
865 });
866
867 // Override profiles and wait for settings to be loaded.
868 fs.insert_file(
869 paths::settings_file(),
870 json!({
871 "agent": {
872 "profiles": {
873 "test-1": {
874 "name": "Test Profile 1",
875 "tools": {
876 EchoTool::name(): true,
877 DelayTool::name(): true,
878 }
879 },
880 "test-2": {
881 "name": "Test Profile 2",
882 "tools": {
883 InfiniteTool::name(): true,
884 }
885 }
886 }
887 }
888 })
889 .to_string()
890 .into_bytes(),
891 )
892 .await;
893 cx.run_until_parked();
894
895 // Test that test-1 profile (default) has echo and delay tools
896 thread
897 .update(cx, |thread, cx| {
898 thread.set_profile(AgentProfileId("test-1".into()));
899 thread.send(UserMessageId::new(), ["test"], cx)
900 })
901 .unwrap();
902 cx.run_until_parked();
903
904 let mut pending_completions = fake_model.pending_completions();
905 assert_eq!(pending_completions.len(), 1);
906 let completion = pending_completions.pop().unwrap();
907 let tool_names: Vec<String> = completion
908 .tools
909 .iter()
910 .map(|tool| tool.name.clone())
911 .collect();
912 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
913 fake_model.end_last_completion_stream();
914
915 // Switch to test-2 profile, and verify that it has only the infinite tool.
916 thread
917 .update(cx, |thread, cx| {
918 thread.set_profile(AgentProfileId("test-2".into()));
919 thread.send(UserMessageId::new(), ["test2"], cx)
920 })
921 .unwrap();
922 cx.run_until_parked();
923 let mut pending_completions = fake_model.pending_completions();
924 assert_eq!(pending_completions.len(), 1);
925 let completion = pending_completions.pop().unwrap();
926 let tool_names: Vec<String> = completion
927 .tools
928 .iter()
929 .map(|tool| tool.name.clone())
930 .collect();
931 assert_eq!(tool_names, vec![InfiniteTool::name()]);
932}
933
934#[gpui::test]
935#[cfg_attr(not(feature = "e2e"), ignore)]
936async fn test_cancellation(cx: &mut TestAppContext) {
937 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
938
939 let mut events = thread
940 .update(cx, |thread, cx| {
941 thread.add_tool(InfiniteTool);
942 thread.add_tool(EchoTool);
943 thread.send(
944 UserMessageId::new(),
945 ["Call the echo tool, then call the infinite tool, then explain their output"],
946 cx,
947 )
948 })
949 .unwrap();
950
951 // Wait until both tools are called.
952 let mut expected_tools = vec!["Echo", "Infinite Tool"];
953 let mut echo_id = None;
954 let mut echo_completed = false;
955 while let Some(event) = events.next().await {
956 match event.unwrap() {
957 ThreadEvent::ToolCall(tool_call) => {
958 assert_eq!(tool_call.title, expected_tools.remove(0));
959 if tool_call.title == "Echo" {
960 echo_id = Some(tool_call.id);
961 }
962 }
963 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
964 acp::ToolCallUpdate {
965 id,
966 fields:
967 acp::ToolCallUpdateFields {
968 status: Some(acp::ToolCallStatus::Completed),
969 ..
970 },
971 },
972 )) if Some(&id) == echo_id.as_ref() => {
973 echo_completed = true;
974 }
975 _ => {}
976 }
977
978 if expected_tools.is_empty() && echo_completed {
979 break;
980 }
981 }
982
983 // Cancel the current send and ensure that the event stream is closed, even
984 // if one of the tools is still running.
985 thread.update(cx, |thread, cx| thread.cancel(cx));
986 let events = events.collect::<Vec<_>>().await;
987 let last_event = events.last();
988 assert!(
989 matches!(
990 last_event,
991 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
992 ),
993 "unexpected event {last_event:?}"
994 );
995
996 // Ensure we can still send a new message after cancellation.
997 let events = thread
998 .update(cx, |thread, cx| {
999 thread.send(
1000 UserMessageId::new(),
1001 ["Testing: reply with 'Hello' then stop."],
1002 cx,
1003 )
1004 })
1005 .unwrap()
1006 .collect::<Vec<_>>()
1007 .await;
1008 thread.update(cx, |thread, _cx| {
1009 let message = thread.last_message().unwrap();
1010 let agent_message = message.as_agent_message().unwrap();
1011 assert_eq!(
1012 agent_message.content,
1013 vec![AgentMessageContent::Text("Hello".to_string())]
1014 );
1015 });
1016 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1017}
1018
1019#[gpui::test]
1020async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1021 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1022 let fake_model = model.as_fake();
1023
1024 let events_1 = thread
1025 .update(cx, |thread, cx| {
1026 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1027 })
1028 .unwrap();
1029 cx.run_until_parked();
1030 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1031 cx.run_until_parked();
1032
1033 let events_2 = thread
1034 .update(cx, |thread, cx| {
1035 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1036 })
1037 .unwrap();
1038 cx.run_until_parked();
1039 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1040 fake_model
1041 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1042 fake_model.end_last_completion_stream();
1043
1044 let events_1 = events_1.collect::<Vec<_>>().await;
1045 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1046 let events_2 = events_2.collect::<Vec<_>>().await;
1047 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1048}
1049
1050#[gpui::test]
1051async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1052 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1053 let fake_model = model.as_fake();
1054
1055 let events_1 = thread
1056 .update(cx, |thread, cx| {
1057 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1058 })
1059 .unwrap();
1060 cx.run_until_parked();
1061 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1062 fake_model
1063 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1064 fake_model.end_last_completion_stream();
1065 let events_1 = events_1.collect::<Vec<_>>().await;
1066
1067 let events_2 = thread
1068 .update(cx, |thread, cx| {
1069 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1070 })
1071 .unwrap();
1072 cx.run_until_parked();
1073 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1074 fake_model
1075 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1076 fake_model.end_last_completion_stream();
1077 let events_2 = events_2.collect::<Vec<_>>().await;
1078
1079 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1080 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1081}
1082
1083#[gpui::test]
1084async fn test_refusal(cx: &mut TestAppContext) {
1085 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1086 let fake_model = model.as_fake();
1087
1088 let events = thread
1089 .update(cx, |thread, cx| {
1090 thread.send(UserMessageId::new(), ["Hello"], cx)
1091 })
1092 .unwrap();
1093 cx.run_until_parked();
1094 thread.read_with(cx, |thread, _| {
1095 assert_eq!(
1096 thread.to_markdown(),
1097 indoc! {"
1098 ## User
1099
1100 Hello
1101 "}
1102 );
1103 });
1104
1105 fake_model.send_last_completion_stream_text_chunk("Hey!");
1106 cx.run_until_parked();
1107 thread.read_with(cx, |thread, _| {
1108 assert_eq!(
1109 thread.to_markdown(),
1110 indoc! {"
1111 ## User
1112
1113 Hello
1114
1115 ## Assistant
1116
1117 Hey!
1118 "}
1119 );
1120 });
1121
1122 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1123 fake_model
1124 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1125 let events = events.collect::<Vec<_>>().await;
1126 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1127 thread.read_with(cx, |thread, _| {
1128 assert_eq!(thread.to_markdown(), "");
1129 });
1130}
1131
1132#[gpui::test]
1133async fn test_truncate_first_message(cx: &mut TestAppContext) {
1134 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1135 let fake_model = model.as_fake();
1136
1137 let message_id = UserMessageId::new();
1138 thread
1139 .update(cx, |thread, cx| {
1140 thread.send(message_id.clone(), ["Hello"], cx)
1141 })
1142 .unwrap();
1143 cx.run_until_parked();
1144 thread.read_with(cx, |thread, _| {
1145 assert_eq!(
1146 thread.to_markdown(),
1147 indoc! {"
1148 ## User
1149
1150 Hello
1151 "}
1152 );
1153 assert_eq!(thread.latest_token_usage(), None);
1154 });
1155
1156 fake_model.send_last_completion_stream_text_chunk("Hey!");
1157 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1158 language_model::TokenUsage {
1159 input_tokens: 32_000,
1160 output_tokens: 16_000,
1161 cache_creation_input_tokens: 0,
1162 cache_read_input_tokens: 0,
1163 },
1164 ));
1165 cx.run_until_parked();
1166 thread.read_with(cx, |thread, _| {
1167 assert_eq!(
1168 thread.to_markdown(),
1169 indoc! {"
1170 ## User
1171
1172 Hello
1173
1174 ## Assistant
1175
1176 Hey!
1177 "}
1178 );
1179 assert_eq!(
1180 thread.latest_token_usage(),
1181 Some(acp_thread::TokenUsage {
1182 used_tokens: 32_000 + 16_000,
1183 max_tokens: 1_000_000,
1184 })
1185 );
1186 });
1187
1188 thread
1189 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1190 .unwrap();
1191 cx.run_until_parked();
1192 thread.read_with(cx, |thread, _| {
1193 assert_eq!(thread.to_markdown(), "");
1194 assert_eq!(thread.latest_token_usage(), None);
1195 });
1196
1197 // Ensure we can still send a new message after truncation.
1198 thread
1199 .update(cx, |thread, cx| {
1200 thread.send(UserMessageId::new(), ["Hi"], cx)
1201 })
1202 .unwrap();
1203 thread.update(cx, |thread, _cx| {
1204 assert_eq!(
1205 thread.to_markdown(),
1206 indoc! {"
1207 ## User
1208
1209 Hi
1210 "}
1211 );
1212 });
1213 cx.run_until_parked();
1214 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1215 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1216 language_model::TokenUsage {
1217 input_tokens: 40_000,
1218 output_tokens: 20_000,
1219 cache_creation_input_tokens: 0,
1220 cache_read_input_tokens: 0,
1221 },
1222 ));
1223 cx.run_until_parked();
1224 thread.read_with(cx, |thread, _| {
1225 assert_eq!(
1226 thread.to_markdown(),
1227 indoc! {"
1228 ## User
1229
1230 Hi
1231
1232 ## Assistant
1233
1234 Ahoy!
1235 "}
1236 );
1237
1238 assert_eq!(
1239 thread.latest_token_usage(),
1240 Some(acp_thread::TokenUsage {
1241 used_tokens: 40_000 + 20_000,
1242 max_tokens: 1_000_000,
1243 })
1244 );
1245 });
1246}
1247
1248#[gpui::test]
1249async fn test_truncate_second_message(cx: &mut TestAppContext) {
1250 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1251 let fake_model = model.as_fake();
1252
1253 thread
1254 .update(cx, |thread, cx| {
1255 thread.send(UserMessageId::new(), ["Message 1"], cx)
1256 })
1257 .unwrap();
1258 cx.run_until_parked();
1259 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1260 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1261 language_model::TokenUsage {
1262 input_tokens: 32_000,
1263 output_tokens: 16_000,
1264 cache_creation_input_tokens: 0,
1265 cache_read_input_tokens: 0,
1266 },
1267 ));
1268 fake_model.end_last_completion_stream();
1269 cx.run_until_parked();
1270
1271 let assert_first_message_state = |cx: &mut TestAppContext| {
1272 thread.clone().read_with(cx, |thread, _| {
1273 assert_eq!(
1274 thread.to_markdown(),
1275 indoc! {"
1276 ## User
1277
1278 Message 1
1279
1280 ## Assistant
1281
1282 Message 1 response
1283 "}
1284 );
1285
1286 assert_eq!(
1287 thread.latest_token_usage(),
1288 Some(acp_thread::TokenUsage {
1289 used_tokens: 32_000 + 16_000,
1290 max_tokens: 1_000_000,
1291 })
1292 );
1293 });
1294 };
1295
1296 assert_first_message_state(cx);
1297
1298 let second_message_id = UserMessageId::new();
1299 thread
1300 .update(cx, |thread, cx| {
1301 thread.send(second_message_id.clone(), ["Message 2"], cx)
1302 })
1303 .unwrap();
1304 cx.run_until_parked();
1305
1306 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1307 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1308 language_model::TokenUsage {
1309 input_tokens: 40_000,
1310 output_tokens: 20_000,
1311 cache_creation_input_tokens: 0,
1312 cache_read_input_tokens: 0,
1313 },
1314 ));
1315 fake_model.end_last_completion_stream();
1316 cx.run_until_parked();
1317
1318 thread.read_with(cx, |thread, _| {
1319 assert_eq!(
1320 thread.to_markdown(),
1321 indoc! {"
1322 ## User
1323
1324 Message 1
1325
1326 ## Assistant
1327
1328 Message 1 response
1329
1330 ## User
1331
1332 Message 2
1333
1334 ## Assistant
1335
1336 Message 2 response
1337 "}
1338 );
1339
1340 assert_eq!(
1341 thread.latest_token_usage(),
1342 Some(acp_thread::TokenUsage {
1343 used_tokens: 40_000 + 20_000,
1344 max_tokens: 1_000_000,
1345 })
1346 );
1347 });
1348
1349 thread
1350 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1351 .unwrap();
1352 cx.run_until_parked();
1353
1354 assert_first_message_state(cx);
1355}
1356
1357#[gpui::test]
1358async fn test_title_generation(cx: &mut TestAppContext) {
1359 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1360 let fake_model = model.as_fake();
1361
1362 let summary_model = Arc::new(FakeLanguageModel::default());
1363 thread.update(cx, |thread, cx| {
1364 thread.set_summarization_model(Some(summary_model.clone()), cx)
1365 });
1366
1367 let send = thread
1368 .update(cx, |thread, cx| {
1369 thread.send(UserMessageId::new(), ["Hello"], cx)
1370 })
1371 .unwrap();
1372 cx.run_until_parked();
1373
1374 fake_model.send_last_completion_stream_text_chunk("Hey!");
1375 fake_model.end_last_completion_stream();
1376 cx.run_until_parked();
1377 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1378
1379 // Ensure the summary model has been invoked to generate a title.
1380 summary_model.send_last_completion_stream_text_chunk("Hello ");
1381 summary_model.send_last_completion_stream_text_chunk("world\nG");
1382 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1383 summary_model.end_last_completion_stream();
1384 send.collect::<Vec<_>>().await;
1385 cx.run_until_parked();
1386 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1387
1388 // Send another message, ensuring no title is generated this time.
1389 let send = thread
1390 .update(cx, |thread, cx| {
1391 thread.send(UserMessageId::new(), ["Hello again"], cx)
1392 })
1393 .unwrap();
1394 cx.run_until_parked();
1395 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1396 fake_model.end_last_completion_stream();
1397 cx.run_until_parked();
1398 assert_eq!(summary_model.pending_completions(), Vec::new());
1399 send.collect::<Vec<_>>().await;
1400 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1401}
1402
1403#[gpui::test]
1404async fn test_agent_connection(cx: &mut TestAppContext) {
1405 cx.update(settings::init);
1406 let templates = Templates::new();
1407
1408 // Initialize language model system with test provider
1409 cx.update(|cx| {
1410 gpui_tokio::init(cx);
1411 client::init_settings(cx);
1412
1413 let http_client = FakeHttpClient::with_404_response();
1414 let clock = Arc::new(clock::FakeSystemClock::new());
1415 let client = Client::new(clock, http_client, cx);
1416 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1417 Project::init_settings(cx);
1418 agent_settings::init(cx);
1419 language_model::init(client.clone(), cx);
1420 language_models::init(user_store, client.clone(), cx);
1421 LanguageModelRegistry::test(cx);
1422 });
1423 cx.executor().forbid_parking();
1424
1425 // Create a project for new_thread
1426 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1427 fake_fs.insert_tree(path!("/test"), json!({})).await;
1428 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1429 let cwd = Path::new("/test");
1430 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1431 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1432
1433 // Create agent and connection
1434 let agent = NativeAgent::new(
1435 project.clone(),
1436 history_store,
1437 templates.clone(),
1438 None,
1439 fake_fs.clone(),
1440 &mut cx.to_async(),
1441 )
1442 .await
1443 .unwrap();
1444 let connection = NativeAgentConnection(agent.clone());
1445
1446 // Test model_selector returns Some
1447 let selector_opt = connection.model_selector();
1448 assert!(
1449 selector_opt.is_some(),
1450 "agent2 should always support ModelSelector"
1451 );
1452 let selector = selector_opt.unwrap();
1453
1454 // Test list_models
1455 let listed_models = cx
1456 .update(|cx| selector.list_models(cx))
1457 .await
1458 .expect("list_models should succeed");
1459 let AgentModelList::Grouped(listed_models) = listed_models else {
1460 panic!("Unexpected model list type");
1461 };
1462 assert!(!listed_models.is_empty(), "should have at least one model");
1463 assert_eq!(
1464 listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1465 "fake/fake"
1466 );
1467
1468 // Create a thread using new_thread
1469 let connection_rc = Rc::new(connection.clone());
1470 let acp_thread = cx
1471 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1472 .await
1473 .expect("new_thread should succeed");
1474
1475 // Get the session_id from the AcpThread
1476 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1477
1478 // Test selected_model returns the default
1479 let model = cx
1480 .update(|cx| selector.selected_model(&session_id, cx))
1481 .await
1482 .expect("selected_model should succeed");
1483 let model = cx
1484 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1485 .unwrap();
1486 let model = model.as_fake();
1487 assert_eq!(model.id().0, "fake", "should return default model");
1488
1489 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1490 cx.run_until_parked();
1491 model.send_last_completion_stream_text_chunk("def");
1492 cx.run_until_parked();
1493 acp_thread.read_with(cx, |thread, cx| {
1494 assert_eq!(
1495 thread.to_markdown(cx),
1496 indoc! {"
1497 ## User
1498
1499 abc
1500
1501 ## Assistant
1502
1503 def
1504
1505 "}
1506 )
1507 });
1508
1509 // Test cancel
1510 cx.update(|cx| connection.cancel(&session_id, cx));
1511 request.await.expect("prompt should fail gracefully");
1512
1513 // Ensure that dropping the ACP thread causes the native thread to be
1514 // dropped as well.
1515 cx.update(|_| drop(acp_thread));
1516 let result = cx
1517 .update(|cx| {
1518 connection.prompt(
1519 Some(acp_thread::UserMessageId::new()),
1520 acp::PromptRequest {
1521 session_id: session_id.clone(),
1522 prompt: vec!["ghi".into()],
1523 },
1524 cx,
1525 )
1526 })
1527 .await;
1528 assert_eq!(
1529 result.as_ref().unwrap_err().to_string(),
1530 "Session not found",
1531 "unexpected result: {:?}",
1532 result
1533 );
1534}
1535
1536#[gpui::test]
1537async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1538 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1539 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1540 let fake_model = model.as_fake();
1541
1542 let mut events = thread
1543 .update(cx, |thread, cx| {
1544 thread.send(UserMessageId::new(), ["Think"], cx)
1545 })
1546 .unwrap();
1547 cx.run_until_parked();
1548
1549 // Simulate streaming partial input.
1550 let input = json!({});
1551 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1552 LanguageModelToolUse {
1553 id: "1".into(),
1554 name: ThinkingTool::name().into(),
1555 raw_input: input.to_string(),
1556 input,
1557 is_input_complete: false,
1558 },
1559 ));
1560
1561 // Input streaming completed
1562 let input = json!({ "content": "Thinking hard!" });
1563 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1564 LanguageModelToolUse {
1565 id: "1".into(),
1566 name: "thinking".into(),
1567 raw_input: input.to_string(),
1568 input,
1569 is_input_complete: true,
1570 },
1571 ));
1572 fake_model.end_last_completion_stream();
1573 cx.run_until_parked();
1574
1575 let tool_call = expect_tool_call(&mut events).await;
1576 assert_eq!(
1577 tool_call,
1578 acp::ToolCall {
1579 id: acp::ToolCallId("1".into()),
1580 title: "Thinking".into(),
1581 kind: acp::ToolKind::Think,
1582 status: acp::ToolCallStatus::Pending,
1583 content: vec![],
1584 locations: vec![],
1585 raw_input: Some(json!({})),
1586 raw_output: None,
1587 }
1588 );
1589 let update = expect_tool_call_update_fields(&mut events).await;
1590 assert_eq!(
1591 update,
1592 acp::ToolCallUpdate {
1593 id: acp::ToolCallId("1".into()),
1594 fields: acp::ToolCallUpdateFields {
1595 title: Some("Thinking".into()),
1596 kind: Some(acp::ToolKind::Think),
1597 raw_input: Some(json!({ "content": "Thinking hard!" })),
1598 ..Default::default()
1599 },
1600 }
1601 );
1602 let update = expect_tool_call_update_fields(&mut events).await;
1603 assert_eq!(
1604 update,
1605 acp::ToolCallUpdate {
1606 id: acp::ToolCallId("1".into()),
1607 fields: acp::ToolCallUpdateFields {
1608 status: Some(acp::ToolCallStatus::InProgress),
1609 ..Default::default()
1610 },
1611 }
1612 );
1613 let update = expect_tool_call_update_fields(&mut events).await;
1614 assert_eq!(
1615 update,
1616 acp::ToolCallUpdate {
1617 id: acp::ToolCallId("1".into()),
1618 fields: acp::ToolCallUpdateFields {
1619 content: Some(vec!["Thinking hard!".into()]),
1620 ..Default::default()
1621 },
1622 }
1623 );
1624 let update = expect_tool_call_update_fields(&mut events).await;
1625 assert_eq!(
1626 update,
1627 acp::ToolCallUpdate {
1628 id: acp::ToolCallId("1".into()),
1629 fields: acp::ToolCallUpdateFields {
1630 status: Some(acp::ToolCallStatus::Completed),
1631 raw_output: Some("Finished thinking.".into()),
1632 ..Default::default()
1633 },
1634 }
1635 );
1636}
1637
1638#[gpui::test]
1639async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
1640 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1641 let fake_model = model.as_fake();
1642
1643 let mut events = thread
1644 .update(cx, |thread, cx| {
1645 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1646 thread.send(UserMessageId::new(), ["Hello!"], cx)
1647 })
1648 .unwrap();
1649 cx.run_until_parked();
1650
1651 fake_model.send_last_completion_stream_text_chunk("Hey!");
1652 fake_model.end_last_completion_stream();
1653
1654 let mut retry_events = Vec::new();
1655 while let Some(Ok(event)) = events.next().await {
1656 match event {
1657 ThreadEvent::Retry(retry_status) => {
1658 retry_events.push(retry_status);
1659 }
1660 ThreadEvent::Stop(..) => break,
1661 _ => {}
1662 }
1663 }
1664
1665 assert_eq!(retry_events.len(), 0);
1666 thread.read_with(cx, |thread, _cx| {
1667 assert_eq!(
1668 thread.to_markdown(),
1669 indoc! {"
1670 ## User
1671
1672 Hello!
1673
1674 ## Assistant
1675
1676 Hey!
1677 "}
1678 )
1679 });
1680}
1681
1682#[gpui::test]
1683async fn test_send_retry_on_error(cx: &mut TestAppContext) {
1684 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1685 let fake_model = model.as_fake();
1686
1687 let mut events = thread
1688 .update(cx, |thread, cx| {
1689 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1690 thread.send(UserMessageId::new(), ["Hello!"], cx)
1691 })
1692 .unwrap();
1693 cx.run_until_parked();
1694
1695 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
1696 provider: LanguageModelProviderName::new("Anthropic"),
1697 retry_after: Some(Duration::from_secs(3)),
1698 });
1699 fake_model.end_last_completion_stream();
1700
1701 cx.executor().advance_clock(Duration::from_secs(3));
1702 cx.run_until_parked();
1703
1704 fake_model.send_last_completion_stream_text_chunk("Hey!");
1705 fake_model.end_last_completion_stream();
1706
1707 let mut retry_events = Vec::new();
1708 while let Some(Ok(event)) = events.next().await {
1709 match event {
1710 ThreadEvent::Retry(retry_status) => {
1711 retry_events.push(retry_status);
1712 }
1713 ThreadEvent::Stop(..) => break,
1714 _ => {}
1715 }
1716 }
1717
1718 assert_eq!(retry_events.len(), 1);
1719 assert!(matches!(
1720 retry_events[0],
1721 acp_thread::RetryStatus { attempt: 1, .. }
1722 ));
1723 thread.read_with(cx, |thread, _cx| {
1724 assert_eq!(
1725 thread.to_markdown(),
1726 indoc! {"
1727 ## User
1728
1729 Hello!
1730
1731 ## Assistant
1732
1733 Hey!
1734 "}
1735 )
1736 });
1737}
1738
1739#[gpui::test]
1740async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
1741 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1742 let fake_model = model.as_fake();
1743
1744 let mut events = thread
1745 .update(cx, |thread, cx| {
1746 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1747 thread.send(UserMessageId::new(), ["Hello!"], cx)
1748 })
1749 .unwrap();
1750 cx.run_until_parked();
1751
1752 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
1753 fake_model.send_last_completion_stream_error(
1754 LanguageModelCompletionError::ServerOverloaded {
1755 provider: LanguageModelProviderName::new("Anthropic"),
1756 retry_after: Some(Duration::from_secs(3)),
1757 },
1758 );
1759 fake_model.end_last_completion_stream();
1760 cx.executor().advance_clock(Duration::from_secs(3));
1761 cx.run_until_parked();
1762 }
1763
1764 let mut errors = Vec::new();
1765 let mut retry_events = Vec::new();
1766 while let Some(event) = events.next().await {
1767 match event {
1768 Ok(ThreadEvent::Retry(retry_status)) => {
1769 retry_events.push(retry_status);
1770 }
1771 Ok(ThreadEvent::Stop(..)) => break,
1772 Err(error) => errors.push(error),
1773 _ => {}
1774 }
1775 }
1776
1777 assert_eq!(
1778 retry_events.len(),
1779 crate::thread::MAX_RETRY_ATTEMPTS as usize
1780 );
1781 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
1782 assert_eq!(retry_events[i].attempt, i + 1);
1783 }
1784 assert_eq!(errors.len(), 1);
1785 let error = errors[0]
1786 .downcast_ref::<LanguageModelCompletionError>()
1787 .unwrap();
1788 assert!(matches!(
1789 error,
1790 LanguageModelCompletionError::ServerOverloaded { .. }
1791 ));
1792}
1793
1794/// Filters out the stop events for asserting against in tests
1795fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
1796 result_events
1797 .into_iter()
1798 .filter_map(|event| match event.unwrap() {
1799 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
1800 _ => None,
1801 })
1802 .collect()
1803}
1804
1805struct ThreadTest {
1806 model: Arc<dyn LanguageModel>,
1807 thread: Entity<Thread>,
1808 project_context: Entity<ProjectContext>,
1809 fs: Arc<FakeFs>,
1810}
1811
1812enum TestModel {
1813 Sonnet4,
1814 Fake,
1815}
1816
1817impl TestModel {
1818 fn id(&self) -> LanguageModelId {
1819 match self {
1820 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1821 TestModel::Fake => unreachable!(),
1822 }
1823 }
1824}
1825
1826async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1827 cx.executor().allow_parking();
1828
1829 let fs = FakeFs::new(cx.background_executor.clone());
1830 fs.create_dir(paths::settings_file().parent().unwrap())
1831 .await
1832 .unwrap();
1833 fs.insert_file(
1834 paths::settings_file(),
1835 json!({
1836 "agent": {
1837 "default_profile": "test-profile",
1838 "profiles": {
1839 "test-profile": {
1840 "name": "Test Profile",
1841 "tools": {
1842 EchoTool::name(): true,
1843 DelayTool::name(): true,
1844 WordListTool::name(): true,
1845 ToolRequiringPermission::name(): true,
1846 InfiniteTool::name(): true,
1847 }
1848 }
1849 }
1850 }
1851 })
1852 .to_string()
1853 .into_bytes(),
1854 )
1855 .await;
1856
1857 cx.update(|cx| {
1858 settings::init(cx);
1859 Project::init_settings(cx);
1860 agent_settings::init(cx);
1861 gpui_tokio::init(cx);
1862 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1863 cx.set_http_client(Arc::new(http_client));
1864
1865 client::init_settings(cx);
1866 let client = Client::production(cx);
1867 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1868 language_model::init(client.clone(), cx);
1869 language_models::init(user_store, client.clone(), cx);
1870
1871 watch_settings(fs.clone(), cx);
1872 });
1873
1874 let templates = Templates::new();
1875
1876 fs.insert_tree(path!("/test"), json!({})).await;
1877 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1878
1879 let model = cx
1880 .update(|cx| {
1881 if let TestModel::Fake = model {
1882 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1883 } else {
1884 let model_id = model.id();
1885 let models = LanguageModelRegistry::read_global(cx);
1886 let model = models
1887 .available_models(cx)
1888 .find(|model| model.id() == model_id)
1889 .unwrap();
1890
1891 let provider = models.provider(&model.provider_id()).unwrap();
1892 let authenticated = provider.authenticate(cx);
1893
1894 cx.spawn(async move |_cx| {
1895 authenticated.await.unwrap();
1896 model
1897 })
1898 }
1899 })
1900 .await;
1901
1902 let project_context = cx.new(|_cx| ProjectContext::default());
1903 let context_server_registry =
1904 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1905 let thread = cx.new(|cx| {
1906 Thread::new(
1907 project,
1908 project_context.clone(),
1909 context_server_registry,
1910 templates,
1911 Some(model.clone()),
1912 cx,
1913 )
1914 });
1915 ThreadTest {
1916 model,
1917 thread,
1918 project_context,
1919 fs,
1920 }
1921}
1922
1923#[cfg(test)]
1924#[ctor::ctor]
1925fn init_logger() {
1926 if std::env::var("RUST_LOG").is_ok() {
1927 env_logger::init();
1928 }
1929}
1930
1931fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1932 let fs = fs.clone();
1933 cx.spawn({
1934 async move |cx| {
1935 let mut new_settings_content_rx = settings::watch_config_file(
1936 cx.background_executor(),
1937 fs,
1938 paths::settings_file().clone(),
1939 );
1940
1941 while let Some(new_settings_content) = new_settings_content_rx.next().await {
1942 cx.update(|cx| {
1943 SettingsStore::update_global(cx, |settings, cx| {
1944 settings.set_user_settings(&new_settings_content, cx)
1945 })
1946 })
1947 .ok();
1948 }
1949 }
1950 })
1951 .detach();
1952}