1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp};
5use agent_settings::AgentProfileId;
6use anyhow::Result;
7use client::{Client, UserStore};
8use fs::{FakeFs, Fs};
9use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
10use gpui::{
11 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
12};
13use indoc::indoc;
14use language_model::{
15 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
16 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
17 LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
18 fake_provider::FakeLanguageModel,
19};
20use pretty_assertions::assert_eq;
21use project::Project;
22use prompt_store::ProjectContext;
23use reqwest_client::ReqwestClient;
24use schemars::JsonSchema;
25use serde::{Deserialize, Serialize};
26use serde_json::json;
27use settings::SettingsStore;
28use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
29use util::path;
30
31mod test_tools;
32use test_tools::*;
33
34#[gpui::test]
35#[ignore = "can't run on CI yet"]
36async fn test_echo(cx: &mut TestAppContext) {
37 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
38
39 let events = thread
40 .update(cx, |thread, cx| {
41 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
42 })
43 .unwrap()
44 .collect()
45 .await;
46 thread.update(cx, |thread, _cx| {
47 assert_eq!(
48 thread.last_message().unwrap().to_markdown(),
49 indoc! {"
50 ## Assistant
51
52 Hello
53 "}
54 )
55 });
56 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
57}
58
59#[gpui::test]
60#[ignore = "can't run on CI yet"]
61async fn test_thinking(cx: &mut TestAppContext) {
62 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
63
64 let events = thread
65 .update(cx, |thread, cx| {
66 thread.send(
67 UserMessageId::new(),
68 [indoc! {"
69 Testing:
70
71 Generate a thinking step where you just think the word 'Think',
72 and have your final answer be 'Hello'
73 "}],
74 cx,
75 )
76 })
77 .unwrap()
78 .collect()
79 .await;
80 thread.update(cx, |thread, _cx| {
81 assert_eq!(
82 thread.last_message().unwrap().to_markdown(),
83 indoc! {"
84 ## Assistant
85
86 <think>Think</think>
87 Hello
88 "}
89 )
90 });
91 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
92}
93
94#[gpui::test]
95async fn test_system_prompt(cx: &mut TestAppContext) {
96 let ThreadTest {
97 model,
98 thread,
99 project_context,
100 ..
101 } = setup(cx, TestModel::Fake).await;
102 let fake_model = model.as_fake();
103
104 project_context.update(cx, |project_context, _cx| {
105 project_context.shell = "test-shell".into()
106 });
107 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
108 thread
109 .update(cx, |thread, cx| {
110 thread.send(UserMessageId::new(), ["abc"], cx)
111 })
112 .unwrap();
113 cx.run_until_parked();
114 let mut pending_completions = fake_model.pending_completions();
115 assert_eq!(
116 pending_completions.len(),
117 1,
118 "unexpected pending completions: {:?}",
119 pending_completions
120 );
121
122 let pending_completion = pending_completions.pop().unwrap();
123 assert_eq!(pending_completion.messages[0].role, Role::System);
124
125 let system_message = &pending_completion.messages[0];
126 let system_prompt = system_message.content[0].to_str().unwrap();
127 assert!(
128 system_prompt.contains("test-shell"),
129 "unexpected system message: {:?}",
130 system_message
131 );
132 assert!(
133 system_prompt.contains("## Fixing Diagnostics"),
134 "unexpected system message: {:?}",
135 system_message
136 );
137}
138
139#[gpui::test]
140async fn test_prompt_caching(cx: &mut TestAppContext) {
141 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
142 let fake_model = model.as_fake();
143
144 // Send initial user message and verify it's cached
145 thread
146 .update(cx, |thread, cx| {
147 thread.send(UserMessageId::new(), ["Message 1"], cx)
148 })
149 .unwrap();
150 cx.run_until_parked();
151
152 let completion = fake_model.pending_completions().pop().unwrap();
153 assert_eq!(
154 completion.messages[1..],
155 vec![LanguageModelRequestMessage {
156 role: Role::User,
157 content: vec!["Message 1".into()],
158 cache: true
159 }]
160 );
161 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
162 "Response to Message 1".into(),
163 ));
164 fake_model.end_last_completion_stream();
165 cx.run_until_parked();
166
167 // Send another user message and verify only the latest is cached
168 thread
169 .update(cx, |thread, cx| {
170 thread.send(UserMessageId::new(), ["Message 2"], cx)
171 })
172 .unwrap();
173 cx.run_until_parked();
174
175 let completion = fake_model.pending_completions().pop().unwrap();
176 assert_eq!(
177 completion.messages[1..],
178 vec![
179 LanguageModelRequestMessage {
180 role: Role::User,
181 content: vec!["Message 1".into()],
182 cache: false
183 },
184 LanguageModelRequestMessage {
185 role: Role::Assistant,
186 content: vec!["Response to Message 1".into()],
187 cache: false
188 },
189 LanguageModelRequestMessage {
190 role: Role::User,
191 content: vec!["Message 2".into()],
192 cache: true
193 }
194 ]
195 );
196 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
197 "Response to Message 2".into(),
198 ));
199 fake_model.end_last_completion_stream();
200 cx.run_until_parked();
201
202 // Simulate a tool call and verify that the latest tool result is cached
203 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
204 thread
205 .update(cx, |thread, cx| {
206 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
207 })
208 .unwrap();
209 cx.run_until_parked();
210
211 let tool_use = LanguageModelToolUse {
212 id: "tool_1".into(),
213 name: EchoTool.name().into(),
214 raw_input: json!({"text": "test"}).to_string(),
215 input: json!({"text": "test"}),
216 is_input_complete: true,
217 };
218 fake_model
219 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
220 fake_model.end_last_completion_stream();
221 cx.run_until_parked();
222
223 let completion = fake_model.pending_completions().pop().unwrap();
224 let tool_result = LanguageModelToolResult {
225 tool_use_id: "tool_1".into(),
226 tool_name: EchoTool.name().into(),
227 is_error: false,
228 content: "test".into(),
229 output: Some("test".into()),
230 };
231 assert_eq!(
232 completion.messages[1..],
233 vec![
234 LanguageModelRequestMessage {
235 role: Role::User,
236 content: vec!["Message 1".into()],
237 cache: false
238 },
239 LanguageModelRequestMessage {
240 role: Role::Assistant,
241 content: vec!["Response to Message 1".into()],
242 cache: false
243 },
244 LanguageModelRequestMessage {
245 role: Role::User,
246 content: vec!["Message 2".into()],
247 cache: false
248 },
249 LanguageModelRequestMessage {
250 role: Role::Assistant,
251 content: vec!["Response to Message 2".into()],
252 cache: false
253 },
254 LanguageModelRequestMessage {
255 role: Role::User,
256 content: vec!["Use the echo tool".into()],
257 cache: false
258 },
259 LanguageModelRequestMessage {
260 role: Role::Assistant,
261 content: vec![MessageContent::ToolUse(tool_use)],
262 cache: false
263 },
264 LanguageModelRequestMessage {
265 role: Role::User,
266 content: vec![MessageContent::ToolResult(tool_result)],
267 cache: true
268 }
269 ]
270 );
271}
272
273#[gpui::test]
274#[ignore = "can't run on CI yet"]
275async fn test_basic_tool_calls(cx: &mut TestAppContext) {
276 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
277
278 // Test a tool call that's likely to complete *before* streaming stops.
279 let events = thread
280 .update(cx, |thread, cx| {
281 thread.add_tool(EchoTool);
282 thread.send(
283 UserMessageId::new(),
284 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
285 cx,
286 )
287 })
288 .unwrap()
289 .collect()
290 .await;
291 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
292
293 // Test a tool calls that's likely to complete *after* streaming stops.
294 let events = thread
295 .update(cx, |thread, cx| {
296 thread.remove_tool(&AgentTool::name(&EchoTool));
297 thread.add_tool(DelayTool);
298 thread.send(
299 UserMessageId::new(),
300 [
301 "Now call the delay tool with 200ms.",
302 "When the timer goes off, then you echo the output of the tool.",
303 ],
304 cx,
305 )
306 })
307 .unwrap()
308 .collect()
309 .await;
310 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
311 thread.update(cx, |thread, _cx| {
312 assert!(
313 thread
314 .last_message()
315 .unwrap()
316 .as_agent_message()
317 .unwrap()
318 .content
319 .iter()
320 .any(|content| {
321 if let AgentMessageContent::Text(text) = content {
322 text.contains("Ding")
323 } else {
324 false
325 }
326 }),
327 "{}",
328 thread.to_markdown()
329 );
330 });
331}
332
333#[gpui::test]
334#[ignore = "can't run on CI yet"]
335async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
336 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
337
338 // Test a tool call that's likely to complete *before* streaming stops.
339 let mut events = thread
340 .update(cx, |thread, cx| {
341 thread.add_tool(WordListTool);
342 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
343 })
344 .unwrap();
345
346 let mut saw_partial_tool_use = false;
347 while let Some(event) = events.next().await {
348 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
349 thread.update(cx, |thread, _cx| {
350 // Look for a tool use in the thread's last message
351 let message = thread.last_message().unwrap();
352 let agent_message = message.as_agent_message().unwrap();
353 let last_content = agent_message.content.last().unwrap();
354 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
355 assert_eq!(last_tool_use.name.as_ref(), "word_list");
356 if tool_call.status == acp::ToolCallStatus::Pending {
357 if !last_tool_use.is_input_complete
358 && last_tool_use.input.get("g").is_none()
359 {
360 saw_partial_tool_use = true;
361 }
362 } else {
363 last_tool_use
364 .input
365 .get("a")
366 .expect("'a' has streamed because input is now complete");
367 last_tool_use
368 .input
369 .get("g")
370 .expect("'g' has streamed because input is now complete");
371 }
372 } else {
373 panic!("last content should be a tool use");
374 }
375 });
376 }
377 }
378
379 assert!(
380 saw_partial_tool_use,
381 "should see at least one partially streamed tool use in the history"
382 );
383}
384
385#[gpui::test]
386async fn test_tool_authorization(cx: &mut TestAppContext) {
387 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
388 let fake_model = model.as_fake();
389
390 let mut events = thread
391 .update(cx, |thread, cx| {
392 thread.add_tool(ToolRequiringPermission);
393 thread.send(UserMessageId::new(), ["abc"], cx)
394 })
395 .unwrap();
396 cx.run_until_parked();
397 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
398 LanguageModelToolUse {
399 id: "tool_id_1".into(),
400 name: ToolRequiringPermission.name().into(),
401 raw_input: "{}".into(),
402 input: json!({}),
403 is_input_complete: true,
404 },
405 ));
406 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
407 LanguageModelToolUse {
408 id: "tool_id_2".into(),
409 name: ToolRequiringPermission.name().into(),
410 raw_input: "{}".into(),
411 input: json!({}),
412 is_input_complete: true,
413 },
414 ));
415 fake_model.end_last_completion_stream();
416 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
417 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
418
419 // Approve the first
420 tool_call_auth_1
421 .response
422 .send(tool_call_auth_1.options[1].id.clone())
423 .unwrap();
424 cx.run_until_parked();
425
426 // Reject the second
427 tool_call_auth_2
428 .response
429 .send(tool_call_auth_1.options[2].id.clone())
430 .unwrap();
431 cx.run_until_parked();
432
433 let completion = fake_model.pending_completions().pop().unwrap();
434 let message = completion.messages.last().unwrap();
435 assert_eq!(
436 message.content,
437 vec![
438 language_model::MessageContent::ToolResult(LanguageModelToolResult {
439 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
440 tool_name: ToolRequiringPermission.name().into(),
441 is_error: false,
442 content: "Allowed".into(),
443 output: Some("Allowed".into())
444 }),
445 language_model::MessageContent::ToolResult(LanguageModelToolResult {
446 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
447 tool_name: ToolRequiringPermission.name().into(),
448 is_error: true,
449 content: "Permission to run tool denied by user".into(),
450 output: None
451 })
452 ]
453 );
454
455 // Simulate yet another tool call.
456 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
457 LanguageModelToolUse {
458 id: "tool_id_3".into(),
459 name: ToolRequiringPermission.name().into(),
460 raw_input: "{}".into(),
461 input: json!({}),
462 is_input_complete: true,
463 },
464 ));
465 fake_model.end_last_completion_stream();
466
467 // Respond by always allowing tools.
468 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
469 tool_call_auth_3
470 .response
471 .send(tool_call_auth_3.options[0].id.clone())
472 .unwrap();
473 cx.run_until_parked();
474 let completion = fake_model.pending_completions().pop().unwrap();
475 let message = completion.messages.last().unwrap();
476 assert_eq!(
477 message.content,
478 vec![language_model::MessageContent::ToolResult(
479 LanguageModelToolResult {
480 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
481 tool_name: ToolRequiringPermission.name().into(),
482 is_error: false,
483 content: "Allowed".into(),
484 output: Some("Allowed".into())
485 }
486 )]
487 );
488
489 // Simulate a final tool call, ensuring we don't trigger authorization.
490 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
491 LanguageModelToolUse {
492 id: "tool_id_4".into(),
493 name: ToolRequiringPermission.name().into(),
494 raw_input: "{}".into(),
495 input: json!({}),
496 is_input_complete: true,
497 },
498 ));
499 fake_model.end_last_completion_stream();
500 cx.run_until_parked();
501 let completion = fake_model.pending_completions().pop().unwrap();
502 let message = completion.messages.last().unwrap();
503 assert_eq!(
504 message.content,
505 vec![language_model::MessageContent::ToolResult(
506 LanguageModelToolResult {
507 tool_use_id: "tool_id_4".into(),
508 tool_name: ToolRequiringPermission.name().into(),
509 is_error: false,
510 content: "Allowed".into(),
511 output: Some("Allowed".into())
512 }
513 )]
514 );
515}
516
517#[gpui::test]
518async fn test_tool_hallucination(cx: &mut TestAppContext) {
519 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
520 let fake_model = model.as_fake();
521
522 let mut events = thread
523 .update(cx, |thread, cx| {
524 thread.send(UserMessageId::new(), ["abc"], cx)
525 })
526 .unwrap();
527 cx.run_until_parked();
528 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
529 LanguageModelToolUse {
530 id: "tool_id_1".into(),
531 name: "nonexistent_tool".into(),
532 raw_input: "{}".into(),
533 input: json!({}),
534 is_input_complete: true,
535 },
536 ));
537 fake_model.end_last_completion_stream();
538
539 let tool_call = expect_tool_call(&mut events).await;
540 assert_eq!(tool_call.title, "nonexistent_tool");
541 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
542 let update = expect_tool_call_update_fields(&mut events).await;
543 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
544}
545
546#[gpui::test]
547async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
548 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
549 let fake_model = model.as_fake();
550
551 let events = thread
552 .update(cx, |thread, cx| {
553 thread.add_tool(EchoTool);
554 thread.send(UserMessageId::new(), ["abc"], cx)
555 })
556 .unwrap();
557 cx.run_until_parked();
558 let tool_use = LanguageModelToolUse {
559 id: "tool_id_1".into(),
560 name: EchoTool.name().into(),
561 raw_input: "{}".into(),
562 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
563 is_input_complete: true,
564 };
565 fake_model
566 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
567 fake_model.end_last_completion_stream();
568
569 cx.run_until_parked();
570 let completion = fake_model.pending_completions().pop().unwrap();
571 let tool_result = LanguageModelToolResult {
572 tool_use_id: "tool_id_1".into(),
573 tool_name: EchoTool.name().into(),
574 is_error: false,
575 content: "def".into(),
576 output: Some("def".into()),
577 };
578 assert_eq!(
579 completion.messages[1..],
580 vec![
581 LanguageModelRequestMessage {
582 role: Role::User,
583 content: vec!["abc".into()],
584 cache: false
585 },
586 LanguageModelRequestMessage {
587 role: Role::Assistant,
588 content: vec![MessageContent::ToolUse(tool_use.clone())],
589 cache: false
590 },
591 LanguageModelRequestMessage {
592 role: Role::User,
593 content: vec![MessageContent::ToolResult(tool_result.clone())],
594 cache: true
595 },
596 ]
597 );
598
599 // Simulate reaching tool use limit.
600 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
601 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
602 ));
603 fake_model.end_last_completion_stream();
604 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
605 assert!(
606 last_event
607 .unwrap_err()
608 .is::<language_model::ToolUseLimitReachedError>()
609 );
610
611 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
612 cx.run_until_parked();
613 let completion = fake_model.pending_completions().pop().unwrap();
614 assert_eq!(
615 completion.messages[1..],
616 vec![
617 LanguageModelRequestMessage {
618 role: Role::User,
619 content: vec!["abc".into()],
620 cache: false
621 },
622 LanguageModelRequestMessage {
623 role: Role::Assistant,
624 content: vec![MessageContent::ToolUse(tool_use)],
625 cache: false
626 },
627 LanguageModelRequestMessage {
628 role: Role::User,
629 content: vec![MessageContent::ToolResult(tool_result)],
630 cache: false
631 },
632 LanguageModelRequestMessage {
633 role: Role::User,
634 content: vec!["Continue where you left off".into()],
635 cache: true
636 }
637 ]
638 );
639
640 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
641 fake_model.end_last_completion_stream();
642 events.collect::<Vec<_>>().await;
643 thread.read_with(cx, |thread, _cx| {
644 assert_eq!(
645 thread.last_message().unwrap().to_markdown(),
646 indoc! {"
647 ## Assistant
648
649 Done
650 "}
651 )
652 });
653
654 // Ensure we error if calling resume when tool use limit was *not* reached.
655 let error = thread
656 .update(cx, |thread, cx| thread.resume(cx))
657 .unwrap_err();
658 assert_eq!(
659 error.to_string(),
660 "can only resume after tool use limit is reached"
661 )
662}
663
664#[gpui::test]
665async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
666 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
667 let fake_model = model.as_fake();
668
669 let events = thread
670 .update(cx, |thread, cx| {
671 thread.add_tool(EchoTool);
672 thread.send(UserMessageId::new(), ["abc"], cx)
673 })
674 .unwrap();
675 cx.run_until_parked();
676
677 let tool_use = LanguageModelToolUse {
678 id: "tool_id_1".into(),
679 name: EchoTool.name().into(),
680 raw_input: "{}".into(),
681 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
682 is_input_complete: true,
683 };
684 let tool_result = LanguageModelToolResult {
685 tool_use_id: "tool_id_1".into(),
686 tool_name: EchoTool.name().into(),
687 is_error: false,
688 content: "def".into(),
689 output: Some("def".into()),
690 };
691 fake_model
692 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
693 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
694 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
695 ));
696 fake_model.end_last_completion_stream();
697 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
698 assert!(
699 last_event
700 .unwrap_err()
701 .is::<language_model::ToolUseLimitReachedError>()
702 );
703
704 thread
705 .update(cx, |thread, cx| {
706 thread.send(UserMessageId::new(), vec!["ghi"], cx)
707 })
708 .unwrap();
709 cx.run_until_parked();
710 let completion = fake_model.pending_completions().pop().unwrap();
711 assert_eq!(
712 completion.messages[1..],
713 vec![
714 LanguageModelRequestMessage {
715 role: Role::User,
716 content: vec!["abc".into()],
717 cache: false
718 },
719 LanguageModelRequestMessage {
720 role: Role::Assistant,
721 content: vec![MessageContent::ToolUse(tool_use)],
722 cache: false
723 },
724 LanguageModelRequestMessage {
725 role: Role::User,
726 content: vec![MessageContent::ToolResult(tool_result)],
727 cache: false
728 },
729 LanguageModelRequestMessage {
730 role: Role::User,
731 content: vec!["ghi".into()],
732 cache: true
733 }
734 ]
735 );
736}
737
738async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
739 let event = events
740 .next()
741 .await
742 .expect("no tool call authorization event received")
743 .unwrap();
744 match event {
745 ThreadEvent::ToolCall(tool_call) => tool_call,
746 event => {
747 panic!("Unexpected event {event:?}");
748 }
749 }
750}
751
752async fn expect_tool_call_update_fields(
753 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
754) -> acp::ToolCallUpdate {
755 let event = events
756 .next()
757 .await
758 .expect("no tool call authorization event received")
759 .unwrap();
760 match event {
761 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
762 event => {
763 panic!("Unexpected event {event:?}");
764 }
765 }
766}
767
768async fn next_tool_call_authorization(
769 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
770) -> ToolCallAuthorization {
771 loop {
772 let event = events
773 .next()
774 .await
775 .expect("no tool call authorization event received")
776 .unwrap();
777 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
778 let permission_kinds = tool_call_authorization
779 .options
780 .iter()
781 .map(|o| o.kind)
782 .collect::<Vec<_>>();
783 assert_eq!(
784 permission_kinds,
785 vec![
786 acp::PermissionOptionKind::AllowAlways,
787 acp::PermissionOptionKind::AllowOnce,
788 acp::PermissionOptionKind::RejectOnce,
789 ]
790 );
791 return tool_call_authorization;
792 }
793 }
794}
795
796#[gpui::test]
797#[ignore = "can't run on CI yet"]
798async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
799 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
800
801 // Test concurrent tool calls with different delay times
802 let events = thread
803 .update(cx, |thread, cx| {
804 thread.add_tool(DelayTool);
805 thread.send(
806 UserMessageId::new(),
807 [
808 "Call the delay tool twice in the same message.",
809 "Once with 100ms. Once with 300ms.",
810 "When both timers are complete, describe the outputs.",
811 ],
812 cx,
813 )
814 })
815 .unwrap()
816 .collect()
817 .await;
818
819 let stop_reasons = stop_events(events);
820 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
821
822 thread.update(cx, |thread, _cx| {
823 let last_message = thread.last_message().unwrap();
824 let agent_message = last_message.as_agent_message().unwrap();
825 let text = agent_message
826 .content
827 .iter()
828 .filter_map(|content| {
829 if let AgentMessageContent::Text(text) = content {
830 Some(text.as_str())
831 } else {
832 None
833 }
834 })
835 .collect::<String>();
836
837 assert!(text.contains("Ding"));
838 });
839}
840
841#[gpui::test]
842async fn test_profiles(cx: &mut TestAppContext) {
843 let ThreadTest {
844 model, thread, fs, ..
845 } = setup(cx, TestModel::Fake).await;
846 let fake_model = model.as_fake();
847
848 thread.update(cx, |thread, _cx| {
849 thread.add_tool(DelayTool);
850 thread.add_tool(EchoTool);
851 thread.add_tool(InfiniteTool);
852 });
853
854 // Override profiles and wait for settings to be loaded.
855 fs.insert_file(
856 paths::settings_file(),
857 json!({
858 "agent": {
859 "profiles": {
860 "test-1": {
861 "name": "Test Profile 1",
862 "tools": {
863 EchoTool.name(): true,
864 DelayTool.name(): true,
865 }
866 },
867 "test-2": {
868 "name": "Test Profile 2",
869 "tools": {
870 InfiniteTool.name(): true,
871 }
872 }
873 }
874 }
875 })
876 .to_string()
877 .into_bytes(),
878 )
879 .await;
880 cx.run_until_parked();
881
882 // Test that test-1 profile (default) has echo and delay tools
883 thread
884 .update(cx, |thread, cx| {
885 thread.set_profile(AgentProfileId("test-1".into()));
886 thread.send(UserMessageId::new(), ["test"], cx)
887 })
888 .unwrap();
889 cx.run_until_parked();
890
891 let mut pending_completions = fake_model.pending_completions();
892 assert_eq!(pending_completions.len(), 1);
893 let completion = pending_completions.pop().unwrap();
894 let tool_names: Vec<String> = completion
895 .tools
896 .iter()
897 .map(|tool| tool.name.clone())
898 .collect();
899 assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
900 fake_model.end_last_completion_stream();
901
902 // Switch to test-2 profile, and verify that it has only the infinite tool.
903 thread
904 .update(cx, |thread, cx| {
905 thread.set_profile(AgentProfileId("test-2".into()));
906 thread.send(UserMessageId::new(), ["test2"], cx)
907 })
908 .unwrap();
909 cx.run_until_parked();
910 let mut pending_completions = fake_model.pending_completions();
911 assert_eq!(pending_completions.len(), 1);
912 let completion = pending_completions.pop().unwrap();
913 let tool_names: Vec<String> = completion
914 .tools
915 .iter()
916 .map(|tool| tool.name.clone())
917 .collect();
918 assert_eq!(tool_names, vec![InfiniteTool.name()]);
919}
920
921#[gpui::test]
922#[ignore = "can't run on CI yet"]
923async fn test_cancellation(cx: &mut TestAppContext) {
924 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
925
926 let mut events = thread
927 .update(cx, |thread, cx| {
928 thread.add_tool(InfiniteTool);
929 thread.add_tool(EchoTool);
930 thread.send(
931 UserMessageId::new(),
932 ["Call the echo tool, then call the infinite tool, then explain their output"],
933 cx,
934 )
935 })
936 .unwrap();
937
938 // Wait until both tools are called.
939 let mut expected_tools = vec!["Echo", "Infinite Tool"];
940 let mut echo_id = None;
941 let mut echo_completed = false;
942 while let Some(event) = events.next().await {
943 match event.unwrap() {
944 ThreadEvent::ToolCall(tool_call) => {
945 assert_eq!(tool_call.title, expected_tools.remove(0));
946 if tool_call.title == "Echo" {
947 echo_id = Some(tool_call.id);
948 }
949 }
950 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
951 acp::ToolCallUpdate {
952 id,
953 fields:
954 acp::ToolCallUpdateFields {
955 status: Some(acp::ToolCallStatus::Completed),
956 ..
957 },
958 },
959 )) if Some(&id) == echo_id.as_ref() => {
960 echo_completed = true;
961 }
962 _ => {}
963 }
964
965 if expected_tools.is_empty() && echo_completed {
966 break;
967 }
968 }
969
970 // Cancel the current send and ensure that the event stream is closed, even
971 // if one of the tools is still running.
972 thread.update(cx, |thread, cx| thread.cancel(cx));
973 let events = events.collect::<Vec<_>>().await;
974 let last_event = events.last();
975 assert!(
976 matches!(
977 last_event,
978 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
979 ),
980 "unexpected event {last_event:?}"
981 );
982
983 // Ensure we can still send a new message after cancellation.
984 let events = thread
985 .update(cx, |thread, cx| {
986 thread.send(
987 UserMessageId::new(),
988 ["Testing: reply with 'Hello' then stop."],
989 cx,
990 )
991 })
992 .unwrap()
993 .collect::<Vec<_>>()
994 .await;
995 thread.update(cx, |thread, _cx| {
996 let message = thread.last_message().unwrap();
997 let agent_message = message.as_agent_message().unwrap();
998 assert_eq!(
999 agent_message.content,
1000 vec![AgentMessageContent::Text("Hello".to_string())]
1001 );
1002 });
1003 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1004}
1005
1006#[gpui::test]
1007async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1008 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1009 let fake_model = model.as_fake();
1010
1011 let events_1 = thread
1012 .update(cx, |thread, cx| {
1013 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1014 })
1015 .unwrap();
1016 cx.run_until_parked();
1017 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1018 cx.run_until_parked();
1019
1020 let events_2 = thread
1021 .update(cx, |thread, cx| {
1022 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1023 })
1024 .unwrap();
1025 cx.run_until_parked();
1026 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1027 fake_model
1028 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1029 fake_model.end_last_completion_stream();
1030
1031 let events_1 = events_1.collect::<Vec<_>>().await;
1032 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1033 let events_2 = events_2.collect::<Vec<_>>().await;
1034 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1035}
1036
1037#[gpui::test]
1038async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1039 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1040 let fake_model = model.as_fake();
1041
1042 let events_1 = thread
1043 .update(cx, |thread, cx| {
1044 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1045 })
1046 .unwrap();
1047 cx.run_until_parked();
1048 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1049 fake_model
1050 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1051 fake_model.end_last_completion_stream();
1052 let events_1 = events_1.collect::<Vec<_>>().await;
1053
1054 let events_2 = thread
1055 .update(cx, |thread, cx| {
1056 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1057 })
1058 .unwrap();
1059 cx.run_until_parked();
1060 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1061 fake_model
1062 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1063 fake_model.end_last_completion_stream();
1064 let events_2 = events_2.collect::<Vec<_>>().await;
1065
1066 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1067 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1068}
1069
1070#[gpui::test]
1071async fn test_refusal(cx: &mut TestAppContext) {
1072 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1073 let fake_model = model.as_fake();
1074
1075 let events = thread
1076 .update(cx, |thread, cx| {
1077 thread.send(UserMessageId::new(), ["Hello"], cx)
1078 })
1079 .unwrap();
1080 cx.run_until_parked();
1081 thread.read_with(cx, |thread, _| {
1082 assert_eq!(
1083 thread.to_markdown(),
1084 indoc! {"
1085 ## User
1086
1087 Hello
1088 "}
1089 );
1090 });
1091
1092 fake_model.send_last_completion_stream_text_chunk("Hey!");
1093 cx.run_until_parked();
1094 thread.read_with(cx, |thread, _| {
1095 assert_eq!(
1096 thread.to_markdown(),
1097 indoc! {"
1098 ## User
1099
1100 Hello
1101
1102 ## Assistant
1103
1104 Hey!
1105 "}
1106 );
1107 });
1108
1109 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1110 fake_model
1111 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1112 let events = events.collect::<Vec<_>>().await;
1113 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1114 thread.read_with(cx, |thread, _| {
1115 assert_eq!(thread.to_markdown(), "");
1116 });
1117}
1118
1119#[gpui::test]
1120async fn test_truncate_first_message(cx: &mut TestAppContext) {
1121 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1122 let fake_model = model.as_fake();
1123
1124 let message_id = UserMessageId::new();
1125 thread
1126 .update(cx, |thread, cx| {
1127 thread.send(message_id.clone(), ["Hello"], cx)
1128 })
1129 .unwrap();
1130 cx.run_until_parked();
1131 thread.read_with(cx, |thread, _| {
1132 assert_eq!(
1133 thread.to_markdown(),
1134 indoc! {"
1135 ## User
1136
1137 Hello
1138 "}
1139 );
1140 assert_eq!(thread.latest_token_usage(), None);
1141 });
1142
1143 fake_model.send_last_completion_stream_text_chunk("Hey!");
1144 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1145 language_model::TokenUsage {
1146 input_tokens: 32_000,
1147 output_tokens: 16_000,
1148 cache_creation_input_tokens: 0,
1149 cache_read_input_tokens: 0,
1150 },
1151 ));
1152 cx.run_until_parked();
1153 thread.read_with(cx, |thread, _| {
1154 assert_eq!(
1155 thread.to_markdown(),
1156 indoc! {"
1157 ## User
1158
1159 Hello
1160
1161 ## Assistant
1162
1163 Hey!
1164 "}
1165 );
1166 assert_eq!(
1167 thread.latest_token_usage(),
1168 Some(acp_thread::TokenUsage {
1169 used_tokens: 32_000 + 16_000,
1170 max_tokens: 1_000_000,
1171 })
1172 );
1173 });
1174
1175 thread
1176 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1177 .unwrap();
1178 cx.run_until_parked();
1179 thread.read_with(cx, |thread, _| {
1180 assert_eq!(thread.to_markdown(), "");
1181 assert_eq!(thread.latest_token_usage(), None);
1182 });
1183
1184 // Ensure we can still send a new message after truncation.
1185 thread
1186 .update(cx, |thread, cx| {
1187 thread.send(UserMessageId::new(), ["Hi"], cx)
1188 })
1189 .unwrap();
1190 thread.update(cx, |thread, _cx| {
1191 assert_eq!(
1192 thread.to_markdown(),
1193 indoc! {"
1194 ## User
1195
1196 Hi
1197 "}
1198 );
1199 });
1200 cx.run_until_parked();
1201 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1202 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1203 language_model::TokenUsage {
1204 input_tokens: 40_000,
1205 output_tokens: 20_000,
1206 cache_creation_input_tokens: 0,
1207 cache_read_input_tokens: 0,
1208 },
1209 ));
1210 cx.run_until_parked();
1211 thread.read_with(cx, |thread, _| {
1212 assert_eq!(
1213 thread.to_markdown(),
1214 indoc! {"
1215 ## User
1216
1217 Hi
1218
1219 ## Assistant
1220
1221 Ahoy!
1222 "}
1223 );
1224
1225 assert_eq!(
1226 thread.latest_token_usage(),
1227 Some(acp_thread::TokenUsage {
1228 used_tokens: 40_000 + 20_000,
1229 max_tokens: 1_000_000,
1230 })
1231 );
1232 });
1233}
1234
1235#[gpui::test]
1236async fn test_truncate_second_message(cx: &mut TestAppContext) {
1237 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1238 let fake_model = model.as_fake();
1239
1240 thread
1241 .update(cx, |thread, cx| {
1242 thread.send(UserMessageId::new(), ["Message 1"], cx)
1243 })
1244 .unwrap();
1245 cx.run_until_parked();
1246 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1247 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1248 language_model::TokenUsage {
1249 input_tokens: 32_000,
1250 output_tokens: 16_000,
1251 cache_creation_input_tokens: 0,
1252 cache_read_input_tokens: 0,
1253 },
1254 ));
1255 fake_model.end_last_completion_stream();
1256 cx.run_until_parked();
1257
1258 let assert_first_message_state = |cx: &mut TestAppContext| {
1259 thread.clone().read_with(cx, |thread, _| {
1260 assert_eq!(
1261 thread.to_markdown(),
1262 indoc! {"
1263 ## User
1264
1265 Message 1
1266
1267 ## Assistant
1268
1269 Message 1 response
1270 "}
1271 );
1272
1273 assert_eq!(
1274 thread.latest_token_usage(),
1275 Some(acp_thread::TokenUsage {
1276 used_tokens: 32_000 + 16_000,
1277 max_tokens: 1_000_000,
1278 })
1279 );
1280 });
1281 };
1282
1283 assert_first_message_state(cx);
1284
1285 let second_message_id = UserMessageId::new();
1286 thread
1287 .update(cx, |thread, cx| {
1288 thread.send(second_message_id.clone(), ["Message 2"], cx)
1289 })
1290 .unwrap();
1291 cx.run_until_parked();
1292
1293 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1294 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1295 language_model::TokenUsage {
1296 input_tokens: 40_000,
1297 output_tokens: 20_000,
1298 cache_creation_input_tokens: 0,
1299 cache_read_input_tokens: 0,
1300 },
1301 ));
1302 fake_model.end_last_completion_stream();
1303 cx.run_until_parked();
1304
1305 thread.read_with(cx, |thread, _| {
1306 assert_eq!(
1307 thread.to_markdown(),
1308 indoc! {"
1309 ## User
1310
1311 Message 1
1312
1313 ## Assistant
1314
1315 Message 1 response
1316
1317 ## User
1318
1319 Message 2
1320
1321 ## Assistant
1322
1323 Message 2 response
1324 "}
1325 );
1326
1327 assert_eq!(
1328 thread.latest_token_usage(),
1329 Some(acp_thread::TokenUsage {
1330 used_tokens: 40_000 + 20_000,
1331 max_tokens: 1_000_000,
1332 })
1333 );
1334 });
1335
1336 thread
1337 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1338 .unwrap();
1339 cx.run_until_parked();
1340
1341 assert_first_message_state(cx);
1342}
1343
1344#[gpui::test]
1345async fn test_title_generation(cx: &mut TestAppContext) {
1346 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1347 let fake_model = model.as_fake();
1348
1349 let summary_model = Arc::new(FakeLanguageModel::default());
1350 thread.update(cx, |thread, cx| {
1351 thread.set_summarization_model(Some(summary_model.clone()), cx)
1352 });
1353
1354 let send = thread
1355 .update(cx, |thread, cx| {
1356 thread.send(UserMessageId::new(), ["Hello"], cx)
1357 })
1358 .unwrap();
1359 cx.run_until_parked();
1360
1361 fake_model.send_last_completion_stream_text_chunk("Hey!");
1362 fake_model.end_last_completion_stream();
1363 cx.run_until_parked();
1364 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1365
1366 // Ensure the summary model has been invoked to generate a title.
1367 summary_model.send_last_completion_stream_text_chunk("Hello ");
1368 summary_model.send_last_completion_stream_text_chunk("world\nG");
1369 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1370 summary_model.end_last_completion_stream();
1371 send.collect::<Vec<_>>().await;
1372 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1373
1374 // Send another message, ensuring no title is generated this time.
1375 let send = thread
1376 .update(cx, |thread, cx| {
1377 thread.send(UserMessageId::new(), ["Hello again"], cx)
1378 })
1379 .unwrap();
1380 cx.run_until_parked();
1381 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1382 fake_model.end_last_completion_stream();
1383 cx.run_until_parked();
1384 assert_eq!(summary_model.pending_completions(), Vec::new());
1385 send.collect::<Vec<_>>().await;
1386 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1387}
1388
1389#[gpui::test]
1390async fn test_agent_connection(cx: &mut TestAppContext) {
1391 cx.update(settings::init);
1392 let templates = Templates::new();
1393
1394 // Initialize language model system with test provider
1395 cx.update(|cx| {
1396 gpui_tokio::init(cx);
1397 client::init_settings(cx);
1398
1399 let http_client = FakeHttpClient::with_404_response();
1400 let clock = Arc::new(clock::FakeSystemClock::new());
1401 let client = Client::new(clock, http_client, cx);
1402 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1403 language_model::init(client.clone(), cx);
1404 language_models::init(user_store, client.clone(), cx);
1405 Project::init_settings(cx);
1406 LanguageModelRegistry::test(cx);
1407 agent_settings::init(cx);
1408 });
1409 cx.executor().forbid_parking();
1410
1411 // Create a project for new_thread
1412 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1413 fake_fs.insert_tree(path!("/test"), json!({})).await;
1414 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1415 let cwd = Path::new("/test");
1416 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1417 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1418
1419 // Create agent and connection
1420 let agent = NativeAgent::new(
1421 project.clone(),
1422 history_store,
1423 templates.clone(),
1424 None,
1425 fake_fs.clone(),
1426 &mut cx.to_async(),
1427 )
1428 .await
1429 .unwrap();
1430 let connection = NativeAgentConnection(agent.clone());
1431
1432 // Test model_selector returns Some
1433 let selector_opt = connection.model_selector();
1434 assert!(
1435 selector_opt.is_some(),
1436 "agent2 should always support ModelSelector"
1437 );
1438 let selector = selector_opt.unwrap();
1439
1440 // Test list_models
1441 let listed_models = cx
1442 .update(|cx| selector.list_models(cx))
1443 .await
1444 .expect("list_models should succeed");
1445 let AgentModelList::Grouped(listed_models) = listed_models else {
1446 panic!("Unexpected model list type");
1447 };
1448 assert!(!listed_models.is_empty(), "should have at least one model");
1449 assert_eq!(
1450 listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1451 "fake/fake"
1452 );
1453
1454 // Create a thread using new_thread
1455 let connection_rc = Rc::new(connection.clone());
1456 let acp_thread = cx
1457 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1458 .await
1459 .expect("new_thread should succeed");
1460
1461 // Get the session_id from the AcpThread
1462 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1463
1464 // Test selected_model returns the default
1465 let model = cx
1466 .update(|cx| selector.selected_model(&session_id, cx))
1467 .await
1468 .expect("selected_model should succeed");
1469 let model = cx
1470 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1471 .unwrap();
1472 let model = model.as_fake();
1473 assert_eq!(model.id().0, "fake", "should return default model");
1474
1475 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1476 cx.run_until_parked();
1477 model.send_last_completion_stream_text_chunk("def");
1478 cx.run_until_parked();
1479 acp_thread.read_with(cx, |thread, cx| {
1480 assert_eq!(
1481 thread.to_markdown(cx),
1482 indoc! {"
1483 ## User
1484
1485 abc
1486
1487 ## Assistant
1488
1489 def
1490
1491 "}
1492 )
1493 });
1494
1495 // Test cancel
1496 cx.update(|cx| connection.cancel(&session_id, cx));
1497 request.await.expect("prompt should fail gracefully");
1498
1499 // Ensure that dropping the ACP thread causes the native thread to be
1500 // dropped as well.
1501 cx.update(|_| drop(acp_thread));
1502 let result = cx
1503 .update(|cx| {
1504 connection.prompt(
1505 Some(acp_thread::UserMessageId::new()),
1506 acp::PromptRequest {
1507 session_id: session_id.clone(),
1508 prompt: vec!["ghi".into()],
1509 },
1510 cx,
1511 )
1512 })
1513 .await;
1514 assert_eq!(
1515 result.as_ref().unwrap_err().to_string(),
1516 "Session not found",
1517 "unexpected result: {:?}",
1518 result
1519 );
1520}
1521
1522#[gpui::test]
1523async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1524 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1525 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1526 let fake_model = model.as_fake();
1527
1528 let mut events = thread
1529 .update(cx, |thread, cx| {
1530 thread.send(UserMessageId::new(), ["Think"], cx)
1531 })
1532 .unwrap();
1533 cx.run_until_parked();
1534
1535 // Simulate streaming partial input.
1536 let input = json!({});
1537 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1538 LanguageModelToolUse {
1539 id: "1".into(),
1540 name: ThinkingTool.name().into(),
1541 raw_input: input.to_string(),
1542 input,
1543 is_input_complete: false,
1544 },
1545 ));
1546
1547 // Input streaming completed
1548 let input = json!({ "content": "Thinking hard!" });
1549 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1550 LanguageModelToolUse {
1551 id: "1".into(),
1552 name: "thinking".into(),
1553 raw_input: input.to_string(),
1554 input,
1555 is_input_complete: true,
1556 },
1557 ));
1558 fake_model.end_last_completion_stream();
1559 cx.run_until_parked();
1560
1561 let tool_call = expect_tool_call(&mut events).await;
1562 assert_eq!(
1563 tool_call,
1564 acp::ToolCall {
1565 id: acp::ToolCallId("1".into()),
1566 title: "Thinking".into(),
1567 kind: acp::ToolKind::Think,
1568 status: acp::ToolCallStatus::Pending,
1569 content: vec![],
1570 locations: vec![],
1571 raw_input: Some(json!({})),
1572 raw_output: None,
1573 }
1574 );
1575 let update = expect_tool_call_update_fields(&mut events).await;
1576 assert_eq!(
1577 update,
1578 acp::ToolCallUpdate {
1579 id: acp::ToolCallId("1".into()),
1580 fields: acp::ToolCallUpdateFields {
1581 title: Some("Thinking".into()),
1582 kind: Some(acp::ToolKind::Think),
1583 raw_input: Some(json!({ "content": "Thinking hard!" })),
1584 ..Default::default()
1585 },
1586 }
1587 );
1588 let update = expect_tool_call_update_fields(&mut events).await;
1589 assert_eq!(
1590 update,
1591 acp::ToolCallUpdate {
1592 id: acp::ToolCallId("1".into()),
1593 fields: acp::ToolCallUpdateFields {
1594 status: Some(acp::ToolCallStatus::InProgress),
1595 ..Default::default()
1596 },
1597 }
1598 );
1599 let update = expect_tool_call_update_fields(&mut events).await;
1600 assert_eq!(
1601 update,
1602 acp::ToolCallUpdate {
1603 id: acp::ToolCallId("1".into()),
1604 fields: acp::ToolCallUpdateFields {
1605 content: Some(vec!["Thinking hard!".into()]),
1606 ..Default::default()
1607 },
1608 }
1609 );
1610 let update = expect_tool_call_update_fields(&mut events).await;
1611 assert_eq!(
1612 update,
1613 acp::ToolCallUpdate {
1614 id: acp::ToolCallId("1".into()),
1615 fields: acp::ToolCallUpdateFields {
1616 status: Some(acp::ToolCallStatus::Completed),
1617 raw_output: Some("Finished thinking.".into()),
1618 ..Default::default()
1619 },
1620 }
1621 );
1622}
1623
1624#[gpui::test]
1625async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
1626 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1627 let fake_model = model.as_fake();
1628
1629 let mut events = thread
1630 .update(cx, |thread, cx| {
1631 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1632 thread.send(UserMessageId::new(), ["Hello!"], cx)
1633 })
1634 .unwrap();
1635 cx.run_until_parked();
1636
1637 fake_model.send_last_completion_stream_text_chunk("Hey!");
1638 fake_model.end_last_completion_stream();
1639
1640 let mut retry_events = Vec::new();
1641 while let Some(Ok(event)) = events.next().await {
1642 match event {
1643 ThreadEvent::Retry(retry_status) => {
1644 retry_events.push(retry_status);
1645 }
1646 ThreadEvent::Stop(..) => break,
1647 _ => {}
1648 }
1649 }
1650
1651 assert_eq!(retry_events.len(), 0);
1652 thread.read_with(cx, |thread, _cx| {
1653 assert_eq!(
1654 thread.to_markdown(),
1655 indoc! {"
1656 ## User
1657
1658 Hello!
1659
1660 ## Assistant
1661
1662 Hey!
1663 "}
1664 )
1665 });
1666}
1667
1668#[gpui::test]
1669async fn test_send_retry_on_error(cx: &mut TestAppContext) {
1670 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1671 let fake_model = model.as_fake();
1672
1673 let mut events = thread
1674 .update(cx, |thread, cx| {
1675 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1676 thread.send(UserMessageId::new(), ["Hello!"], cx)
1677 })
1678 .unwrap();
1679 cx.run_until_parked();
1680
1681 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
1682 provider: LanguageModelProviderName::new("Anthropic"),
1683 retry_after: Some(Duration::from_secs(3)),
1684 });
1685 fake_model.end_last_completion_stream();
1686
1687 cx.executor().advance_clock(Duration::from_secs(3));
1688 cx.run_until_parked();
1689
1690 fake_model.send_last_completion_stream_text_chunk("Hey!");
1691 fake_model.end_last_completion_stream();
1692
1693 let mut retry_events = Vec::new();
1694 while let Some(Ok(event)) = events.next().await {
1695 match event {
1696 ThreadEvent::Retry(retry_status) => {
1697 retry_events.push(retry_status);
1698 }
1699 ThreadEvent::Stop(..) => break,
1700 _ => {}
1701 }
1702 }
1703
1704 assert_eq!(retry_events.len(), 1);
1705 assert!(matches!(
1706 retry_events[0],
1707 acp_thread::RetryStatus { attempt: 1, .. }
1708 ));
1709 thread.read_with(cx, |thread, _cx| {
1710 assert_eq!(
1711 thread.to_markdown(),
1712 indoc! {"
1713 ## User
1714
1715 Hello!
1716
1717 ## Assistant
1718
1719 Hey!
1720 "}
1721 )
1722 });
1723}
1724
1725#[gpui::test]
1726async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
1727 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1728 let fake_model = model.as_fake();
1729
1730 let mut events = thread
1731 .update(cx, |thread, cx| {
1732 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1733 thread.send(UserMessageId::new(), ["Hello!"], cx)
1734 })
1735 .unwrap();
1736 cx.run_until_parked();
1737
1738 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
1739 fake_model.send_last_completion_stream_error(
1740 LanguageModelCompletionError::ServerOverloaded {
1741 provider: LanguageModelProviderName::new("Anthropic"),
1742 retry_after: Some(Duration::from_secs(3)),
1743 },
1744 );
1745 fake_model.end_last_completion_stream();
1746 cx.executor().advance_clock(Duration::from_secs(3));
1747 cx.run_until_parked();
1748 }
1749
1750 let mut errors = Vec::new();
1751 let mut retry_events = Vec::new();
1752 while let Some(event) = events.next().await {
1753 match event {
1754 Ok(ThreadEvent::Retry(retry_status)) => {
1755 retry_events.push(retry_status);
1756 }
1757 Ok(ThreadEvent::Stop(..)) => break,
1758 Err(error) => errors.push(error),
1759 _ => {}
1760 }
1761 }
1762
1763 assert_eq!(
1764 retry_events.len(),
1765 crate::thread::MAX_RETRY_ATTEMPTS as usize
1766 );
1767 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
1768 assert_eq!(retry_events[i].attempt, i + 1);
1769 }
1770 assert_eq!(errors.len(), 1);
1771 let error = errors[0]
1772 .downcast_ref::<LanguageModelCompletionError>()
1773 .unwrap();
1774 assert!(matches!(
1775 error,
1776 LanguageModelCompletionError::ServerOverloaded { .. }
1777 ));
1778}
1779
1780/// Filters out the stop events for asserting against in tests
1781fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
1782 result_events
1783 .into_iter()
1784 .filter_map(|event| match event.unwrap() {
1785 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
1786 _ => None,
1787 })
1788 .collect()
1789}
1790
1791struct ThreadTest {
1792 model: Arc<dyn LanguageModel>,
1793 thread: Entity<Thread>,
1794 project_context: Entity<ProjectContext>,
1795 fs: Arc<FakeFs>,
1796}
1797
1798enum TestModel {
1799 Sonnet4,
1800 Sonnet4Thinking,
1801 Fake,
1802}
1803
1804impl TestModel {
1805 fn id(&self) -> LanguageModelId {
1806 match self {
1807 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1808 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
1809 TestModel::Fake => unreachable!(),
1810 }
1811 }
1812}
1813
1814async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1815 cx.executor().allow_parking();
1816
1817 let fs = FakeFs::new(cx.background_executor.clone());
1818 fs.create_dir(paths::settings_file().parent().unwrap())
1819 .await
1820 .unwrap();
1821 fs.insert_file(
1822 paths::settings_file(),
1823 json!({
1824 "agent": {
1825 "default_profile": "test-profile",
1826 "profiles": {
1827 "test-profile": {
1828 "name": "Test Profile",
1829 "tools": {
1830 EchoTool.name(): true,
1831 DelayTool.name(): true,
1832 WordListTool.name(): true,
1833 ToolRequiringPermission.name(): true,
1834 InfiniteTool.name(): true,
1835 }
1836 }
1837 }
1838 }
1839 })
1840 .to_string()
1841 .into_bytes(),
1842 )
1843 .await;
1844
1845 cx.update(|cx| {
1846 settings::init(cx);
1847 Project::init_settings(cx);
1848 agent_settings::init(cx);
1849 gpui_tokio::init(cx);
1850 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1851 cx.set_http_client(Arc::new(http_client));
1852
1853 client::init_settings(cx);
1854 let client = Client::production(cx);
1855 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1856 language_model::init(client.clone(), cx);
1857 language_models::init(user_store, client.clone(), cx);
1858
1859 watch_settings(fs.clone(), cx);
1860 });
1861
1862 let templates = Templates::new();
1863
1864 fs.insert_tree(path!("/test"), json!({})).await;
1865 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1866
1867 let model = cx
1868 .update(|cx| {
1869 if let TestModel::Fake = model {
1870 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1871 } else {
1872 let model_id = model.id();
1873 let models = LanguageModelRegistry::read_global(cx);
1874 let model = models
1875 .available_models(cx)
1876 .find(|model| model.id() == model_id)
1877 .unwrap();
1878
1879 let provider = models.provider(&model.provider_id()).unwrap();
1880 let authenticated = provider.authenticate(cx);
1881
1882 cx.spawn(async move |_cx| {
1883 authenticated.await.unwrap();
1884 model
1885 })
1886 }
1887 })
1888 .await;
1889
1890 let project_context = cx.new(|_cx| ProjectContext::default());
1891 let context_server_registry =
1892 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1893 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1894 let thread = cx.new(|cx| {
1895 Thread::new(
1896 project,
1897 project_context.clone(),
1898 context_server_registry,
1899 action_log,
1900 templates,
1901 Some(model.clone()),
1902 cx,
1903 )
1904 });
1905 ThreadTest {
1906 model,
1907 thread,
1908 project_context,
1909 fs,
1910 }
1911}
1912
1913#[cfg(test)]
1914#[ctor::ctor]
1915fn init_logger() {
1916 if std::env::var("RUST_LOG").is_ok() {
1917 env_logger::init();
1918 }
1919}
1920
1921fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1922 let fs = fs.clone();
1923 cx.spawn({
1924 async move |cx| {
1925 let mut new_settings_content_rx = settings::watch_config_file(
1926 cx.background_executor(),
1927 fs,
1928 paths::settings_file().clone(),
1929 );
1930
1931 while let Some(new_settings_content) = new_settings_content_rx.next().await {
1932 cx.update(|cx| {
1933 SettingsStore::update_global(cx, |settings, cx| {
1934 settings.set_user_settings(&new_settings_content, cx)
1935 })
1936 })
1937 .ok();
1938 }
1939 }
1940 })
1941 .detach();
1942}