1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp};
5use agent_settings::AgentProfileId;
6use anyhow::Result;
7use client::{Client, UserStore};
8use fs::{FakeFs, Fs};
9use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
10use gpui::{
11 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
12};
13use indoc::indoc;
14use language_model::{
15 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
16 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
17 LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
18 fake_provider::FakeLanguageModel,
19};
20use pretty_assertions::assert_eq;
21use project::Project;
22use prompt_store::ProjectContext;
23use reqwest_client::ReqwestClient;
24use schemars::JsonSchema;
25use serde::{Deserialize, Serialize};
26use serde_json::json;
27use settings::SettingsStore;
28use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
29use util::path;
30
31mod test_tools;
32use test_tools::*;
33
34#[gpui::test]
35async fn test_echo(cx: &mut TestAppContext) {
36 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
37 let fake_model = model.as_fake();
38
39 let events = thread
40 .update(cx, |thread, cx| {
41 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
42 })
43 .unwrap();
44 cx.run_until_parked();
45 fake_model.send_last_completion_stream_text_chunk("Hello");
46 fake_model
47 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
48 fake_model.end_last_completion_stream();
49
50 let events = events.collect().await;
51 thread.update(cx, |thread, _cx| {
52 assert_eq!(
53 thread.last_message().unwrap().to_markdown(),
54 indoc! {"
55 ## Assistant
56
57 Hello
58 "}
59 )
60 });
61 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
62}
63
64#[gpui::test]
65async fn test_thinking(cx: &mut TestAppContext) {
66 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
67 let fake_model = model.as_fake();
68
69 let events = thread
70 .update(cx, |thread, cx| {
71 thread.send(
72 UserMessageId::new(),
73 [indoc! {"
74 Testing:
75
76 Generate a thinking step where you just think the word 'Think',
77 and have your final answer be 'Hello'
78 "}],
79 cx,
80 )
81 })
82 .unwrap();
83 cx.run_until_parked();
84 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
85 text: "Think".to_string(),
86 signature: None,
87 });
88 fake_model.send_last_completion_stream_text_chunk("Hello");
89 fake_model
90 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
91 fake_model.end_last_completion_stream();
92
93 let events = events.collect().await;
94 thread.update(cx, |thread, _cx| {
95 assert_eq!(
96 thread.last_message().unwrap().to_markdown(),
97 indoc! {"
98 ## Assistant
99
100 <think>Think</think>
101 Hello
102 "}
103 )
104 });
105 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
106}
107
108#[gpui::test]
109async fn test_system_prompt(cx: &mut TestAppContext) {
110 let ThreadTest {
111 model,
112 thread,
113 project_context,
114 ..
115 } = setup(cx, TestModel::Fake).await;
116 let fake_model = model.as_fake();
117
118 project_context.update(cx, |project_context, _cx| {
119 project_context.shell = "test-shell".into()
120 });
121 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
122 thread
123 .update(cx, |thread, cx| {
124 thread.send(UserMessageId::new(), ["abc"], cx)
125 })
126 .unwrap();
127 cx.run_until_parked();
128 let mut pending_completions = fake_model.pending_completions();
129 assert_eq!(
130 pending_completions.len(),
131 1,
132 "unexpected pending completions: {:?}",
133 pending_completions
134 );
135
136 let pending_completion = pending_completions.pop().unwrap();
137 assert_eq!(pending_completion.messages[0].role, Role::System);
138
139 let system_message = &pending_completion.messages[0];
140 let system_prompt = system_message.content[0].to_str().unwrap();
141 assert!(
142 system_prompt.contains("test-shell"),
143 "unexpected system message: {:?}",
144 system_message
145 );
146 assert!(
147 system_prompt.contains("## Fixing Diagnostics"),
148 "unexpected system message: {:?}",
149 system_message
150 );
151}
152
153#[gpui::test]
154async fn test_prompt_caching(cx: &mut TestAppContext) {
155 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
156 let fake_model = model.as_fake();
157
158 // Send initial user message and verify it's cached
159 thread
160 .update(cx, |thread, cx| {
161 thread.send(UserMessageId::new(), ["Message 1"], cx)
162 })
163 .unwrap();
164 cx.run_until_parked();
165
166 let completion = fake_model.pending_completions().pop().unwrap();
167 assert_eq!(
168 completion.messages[1..],
169 vec![LanguageModelRequestMessage {
170 role: Role::User,
171 content: vec!["Message 1".into()],
172 cache: true
173 }]
174 );
175 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
176 "Response to Message 1".into(),
177 ));
178 fake_model.end_last_completion_stream();
179 cx.run_until_parked();
180
181 // Send another user message and verify only the latest is cached
182 thread
183 .update(cx, |thread, cx| {
184 thread.send(UserMessageId::new(), ["Message 2"], cx)
185 })
186 .unwrap();
187 cx.run_until_parked();
188
189 let completion = fake_model.pending_completions().pop().unwrap();
190 assert_eq!(
191 completion.messages[1..],
192 vec![
193 LanguageModelRequestMessage {
194 role: Role::User,
195 content: vec!["Message 1".into()],
196 cache: false
197 },
198 LanguageModelRequestMessage {
199 role: Role::Assistant,
200 content: vec!["Response to Message 1".into()],
201 cache: false
202 },
203 LanguageModelRequestMessage {
204 role: Role::User,
205 content: vec!["Message 2".into()],
206 cache: true
207 }
208 ]
209 );
210 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
211 "Response to Message 2".into(),
212 ));
213 fake_model.end_last_completion_stream();
214 cx.run_until_parked();
215
216 // Simulate a tool call and verify that the latest tool result is cached
217 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
218 thread
219 .update(cx, |thread, cx| {
220 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
221 })
222 .unwrap();
223 cx.run_until_parked();
224
225 let tool_use = LanguageModelToolUse {
226 id: "tool_1".into(),
227 name: EchoTool.name().into(),
228 raw_input: json!({"text": "test"}).to_string(),
229 input: json!({"text": "test"}),
230 is_input_complete: true,
231 };
232 fake_model
233 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
234 fake_model.end_last_completion_stream();
235 cx.run_until_parked();
236
237 let completion = fake_model.pending_completions().pop().unwrap();
238 let tool_result = LanguageModelToolResult {
239 tool_use_id: "tool_1".into(),
240 tool_name: EchoTool.name().into(),
241 is_error: false,
242 content: "test".into(),
243 output: Some("test".into()),
244 };
245 assert_eq!(
246 completion.messages[1..],
247 vec![
248 LanguageModelRequestMessage {
249 role: Role::User,
250 content: vec!["Message 1".into()],
251 cache: false
252 },
253 LanguageModelRequestMessage {
254 role: Role::Assistant,
255 content: vec!["Response to Message 1".into()],
256 cache: false
257 },
258 LanguageModelRequestMessage {
259 role: Role::User,
260 content: vec!["Message 2".into()],
261 cache: false
262 },
263 LanguageModelRequestMessage {
264 role: Role::Assistant,
265 content: vec!["Response to Message 2".into()],
266 cache: false
267 },
268 LanguageModelRequestMessage {
269 role: Role::User,
270 content: vec!["Use the echo tool".into()],
271 cache: false
272 },
273 LanguageModelRequestMessage {
274 role: Role::Assistant,
275 content: vec![MessageContent::ToolUse(tool_use)],
276 cache: false
277 },
278 LanguageModelRequestMessage {
279 role: Role::User,
280 content: vec![MessageContent::ToolResult(tool_result)],
281 cache: true
282 }
283 ]
284 );
285}
286
287#[gpui::test]
288#[cfg_attr(not(feature = "e2e"), ignore)]
289async fn test_basic_tool_calls(cx: &mut TestAppContext) {
290 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
291
292 // Test a tool call that's likely to complete *before* streaming stops.
293 let events = thread
294 .update(cx, |thread, cx| {
295 thread.add_tool(EchoTool);
296 thread.send(
297 UserMessageId::new(),
298 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
299 cx,
300 )
301 })
302 .unwrap()
303 .collect()
304 .await;
305 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
306
307 // Test a tool calls that's likely to complete *after* streaming stops.
308 let events = thread
309 .update(cx, |thread, cx| {
310 thread.remove_tool(&AgentTool::name(&EchoTool));
311 thread.add_tool(DelayTool);
312 thread.send(
313 UserMessageId::new(),
314 [
315 "Now call the delay tool with 200ms.",
316 "When the timer goes off, then you echo the output of the tool.",
317 ],
318 cx,
319 )
320 })
321 .unwrap()
322 .collect()
323 .await;
324 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
325 thread.update(cx, |thread, _cx| {
326 assert!(
327 thread
328 .last_message()
329 .unwrap()
330 .as_agent_message()
331 .unwrap()
332 .content
333 .iter()
334 .any(|content| {
335 if let AgentMessageContent::Text(text) = content {
336 text.contains("Ding")
337 } else {
338 false
339 }
340 }),
341 "{}",
342 thread.to_markdown()
343 );
344 });
345}
346
347#[gpui::test]
348#[cfg_attr(not(feature = "e2e"), ignore)]
349async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
350 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
351
352 // Test a tool call that's likely to complete *before* streaming stops.
353 let mut events = thread
354 .update(cx, |thread, cx| {
355 thread.add_tool(WordListTool);
356 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
357 })
358 .unwrap();
359
360 let mut saw_partial_tool_use = false;
361 while let Some(event) = events.next().await {
362 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
363 thread.update(cx, |thread, _cx| {
364 // Look for a tool use in the thread's last message
365 let message = thread.last_message().unwrap();
366 let agent_message = message.as_agent_message().unwrap();
367 let last_content = agent_message.content.last().unwrap();
368 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
369 assert_eq!(last_tool_use.name.as_ref(), "word_list");
370 if tool_call.status == acp::ToolCallStatus::Pending {
371 if !last_tool_use.is_input_complete
372 && last_tool_use.input.get("g").is_none()
373 {
374 saw_partial_tool_use = true;
375 }
376 } else {
377 last_tool_use
378 .input
379 .get("a")
380 .expect("'a' has streamed because input is now complete");
381 last_tool_use
382 .input
383 .get("g")
384 .expect("'g' has streamed because input is now complete");
385 }
386 } else {
387 panic!("last content should be a tool use");
388 }
389 });
390 }
391 }
392
393 assert!(
394 saw_partial_tool_use,
395 "should see at least one partially streamed tool use in the history"
396 );
397}
398
399#[gpui::test]
400async fn test_tool_authorization(cx: &mut TestAppContext) {
401 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
402 let fake_model = model.as_fake();
403
404 let mut events = thread
405 .update(cx, |thread, cx| {
406 thread.add_tool(ToolRequiringPermission);
407 thread.send(UserMessageId::new(), ["abc"], cx)
408 })
409 .unwrap();
410 cx.run_until_parked();
411 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
412 LanguageModelToolUse {
413 id: "tool_id_1".into(),
414 name: ToolRequiringPermission.name().into(),
415 raw_input: "{}".into(),
416 input: json!({}),
417 is_input_complete: true,
418 },
419 ));
420 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
421 LanguageModelToolUse {
422 id: "tool_id_2".into(),
423 name: ToolRequiringPermission.name().into(),
424 raw_input: "{}".into(),
425 input: json!({}),
426 is_input_complete: true,
427 },
428 ));
429 fake_model.end_last_completion_stream();
430 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
431 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
432
433 // Approve the first
434 tool_call_auth_1
435 .response
436 .send(tool_call_auth_1.options[1].id.clone())
437 .unwrap();
438 cx.run_until_parked();
439
440 // Reject the second
441 tool_call_auth_2
442 .response
443 .send(tool_call_auth_1.options[2].id.clone())
444 .unwrap();
445 cx.run_until_parked();
446
447 let completion = fake_model.pending_completions().pop().unwrap();
448 let message = completion.messages.last().unwrap();
449 assert_eq!(
450 message.content,
451 vec![
452 language_model::MessageContent::ToolResult(LanguageModelToolResult {
453 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
454 tool_name: ToolRequiringPermission.name().into(),
455 is_error: false,
456 content: "Allowed".into(),
457 output: Some("Allowed".into())
458 }),
459 language_model::MessageContent::ToolResult(LanguageModelToolResult {
460 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
461 tool_name: ToolRequiringPermission.name().into(),
462 is_error: true,
463 content: "Permission to run tool denied by user".into(),
464 output: None
465 })
466 ]
467 );
468
469 // Simulate yet another tool call.
470 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
471 LanguageModelToolUse {
472 id: "tool_id_3".into(),
473 name: ToolRequiringPermission.name().into(),
474 raw_input: "{}".into(),
475 input: json!({}),
476 is_input_complete: true,
477 },
478 ));
479 fake_model.end_last_completion_stream();
480
481 // Respond by always allowing tools.
482 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
483 tool_call_auth_3
484 .response
485 .send(tool_call_auth_3.options[0].id.clone())
486 .unwrap();
487 cx.run_until_parked();
488 let completion = fake_model.pending_completions().pop().unwrap();
489 let message = completion.messages.last().unwrap();
490 assert_eq!(
491 message.content,
492 vec![language_model::MessageContent::ToolResult(
493 LanguageModelToolResult {
494 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
495 tool_name: ToolRequiringPermission.name().into(),
496 is_error: false,
497 content: "Allowed".into(),
498 output: Some("Allowed".into())
499 }
500 )]
501 );
502
503 // Simulate a final tool call, ensuring we don't trigger authorization.
504 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
505 LanguageModelToolUse {
506 id: "tool_id_4".into(),
507 name: ToolRequiringPermission.name().into(),
508 raw_input: "{}".into(),
509 input: json!({}),
510 is_input_complete: true,
511 },
512 ));
513 fake_model.end_last_completion_stream();
514 cx.run_until_parked();
515 let completion = fake_model.pending_completions().pop().unwrap();
516 let message = completion.messages.last().unwrap();
517 assert_eq!(
518 message.content,
519 vec![language_model::MessageContent::ToolResult(
520 LanguageModelToolResult {
521 tool_use_id: "tool_id_4".into(),
522 tool_name: ToolRequiringPermission.name().into(),
523 is_error: false,
524 content: "Allowed".into(),
525 output: Some("Allowed".into())
526 }
527 )]
528 );
529}
530
531#[gpui::test]
532async fn test_tool_hallucination(cx: &mut TestAppContext) {
533 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
534 let fake_model = model.as_fake();
535
536 let mut events = thread
537 .update(cx, |thread, cx| {
538 thread.send(UserMessageId::new(), ["abc"], cx)
539 })
540 .unwrap();
541 cx.run_until_parked();
542 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
543 LanguageModelToolUse {
544 id: "tool_id_1".into(),
545 name: "nonexistent_tool".into(),
546 raw_input: "{}".into(),
547 input: json!({}),
548 is_input_complete: true,
549 },
550 ));
551 fake_model.end_last_completion_stream();
552
553 let tool_call = expect_tool_call(&mut events).await;
554 assert_eq!(tool_call.title, "nonexistent_tool");
555 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
556 let update = expect_tool_call_update_fields(&mut events).await;
557 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
558}
559
560#[gpui::test]
561async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
562 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
563 let fake_model = model.as_fake();
564
565 let events = thread
566 .update(cx, |thread, cx| {
567 thread.add_tool(EchoTool);
568 thread.send(UserMessageId::new(), ["abc"], cx)
569 })
570 .unwrap();
571 cx.run_until_parked();
572 let tool_use = LanguageModelToolUse {
573 id: "tool_id_1".into(),
574 name: EchoTool.name().into(),
575 raw_input: "{}".into(),
576 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
577 is_input_complete: true,
578 };
579 fake_model
580 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
581 fake_model.end_last_completion_stream();
582
583 cx.run_until_parked();
584 let completion = fake_model.pending_completions().pop().unwrap();
585 let tool_result = LanguageModelToolResult {
586 tool_use_id: "tool_id_1".into(),
587 tool_name: EchoTool.name().into(),
588 is_error: false,
589 content: "def".into(),
590 output: Some("def".into()),
591 };
592 assert_eq!(
593 completion.messages[1..],
594 vec![
595 LanguageModelRequestMessage {
596 role: Role::User,
597 content: vec!["abc".into()],
598 cache: false
599 },
600 LanguageModelRequestMessage {
601 role: Role::Assistant,
602 content: vec![MessageContent::ToolUse(tool_use.clone())],
603 cache: false
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec![MessageContent::ToolResult(tool_result.clone())],
608 cache: true
609 },
610 ]
611 );
612
613 // Simulate reaching tool use limit.
614 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
615 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
616 ));
617 fake_model.end_last_completion_stream();
618 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
619 assert!(
620 last_event
621 .unwrap_err()
622 .is::<language_model::ToolUseLimitReachedError>()
623 );
624
625 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
626 cx.run_until_parked();
627 let completion = fake_model.pending_completions().pop().unwrap();
628 assert_eq!(
629 completion.messages[1..],
630 vec![
631 LanguageModelRequestMessage {
632 role: Role::User,
633 content: vec!["abc".into()],
634 cache: false
635 },
636 LanguageModelRequestMessage {
637 role: Role::Assistant,
638 content: vec![MessageContent::ToolUse(tool_use)],
639 cache: false
640 },
641 LanguageModelRequestMessage {
642 role: Role::User,
643 content: vec![MessageContent::ToolResult(tool_result)],
644 cache: false
645 },
646 LanguageModelRequestMessage {
647 role: Role::User,
648 content: vec!["Continue where you left off".into()],
649 cache: true
650 }
651 ]
652 );
653
654 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
655 fake_model.end_last_completion_stream();
656 events.collect::<Vec<_>>().await;
657 thread.read_with(cx, |thread, _cx| {
658 assert_eq!(
659 thread.last_message().unwrap().to_markdown(),
660 indoc! {"
661 ## Assistant
662
663 Done
664 "}
665 )
666 });
667
668 // Ensure we error if calling resume when tool use limit was *not* reached.
669 let error = thread
670 .update(cx, |thread, cx| thread.resume(cx))
671 .unwrap_err();
672 assert_eq!(
673 error.to_string(),
674 "can only resume after tool use limit is reached"
675 )
676}
677
678#[gpui::test]
679async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
680 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
681 let fake_model = model.as_fake();
682
683 let events = thread
684 .update(cx, |thread, cx| {
685 thread.add_tool(EchoTool);
686 thread.send(UserMessageId::new(), ["abc"], cx)
687 })
688 .unwrap();
689 cx.run_until_parked();
690
691 let tool_use = LanguageModelToolUse {
692 id: "tool_id_1".into(),
693 name: EchoTool.name().into(),
694 raw_input: "{}".into(),
695 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
696 is_input_complete: true,
697 };
698 let tool_result = LanguageModelToolResult {
699 tool_use_id: "tool_id_1".into(),
700 tool_name: EchoTool.name().into(),
701 is_error: false,
702 content: "def".into(),
703 output: Some("def".into()),
704 };
705 fake_model
706 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
707 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
708 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
709 ));
710 fake_model.end_last_completion_stream();
711 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
712 assert!(
713 last_event
714 .unwrap_err()
715 .is::<language_model::ToolUseLimitReachedError>()
716 );
717
718 thread
719 .update(cx, |thread, cx| {
720 thread.send(UserMessageId::new(), vec!["ghi"], cx)
721 })
722 .unwrap();
723 cx.run_until_parked();
724 let completion = fake_model.pending_completions().pop().unwrap();
725 assert_eq!(
726 completion.messages[1..],
727 vec![
728 LanguageModelRequestMessage {
729 role: Role::User,
730 content: vec!["abc".into()],
731 cache: false
732 },
733 LanguageModelRequestMessage {
734 role: Role::Assistant,
735 content: vec![MessageContent::ToolUse(tool_use)],
736 cache: false
737 },
738 LanguageModelRequestMessage {
739 role: Role::User,
740 content: vec![MessageContent::ToolResult(tool_result)],
741 cache: false
742 },
743 LanguageModelRequestMessage {
744 role: Role::User,
745 content: vec!["ghi".into()],
746 cache: true
747 }
748 ]
749 );
750}
751
752async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
753 let event = events
754 .next()
755 .await
756 .expect("no tool call authorization event received")
757 .unwrap();
758 match event {
759 ThreadEvent::ToolCall(tool_call) => tool_call,
760 event => {
761 panic!("Unexpected event {event:?}");
762 }
763 }
764}
765
766async fn expect_tool_call_update_fields(
767 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
768) -> acp::ToolCallUpdate {
769 let event = events
770 .next()
771 .await
772 .expect("no tool call authorization event received")
773 .unwrap();
774 match event {
775 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
776 event => {
777 panic!("Unexpected event {event:?}");
778 }
779 }
780}
781
782async fn next_tool_call_authorization(
783 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
784) -> ToolCallAuthorization {
785 loop {
786 let event = events
787 .next()
788 .await
789 .expect("no tool call authorization event received")
790 .unwrap();
791 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
792 let permission_kinds = tool_call_authorization
793 .options
794 .iter()
795 .map(|o| o.kind)
796 .collect::<Vec<_>>();
797 assert_eq!(
798 permission_kinds,
799 vec![
800 acp::PermissionOptionKind::AllowAlways,
801 acp::PermissionOptionKind::AllowOnce,
802 acp::PermissionOptionKind::RejectOnce,
803 ]
804 );
805 return tool_call_authorization;
806 }
807 }
808}
809
810#[gpui::test]
811#[cfg_attr(not(feature = "e2e"), ignore)]
812async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
813 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
814
815 // Test concurrent tool calls with different delay times
816 let events = thread
817 .update(cx, |thread, cx| {
818 thread.add_tool(DelayTool);
819 thread.send(
820 UserMessageId::new(),
821 [
822 "Call the delay tool twice in the same message.",
823 "Once with 100ms. Once with 300ms.",
824 "When both timers are complete, describe the outputs.",
825 ],
826 cx,
827 )
828 })
829 .unwrap()
830 .collect()
831 .await;
832
833 let stop_reasons = stop_events(events);
834 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
835
836 thread.update(cx, |thread, _cx| {
837 let last_message = thread.last_message().unwrap();
838 let agent_message = last_message.as_agent_message().unwrap();
839 let text = agent_message
840 .content
841 .iter()
842 .filter_map(|content| {
843 if let AgentMessageContent::Text(text) = content {
844 Some(text.as_str())
845 } else {
846 None
847 }
848 })
849 .collect::<String>();
850
851 assert!(text.contains("Ding"));
852 });
853}
854
855#[gpui::test]
856async fn test_profiles(cx: &mut TestAppContext) {
857 let ThreadTest {
858 model, thread, fs, ..
859 } = setup(cx, TestModel::Fake).await;
860 let fake_model = model.as_fake();
861
862 thread.update(cx, |thread, _cx| {
863 thread.add_tool(DelayTool);
864 thread.add_tool(EchoTool);
865 thread.add_tool(InfiniteTool);
866 });
867
868 // Override profiles and wait for settings to be loaded.
869 fs.insert_file(
870 paths::settings_file(),
871 json!({
872 "agent": {
873 "profiles": {
874 "test-1": {
875 "name": "Test Profile 1",
876 "tools": {
877 EchoTool.name(): true,
878 DelayTool.name(): true,
879 }
880 },
881 "test-2": {
882 "name": "Test Profile 2",
883 "tools": {
884 InfiniteTool.name(): true,
885 }
886 }
887 }
888 }
889 })
890 .to_string()
891 .into_bytes(),
892 )
893 .await;
894 cx.run_until_parked();
895
896 // Test that test-1 profile (default) has echo and delay tools
897 thread
898 .update(cx, |thread, cx| {
899 thread.set_profile(AgentProfileId("test-1".into()));
900 thread.send(UserMessageId::new(), ["test"], cx)
901 })
902 .unwrap();
903 cx.run_until_parked();
904
905 let mut pending_completions = fake_model.pending_completions();
906 assert_eq!(pending_completions.len(), 1);
907 let completion = pending_completions.pop().unwrap();
908 let tool_names: Vec<String> = completion
909 .tools
910 .iter()
911 .map(|tool| tool.name.clone())
912 .collect();
913 assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
914 fake_model.end_last_completion_stream();
915
916 // Switch to test-2 profile, and verify that it has only the infinite tool.
917 thread
918 .update(cx, |thread, cx| {
919 thread.set_profile(AgentProfileId("test-2".into()));
920 thread.send(UserMessageId::new(), ["test2"], cx)
921 })
922 .unwrap();
923 cx.run_until_parked();
924 let mut pending_completions = fake_model.pending_completions();
925 assert_eq!(pending_completions.len(), 1);
926 let completion = pending_completions.pop().unwrap();
927 let tool_names: Vec<String> = completion
928 .tools
929 .iter()
930 .map(|tool| tool.name.clone())
931 .collect();
932 assert_eq!(tool_names, vec![InfiniteTool.name()]);
933}
934
935#[gpui::test]
936#[cfg_attr(not(feature = "e2e"), ignore)]
937async fn test_cancellation(cx: &mut TestAppContext) {
938 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
939
940 let mut events = thread
941 .update(cx, |thread, cx| {
942 thread.add_tool(InfiniteTool);
943 thread.add_tool(EchoTool);
944 thread.send(
945 UserMessageId::new(),
946 ["Call the echo tool, then call the infinite tool, then explain their output"],
947 cx,
948 )
949 })
950 .unwrap();
951
952 // Wait until both tools are called.
953 let mut expected_tools = vec!["Echo", "Infinite Tool"];
954 let mut echo_id = None;
955 let mut echo_completed = false;
956 while let Some(event) = events.next().await {
957 match event.unwrap() {
958 ThreadEvent::ToolCall(tool_call) => {
959 assert_eq!(tool_call.title, expected_tools.remove(0));
960 if tool_call.title == "Echo" {
961 echo_id = Some(tool_call.id);
962 }
963 }
964 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
965 acp::ToolCallUpdate {
966 id,
967 fields:
968 acp::ToolCallUpdateFields {
969 status: Some(acp::ToolCallStatus::Completed),
970 ..
971 },
972 },
973 )) if Some(&id) == echo_id.as_ref() => {
974 echo_completed = true;
975 }
976 _ => {}
977 }
978
979 if expected_tools.is_empty() && echo_completed {
980 break;
981 }
982 }
983
984 // Cancel the current send and ensure that the event stream is closed, even
985 // if one of the tools is still running.
986 thread.update(cx, |thread, cx| thread.cancel(cx));
987 let events = events.collect::<Vec<_>>().await;
988 let last_event = events.last();
989 assert!(
990 matches!(
991 last_event,
992 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
993 ),
994 "unexpected event {last_event:?}"
995 );
996
997 // Ensure we can still send a new message after cancellation.
998 let events = thread
999 .update(cx, |thread, cx| {
1000 thread.send(
1001 UserMessageId::new(),
1002 ["Testing: reply with 'Hello' then stop."],
1003 cx,
1004 )
1005 })
1006 .unwrap()
1007 .collect::<Vec<_>>()
1008 .await;
1009 thread.update(cx, |thread, _cx| {
1010 let message = thread.last_message().unwrap();
1011 let agent_message = message.as_agent_message().unwrap();
1012 assert_eq!(
1013 agent_message.content,
1014 vec![AgentMessageContent::Text("Hello".to_string())]
1015 );
1016 });
1017 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1018}
1019
1020#[gpui::test]
1021async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1022 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1023 let fake_model = model.as_fake();
1024
1025 let events_1 = thread
1026 .update(cx, |thread, cx| {
1027 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1028 })
1029 .unwrap();
1030 cx.run_until_parked();
1031 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1032 cx.run_until_parked();
1033
1034 let events_2 = thread
1035 .update(cx, |thread, cx| {
1036 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1037 })
1038 .unwrap();
1039 cx.run_until_parked();
1040 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1041 fake_model
1042 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1043 fake_model.end_last_completion_stream();
1044
1045 let events_1 = events_1.collect::<Vec<_>>().await;
1046 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1047 let events_2 = events_2.collect::<Vec<_>>().await;
1048 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1049}
1050
1051#[gpui::test]
1052async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1053 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1054 let fake_model = model.as_fake();
1055
1056 let events_1 = thread
1057 .update(cx, |thread, cx| {
1058 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1059 })
1060 .unwrap();
1061 cx.run_until_parked();
1062 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1063 fake_model
1064 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1065 fake_model.end_last_completion_stream();
1066 let events_1 = events_1.collect::<Vec<_>>().await;
1067
1068 let events_2 = thread
1069 .update(cx, |thread, cx| {
1070 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1071 })
1072 .unwrap();
1073 cx.run_until_parked();
1074 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1075 fake_model
1076 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1077 fake_model.end_last_completion_stream();
1078 let events_2 = events_2.collect::<Vec<_>>().await;
1079
1080 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1081 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1082}
1083
1084#[gpui::test]
1085async fn test_refusal(cx: &mut TestAppContext) {
1086 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1087 let fake_model = model.as_fake();
1088
1089 let events = thread
1090 .update(cx, |thread, cx| {
1091 thread.send(UserMessageId::new(), ["Hello"], cx)
1092 })
1093 .unwrap();
1094 cx.run_until_parked();
1095 thread.read_with(cx, |thread, _| {
1096 assert_eq!(
1097 thread.to_markdown(),
1098 indoc! {"
1099 ## User
1100
1101 Hello
1102 "}
1103 );
1104 });
1105
1106 fake_model.send_last_completion_stream_text_chunk("Hey!");
1107 cx.run_until_parked();
1108 thread.read_with(cx, |thread, _| {
1109 assert_eq!(
1110 thread.to_markdown(),
1111 indoc! {"
1112 ## User
1113
1114 Hello
1115
1116 ## Assistant
1117
1118 Hey!
1119 "}
1120 );
1121 });
1122
1123 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1124 fake_model
1125 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1126 let events = events.collect::<Vec<_>>().await;
1127 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1128 thread.read_with(cx, |thread, _| {
1129 assert_eq!(thread.to_markdown(), "");
1130 });
1131}
1132
1133#[gpui::test]
1134async fn test_truncate_first_message(cx: &mut TestAppContext) {
1135 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1136 let fake_model = model.as_fake();
1137
1138 let message_id = UserMessageId::new();
1139 thread
1140 .update(cx, |thread, cx| {
1141 thread.send(message_id.clone(), ["Hello"], cx)
1142 })
1143 .unwrap();
1144 cx.run_until_parked();
1145 thread.read_with(cx, |thread, _| {
1146 assert_eq!(
1147 thread.to_markdown(),
1148 indoc! {"
1149 ## User
1150
1151 Hello
1152 "}
1153 );
1154 assert_eq!(thread.latest_token_usage(), None);
1155 });
1156
1157 fake_model.send_last_completion_stream_text_chunk("Hey!");
1158 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1159 language_model::TokenUsage {
1160 input_tokens: 32_000,
1161 output_tokens: 16_000,
1162 cache_creation_input_tokens: 0,
1163 cache_read_input_tokens: 0,
1164 },
1165 ));
1166 cx.run_until_parked();
1167 thread.read_with(cx, |thread, _| {
1168 assert_eq!(
1169 thread.to_markdown(),
1170 indoc! {"
1171 ## User
1172
1173 Hello
1174
1175 ## Assistant
1176
1177 Hey!
1178 "}
1179 );
1180 assert_eq!(
1181 thread.latest_token_usage(),
1182 Some(acp_thread::TokenUsage {
1183 used_tokens: 32_000 + 16_000,
1184 max_tokens: 1_000_000,
1185 })
1186 );
1187 });
1188
1189 thread
1190 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1191 .unwrap();
1192 cx.run_until_parked();
1193 thread.read_with(cx, |thread, _| {
1194 assert_eq!(thread.to_markdown(), "");
1195 assert_eq!(thread.latest_token_usage(), None);
1196 });
1197
1198 // Ensure we can still send a new message after truncation.
1199 thread
1200 .update(cx, |thread, cx| {
1201 thread.send(UserMessageId::new(), ["Hi"], cx)
1202 })
1203 .unwrap();
1204 thread.update(cx, |thread, _cx| {
1205 assert_eq!(
1206 thread.to_markdown(),
1207 indoc! {"
1208 ## User
1209
1210 Hi
1211 "}
1212 );
1213 });
1214 cx.run_until_parked();
1215 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1216 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1217 language_model::TokenUsage {
1218 input_tokens: 40_000,
1219 output_tokens: 20_000,
1220 cache_creation_input_tokens: 0,
1221 cache_read_input_tokens: 0,
1222 },
1223 ));
1224 cx.run_until_parked();
1225 thread.read_with(cx, |thread, _| {
1226 assert_eq!(
1227 thread.to_markdown(),
1228 indoc! {"
1229 ## User
1230
1231 Hi
1232
1233 ## Assistant
1234
1235 Ahoy!
1236 "}
1237 );
1238
1239 assert_eq!(
1240 thread.latest_token_usage(),
1241 Some(acp_thread::TokenUsage {
1242 used_tokens: 40_000 + 20_000,
1243 max_tokens: 1_000_000,
1244 })
1245 );
1246 });
1247}
1248
1249#[gpui::test]
1250async fn test_truncate_second_message(cx: &mut TestAppContext) {
1251 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1252 let fake_model = model.as_fake();
1253
1254 thread
1255 .update(cx, |thread, cx| {
1256 thread.send(UserMessageId::new(), ["Message 1"], cx)
1257 })
1258 .unwrap();
1259 cx.run_until_parked();
1260 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1261 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1262 language_model::TokenUsage {
1263 input_tokens: 32_000,
1264 output_tokens: 16_000,
1265 cache_creation_input_tokens: 0,
1266 cache_read_input_tokens: 0,
1267 },
1268 ));
1269 fake_model.end_last_completion_stream();
1270 cx.run_until_parked();
1271
1272 let assert_first_message_state = |cx: &mut TestAppContext| {
1273 thread.clone().read_with(cx, |thread, _| {
1274 assert_eq!(
1275 thread.to_markdown(),
1276 indoc! {"
1277 ## User
1278
1279 Message 1
1280
1281 ## Assistant
1282
1283 Message 1 response
1284 "}
1285 );
1286
1287 assert_eq!(
1288 thread.latest_token_usage(),
1289 Some(acp_thread::TokenUsage {
1290 used_tokens: 32_000 + 16_000,
1291 max_tokens: 1_000_000,
1292 })
1293 );
1294 });
1295 };
1296
1297 assert_first_message_state(cx);
1298
1299 let second_message_id = UserMessageId::new();
1300 thread
1301 .update(cx, |thread, cx| {
1302 thread.send(second_message_id.clone(), ["Message 2"], cx)
1303 })
1304 .unwrap();
1305 cx.run_until_parked();
1306
1307 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1308 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1309 language_model::TokenUsage {
1310 input_tokens: 40_000,
1311 output_tokens: 20_000,
1312 cache_creation_input_tokens: 0,
1313 cache_read_input_tokens: 0,
1314 },
1315 ));
1316 fake_model.end_last_completion_stream();
1317 cx.run_until_parked();
1318
1319 thread.read_with(cx, |thread, _| {
1320 assert_eq!(
1321 thread.to_markdown(),
1322 indoc! {"
1323 ## User
1324
1325 Message 1
1326
1327 ## Assistant
1328
1329 Message 1 response
1330
1331 ## User
1332
1333 Message 2
1334
1335 ## Assistant
1336
1337 Message 2 response
1338 "}
1339 );
1340
1341 assert_eq!(
1342 thread.latest_token_usage(),
1343 Some(acp_thread::TokenUsage {
1344 used_tokens: 40_000 + 20_000,
1345 max_tokens: 1_000_000,
1346 })
1347 );
1348 });
1349
1350 thread
1351 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1352 .unwrap();
1353 cx.run_until_parked();
1354
1355 assert_first_message_state(cx);
1356}
1357
1358#[gpui::test]
1359async fn test_title_generation(cx: &mut TestAppContext) {
1360 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1361 let fake_model = model.as_fake();
1362
1363 let summary_model = Arc::new(FakeLanguageModel::default());
1364 thread.update(cx, |thread, cx| {
1365 thread.set_summarization_model(Some(summary_model.clone()), cx)
1366 });
1367
1368 let send = thread
1369 .update(cx, |thread, cx| {
1370 thread.send(UserMessageId::new(), ["Hello"], cx)
1371 })
1372 .unwrap();
1373 cx.run_until_parked();
1374
1375 fake_model.send_last_completion_stream_text_chunk("Hey!");
1376 fake_model.end_last_completion_stream();
1377 cx.run_until_parked();
1378 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1379
1380 // Ensure the summary model has been invoked to generate a title.
1381 summary_model.send_last_completion_stream_text_chunk("Hello ");
1382 summary_model.send_last_completion_stream_text_chunk("world\nG");
1383 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1384 summary_model.end_last_completion_stream();
1385 send.collect::<Vec<_>>().await;
1386 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1387
1388 // Send another message, ensuring no title is generated this time.
1389 let send = thread
1390 .update(cx, |thread, cx| {
1391 thread.send(UserMessageId::new(), ["Hello again"], cx)
1392 })
1393 .unwrap();
1394 cx.run_until_parked();
1395 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1396 fake_model.end_last_completion_stream();
1397 cx.run_until_parked();
1398 assert_eq!(summary_model.pending_completions(), Vec::new());
1399 send.collect::<Vec<_>>().await;
1400 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1401}
1402
1403#[gpui::test]
1404async fn test_agent_connection(cx: &mut TestAppContext) {
1405 cx.update(settings::init);
1406 let templates = Templates::new();
1407
1408 // Initialize language model system with test provider
1409 cx.update(|cx| {
1410 gpui_tokio::init(cx);
1411 client::init_settings(cx);
1412
1413 let http_client = FakeHttpClient::with_404_response();
1414 let clock = Arc::new(clock::FakeSystemClock::new());
1415 let client = Client::new(clock, http_client, cx);
1416 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1417 language_model::init(client.clone(), cx);
1418 language_models::init(user_store, client.clone(), cx);
1419 Project::init_settings(cx);
1420 LanguageModelRegistry::test(cx);
1421 agent_settings::init(cx);
1422 });
1423 cx.executor().forbid_parking();
1424
1425 // Create a project for new_thread
1426 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1427 fake_fs.insert_tree(path!("/test"), json!({})).await;
1428 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1429 let cwd = Path::new("/test");
1430 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1431 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1432
1433 // Create agent and connection
1434 let agent = NativeAgent::new(
1435 project.clone(),
1436 history_store,
1437 templates.clone(),
1438 None,
1439 fake_fs.clone(),
1440 &mut cx.to_async(),
1441 )
1442 .await
1443 .unwrap();
1444 let connection = NativeAgentConnection(agent.clone());
1445
1446 // Test model_selector returns Some
1447 let selector_opt = connection.model_selector();
1448 assert!(
1449 selector_opt.is_some(),
1450 "agent2 should always support ModelSelector"
1451 );
1452 let selector = selector_opt.unwrap();
1453
1454 // Test list_models
1455 let listed_models = cx
1456 .update(|cx| selector.list_models(cx))
1457 .await
1458 .expect("list_models should succeed");
1459 let AgentModelList::Grouped(listed_models) = listed_models else {
1460 panic!("Unexpected model list type");
1461 };
1462 assert!(!listed_models.is_empty(), "should have at least one model");
1463 assert_eq!(
1464 listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1465 "fake/fake"
1466 );
1467
1468 // Create a thread using new_thread
1469 let connection_rc = Rc::new(connection.clone());
1470 let acp_thread = cx
1471 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1472 .await
1473 .expect("new_thread should succeed");
1474
1475 // Get the session_id from the AcpThread
1476 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1477
1478 // Test selected_model returns the default
1479 let model = cx
1480 .update(|cx| selector.selected_model(&session_id, cx))
1481 .await
1482 .expect("selected_model should succeed");
1483 let model = cx
1484 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1485 .unwrap();
1486 let model = model.as_fake();
1487 assert_eq!(model.id().0, "fake", "should return default model");
1488
1489 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1490 cx.run_until_parked();
1491 model.send_last_completion_stream_text_chunk("def");
1492 cx.run_until_parked();
1493 acp_thread.read_with(cx, |thread, cx| {
1494 assert_eq!(
1495 thread.to_markdown(cx),
1496 indoc! {"
1497 ## User
1498
1499 abc
1500
1501 ## Assistant
1502
1503 def
1504
1505 "}
1506 )
1507 });
1508
1509 // Test cancel
1510 cx.update(|cx| connection.cancel(&session_id, cx));
1511 request.await.expect("prompt should fail gracefully");
1512
1513 // Ensure that dropping the ACP thread causes the native thread to be
1514 // dropped as well.
1515 cx.update(|_| drop(acp_thread));
1516 let result = cx
1517 .update(|cx| {
1518 connection.prompt(
1519 Some(acp_thread::UserMessageId::new()),
1520 acp::PromptRequest {
1521 session_id: session_id.clone(),
1522 prompt: vec!["ghi".into()],
1523 },
1524 cx,
1525 )
1526 })
1527 .await;
1528 assert_eq!(
1529 result.as_ref().unwrap_err().to_string(),
1530 "Session not found",
1531 "unexpected result: {:?}",
1532 result
1533 );
1534}
1535
1536#[gpui::test]
1537async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1538 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1539 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1540 let fake_model = model.as_fake();
1541
1542 let mut events = thread
1543 .update(cx, |thread, cx| {
1544 thread.send(UserMessageId::new(), ["Think"], cx)
1545 })
1546 .unwrap();
1547 cx.run_until_parked();
1548
1549 // Simulate streaming partial input.
1550 let input = json!({});
1551 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1552 LanguageModelToolUse {
1553 id: "1".into(),
1554 name: ThinkingTool.name().into(),
1555 raw_input: input.to_string(),
1556 input,
1557 is_input_complete: false,
1558 },
1559 ));
1560
1561 // Input streaming completed
1562 let input = json!({ "content": "Thinking hard!" });
1563 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1564 LanguageModelToolUse {
1565 id: "1".into(),
1566 name: "thinking".into(),
1567 raw_input: input.to_string(),
1568 input,
1569 is_input_complete: true,
1570 },
1571 ));
1572 fake_model.end_last_completion_stream();
1573 cx.run_until_parked();
1574
1575 let tool_call = expect_tool_call(&mut events).await;
1576 assert_eq!(
1577 tool_call,
1578 acp::ToolCall {
1579 id: acp::ToolCallId("1".into()),
1580 title: "Thinking".into(),
1581 kind: acp::ToolKind::Think,
1582 status: acp::ToolCallStatus::Pending,
1583 content: vec![],
1584 locations: vec![],
1585 raw_input: Some(json!({})),
1586 raw_output: None,
1587 }
1588 );
1589 let update = expect_tool_call_update_fields(&mut events).await;
1590 assert_eq!(
1591 update,
1592 acp::ToolCallUpdate {
1593 id: acp::ToolCallId("1".into()),
1594 fields: acp::ToolCallUpdateFields {
1595 title: Some("Thinking".into()),
1596 kind: Some(acp::ToolKind::Think),
1597 raw_input: Some(json!({ "content": "Thinking hard!" })),
1598 ..Default::default()
1599 },
1600 }
1601 );
1602 let update = expect_tool_call_update_fields(&mut events).await;
1603 assert_eq!(
1604 update,
1605 acp::ToolCallUpdate {
1606 id: acp::ToolCallId("1".into()),
1607 fields: acp::ToolCallUpdateFields {
1608 status: Some(acp::ToolCallStatus::InProgress),
1609 ..Default::default()
1610 },
1611 }
1612 );
1613 let update = expect_tool_call_update_fields(&mut events).await;
1614 assert_eq!(
1615 update,
1616 acp::ToolCallUpdate {
1617 id: acp::ToolCallId("1".into()),
1618 fields: acp::ToolCallUpdateFields {
1619 content: Some(vec!["Thinking hard!".into()]),
1620 ..Default::default()
1621 },
1622 }
1623 );
1624 let update = expect_tool_call_update_fields(&mut events).await;
1625 assert_eq!(
1626 update,
1627 acp::ToolCallUpdate {
1628 id: acp::ToolCallId("1".into()),
1629 fields: acp::ToolCallUpdateFields {
1630 status: Some(acp::ToolCallStatus::Completed),
1631 raw_output: Some("Finished thinking.".into()),
1632 ..Default::default()
1633 },
1634 }
1635 );
1636}
1637
1638#[gpui::test]
1639async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
1640 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1641 let fake_model = model.as_fake();
1642
1643 let mut events = thread
1644 .update(cx, |thread, cx| {
1645 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1646 thread.send(UserMessageId::new(), ["Hello!"], cx)
1647 })
1648 .unwrap();
1649 cx.run_until_parked();
1650
1651 fake_model.send_last_completion_stream_text_chunk("Hey!");
1652 fake_model.end_last_completion_stream();
1653
1654 let mut retry_events = Vec::new();
1655 while let Some(Ok(event)) = events.next().await {
1656 match event {
1657 ThreadEvent::Retry(retry_status) => {
1658 retry_events.push(retry_status);
1659 }
1660 ThreadEvent::Stop(..) => break,
1661 _ => {}
1662 }
1663 }
1664
1665 assert_eq!(retry_events.len(), 0);
1666 thread.read_with(cx, |thread, _cx| {
1667 assert_eq!(
1668 thread.to_markdown(),
1669 indoc! {"
1670 ## User
1671
1672 Hello!
1673
1674 ## Assistant
1675
1676 Hey!
1677 "}
1678 )
1679 });
1680}
1681
1682#[gpui::test]
1683async fn test_send_retry_on_error(cx: &mut TestAppContext) {
1684 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1685 let fake_model = model.as_fake();
1686
1687 let mut events = thread
1688 .update(cx, |thread, cx| {
1689 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1690 thread.send(UserMessageId::new(), ["Hello!"], cx)
1691 })
1692 .unwrap();
1693 cx.run_until_parked();
1694
1695 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
1696 provider: LanguageModelProviderName::new("Anthropic"),
1697 retry_after: Some(Duration::from_secs(3)),
1698 });
1699 fake_model.end_last_completion_stream();
1700
1701 cx.executor().advance_clock(Duration::from_secs(3));
1702 cx.run_until_parked();
1703
1704 fake_model.send_last_completion_stream_text_chunk("Hey!");
1705 fake_model.end_last_completion_stream();
1706
1707 let mut retry_events = Vec::new();
1708 while let Some(Ok(event)) = events.next().await {
1709 match event {
1710 ThreadEvent::Retry(retry_status) => {
1711 retry_events.push(retry_status);
1712 }
1713 ThreadEvent::Stop(..) => break,
1714 _ => {}
1715 }
1716 }
1717
1718 assert_eq!(retry_events.len(), 1);
1719 assert!(matches!(
1720 retry_events[0],
1721 acp_thread::RetryStatus { attempt: 1, .. }
1722 ));
1723 thread.read_with(cx, |thread, _cx| {
1724 assert_eq!(
1725 thread.to_markdown(),
1726 indoc! {"
1727 ## User
1728
1729 Hello!
1730
1731 ## Assistant
1732
1733 Hey!
1734 "}
1735 )
1736 });
1737}
1738
1739#[gpui::test]
1740async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
1741 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1742 let fake_model = model.as_fake();
1743
1744 let mut events = thread
1745 .update(cx, |thread, cx| {
1746 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1747 thread.send(UserMessageId::new(), ["Hello!"], cx)
1748 })
1749 .unwrap();
1750 cx.run_until_parked();
1751
1752 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
1753 fake_model.send_last_completion_stream_error(
1754 LanguageModelCompletionError::ServerOverloaded {
1755 provider: LanguageModelProviderName::new("Anthropic"),
1756 retry_after: Some(Duration::from_secs(3)),
1757 },
1758 );
1759 fake_model.end_last_completion_stream();
1760 cx.executor().advance_clock(Duration::from_secs(3));
1761 cx.run_until_parked();
1762 }
1763
1764 let mut errors = Vec::new();
1765 let mut retry_events = Vec::new();
1766 while let Some(event) = events.next().await {
1767 match event {
1768 Ok(ThreadEvent::Retry(retry_status)) => {
1769 retry_events.push(retry_status);
1770 }
1771 Ok(ThreadEvent::Stop(..)) => break,
1772 Err(error) => errors.push(error),
1773 _ => {}
1774 }
1775 }
1776
1777 assert_eq!(
1778 retry_events.len(),
1779 crate::thread::MAX_RETRY_ATTEMPTS as usize
1780 );
1781 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
1782 assert_eq!(retry_events[i].attempt, i + 1);
1783 }
1784 assert_eq!(errors.len(), 1);
1785 let error = errors[0]
1786 .downcast_ref::<LanguageModelCompletionError>()
1787 .unwrap();
1788 assert!(matches!(
1789 error,
1790 LanguageModelCompletionError::ServerOverloaded { .. }
1791 ));
1792}
1793
1794/// Filters out the stop events for asserting against in tests
1795fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
1796 result_events
1797 .into_iter()
1798 .filter_map(|event| match event.unwrap() {
1799 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
1800 _ => None,
1801 })
1802 .collect()
1803}
1804
1805struct ThreadTest {
1806 model: Arc<dyn LanguageModel>,
1807 thread: Entity<Thread>,
1808 project_context: Entity<ProjectContext>,
1809 fs: Arc<FakeFs>,
1810}
1811
1812enum TestModel {
1813 Sonnet4,
1814 Fake,
1815}
1816
1817impl TestModel {
1818 fn id(&self) -> LanguageModelId {
1819 match self {
1820 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1821 TestModel::Fake => unreachable!(),
1822 }
1823 }
1824}
1825
1826async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1827 cx.executor().allow_parking();
1828
1829 let fs = FakeFs::new(cx.background_executor.clone());
1830 fs.create_dir(paths::settings_file().parent().unwrap())
1831 .await
1832 .unwrap();
1833 fs.insert_file(
1834 paths::settings_file(),
1835 json!({
1836 "agent": {
1837 "default_profile": "test-profile",
1838 "profiles": {
1839 "test-profile": {
1840 "name": "Test Profile",
1841 "tools": {
1842 EchoTool.name(): true,
1843 DelayTool.name(): true,
1844 WordListTool.name(): true,
1845 ToolRequiringPermission.name(): true,
1846 InfiniteTool.name(): true,
1847 }
1848 }
1849 }
1850 }
1851 })
1852 .to_string()
1853 .into_bytes(),
1854 )
1855 .await;
1856
1857 cx.update(|cx| {
1858 settings::init(cx);
1859 Project::init_settings(cx);
1860 agent_settings::init(cx);
1861 gpui_tokio::init(cx);
1862 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1863 cx.set_http_client(Arc::new(http_client));
1864
1865 client::init_settings(cx);
1866 let client = Client::production(cx);
1867 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1868 language_model::init(client.clone(), cx);
1869 language_models::init(user_store, client.clone(), cx);
1870
1871 watch_settings(fs.clone(), cx);
1872 });
1873
1874 let templates = Templates::new();
1875
1876 fs.insert_tree(path!("/test"), json!({})).await;
1877 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1878
1879 let model = cx
1880 .update(|cx| {
1881 if let TestModel::Fake = model {
1882 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1883 } else {
1884 let model_id = model.id();
1885 let models = LanguageModelRegistry::read_global(cx);
1886 let model = models
1887 .available_models(cx)
1888 .find(|model| model.id() == model_id)
1889 .unwrap();
1890
1891 let provider = models.provider(&model.provider_id()).unwrap();
1892 let authenticated = provider.authenticate(cx);
1893
1894 cx.spawn(async move |_cx| {
1895 authenticated.await.unwrap();
1896 model
1897 })
1898 }
1899 })
1900 .await;
1901
1902 let project_context = cx.new(|_cx| ProjectContext::default());
1903 let context_server_registry =
1904 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1905 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1906 let thread = cx.new(|cx| {
1907 Thread::new(
1908 project,
1909 project_context.clone(),
1910 context_server_registry,
1911 action_log,
1912 templates,
1913 Some(model.clone()),
1914 cx,
1915 )
1916 });
1917 ThreadTest {
1918 model,
1919 thread,
1920 project_context,
1921 fs,
1922 }
1923}
1924
1925#[cfg(test)]
1926#[ctor::ctor]
1927fn init_logger() {
1928 if std::env::var("RUST_LOG").is_ok() {
1929 env_logger::init();
1930 }
1931}
1932
1933fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1934 let fs = fs.clone();
1935 cx.spawn({
1936 async move |cx| {
1937 let mut new_settings_content_rx = settings::watch_config_file(
1938 cx.background_executor(),
1939 fs,
1940 paths::settings_file().clone(),
1941 );
1942
1943 while let Some(new_settings_content) = new_settings_content_rx.next().await {
1944 cx.update(|cx| {
1945 SettingsStore::update_global(cx, |settings, cx| {
1946 settings.set_user_settings(&new_settings_content, cx)
1947 })
1948 })
1949 .ok();
1950 }
1951 }
1952 })
1953 .detach();
1954}