1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp};
5use agent_settings::AgentProfileId;
6use anyhow::Result;
7use client::{Client, UserStore};
8use fs::{FakeFs, Fs};
9use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
10use gpui::{
11 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
12};
13use indoc::indoc;
14use language_model::{
15 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
16 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
17 LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
18 fake_provider::FakeLanguageModel,
19};
20use pretty_assertions::assert_eq;
21use project::Project;
22use prompt_store::ProjectContext;
23use reqwest_client::ReqwestClient;
24use schemars::JsonSchema;
25use serde::{Deserialize, Serialize};
26use serde_json::json;
27use settings::SettingsStore;
28use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
29use util::path;
30
31mod test_tools;
32use test_tools::*;
33
34#[gpui::test]
35async fn test_echo(cx: &mut TestAppContext) {
36 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
37 let fake_model = model.as_fake();
38
39 let events = thread
40 .update(cx, |thread, cx| {
41 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
42 })
43 .unwrap();
44 cx.run_until_parked();
45 fake_model.send_last_completion_stream_text_chunk("Hello");
46 fake_model
47 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
48 fake_model.end_last_completion_stream();
49
50 let events = events.collect().await;
51 thread.update(cx, |thread, _cx| {
52 assert_eq!(
53 thread.last_message().unwrap().to_markdown(),
54 indoc! {"
55 ## Assistant
56
57 Hello
58 "}
59 )
60 });
61 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
62}
63
64#[gpui::test]
65async fn test_thinking(cx: &mut TestAppContext) {
66 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
67 let fake_model = model.as_fake();
68
69 let events = thread
70 .update(cx, |thread, cx| {
71 thread.send(
72 UserMessageId::new(),
73 [indoc! {"
74 Testing:
75
76 Generate a thinking step where you just think the word 'Think',
77 and have your final answer be 'Hello'
78 "}],
79 cx,
80 )
81 })
82 .unwrap();
83 cx.run_until_parked();
84 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
85 text: "Think".to_string(),
86 signature: None,
87 });
88 fake_model.send_last_completion_stream_text_chunk("Hello");
89 fake_model
90 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
91 fake_model.end_last_completion_stream();
92
93 let events = events.collect().await;
94 thread.update(cx, |thread, _cx| {
95 assert_eq!(
96 thread.last_message().unwrap().to_markdown(),
97 indoc! {"
98 ## Assistant
99
100 <think>Think</think>
101 Hello
102 "}
103 )
104 });
105 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
106}
107
108#[gpui::test]
109async fn test_system_prompt(cx: &mut TestAppContext) {
110 let ThreadTest {
111 model,
112 thread,
113 project_context,
114 ..
115 } = setup(cx, TestModel::Fake).await;
116 let fake_model = model.as_fake();
117
118 project_context.update(cx, |project_context, _cx| {
119 project_context.shell = "test-shell".into()
120 });
121 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
122 thread
123 .update(cx, |thread, cx| {
124 thread.send(UserMessageId::new(), ["abc"], cx)
125 })
126 .unwrap();
127 cx.run_until_parked();
128 let mut pending_completions = fake_model.pending_completions();
129 assert_eq!(
130 pending_completions.len(),
131 1,
132 "unexpected pending completions: {:?}",
133 pending_completions
134 );
135
136 let pending_completion = pending_completions.pop().unwrap();
137 assert_eq!(pending_completion.messages[0].role, Role::System);
138
139 let system_message = &pending_completion.messages[0];
140 let system_prompt = system_message.content[0].to_str().unwrap();
141 assert!(
142 system_prompt.contains("test-shell"),
143 "unexpected system message: {:?}",
144 system_message
145 );
146 assert!(
147 system_prompt.contains("## Fixing Diagnostics"),
148 "unexpected system message: {:?}",
149 system_message
150 );
151}
152
153#[gpui::test]
154async fn test_prompt_caching(cx: &mut TestAppContext) {
155 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
156 let fake_model = model.as_fake();
157
158 // Send initial user message and verify it's cached
159 thread
160 .update(cx, |thread, cx| {
161 thread.send(UserMessageId::new(), ["Message 1"], cx)
162 })
163 .unwrap();
164 cx.run_until_parked();
165
166 let completion = fake_model.pending_completions().pop().unwrap();
167 assert_eq!(
168 completion.messages[1..],
169 vec![LanguageModelRequestMessage {
170 role: Role::User,
171 content: vec!["Message 1".into()],
172 cache: true
173 }]
174 );
175 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
176 "Response to Message 1".into(),
177 ));
178 fake_model.end_last_completion_stream();
179 cx.run_until_parked();
180
181 // Send another user message and verify only the latest is cached
182 thread
183 .update(cx, |thread, cx| {
184 thread.send(UserMessageId::new(), ["Message 2"], cx)
185 })
186 .unwrap();
187 cx.run_until_parked();
188
189 let completion = fake_model.pending_completions().pop().unwrap();
190 assert_eq!(
191 completion.messages[1..],
192 vec![
193 LanguageModelRequestMessage {
194 role: Role::User,
195 content: vec!["Message 1".into()],
196 cache: false
197 },
198 LanguageModelRequestMessage {
199 role: Role::Assistant,
200 content: vec!["Response to Message 1".into()],
201 cache: false
202 },
203 LanguageModelRequestMessage {
204 role: Role::User,
205 content: vec!["Message 2".into()],
206 cache: true
207 }
208 ]
209 );
210 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
211 "Response to Message 2".into(),
212 ));
213 fake_model.end_last_completion_stream();
214 cx.run_until_parked();
215
216 // Simulate a tool call and verify that the latest tool result is cached
217 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
218 thread
219 .update(cx, |thread, cx| {
220 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
221 })
222 .unwrap();
223 cx.run_until_parked();
224
225 let tool_use = LanguageModelToolUse {
226 id: "tool_1".into(),
227 name: EchoTool.name().into(),
228 raw_input: json!({"text": "test"}).to_string(),
229 input: json!({"text": "test"}),
230 is_input_complete: true,
231 };
232 fake_model
233 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
234 fake_model.end_last_completion_stream();
235 cx.run_until_parked();
236
237 let completion = fake_model.pending_completions().pop().unwrap();
238 let tool_result = LanguageModelToolResult {
239 tool_use_id: "tool_1".into(),
240 tool_name: EchoTool.name().into(),
241 is_error: false,
242 content: "test".into(),
243 output: Some("test".into()),
244 };
245 assert_eq!(
246 completion.messages[1..],
247 vec![
248 LanguageModelRequestMessage {
249 role: Role::User,
250 content: vec!["Message 1".into()],
251 cache: false
252 },
253 LanguageModelRequestMessage {
254 role: Role::Assistant,
255 content: vec!["Response to Message 1".into()],
256 cache: false
257 },
258 LanguageModelRequestMessage {
259 role: Role::User,
260 content: vec!["Message 2".into()],
261 cache: false
262 },
263 LanguageModelRequestMessage {
264 role: Role::Assistant,
265 content: vec!["Response to Message 2".into()],
266 cache: false
267 },
268 LanguageModelRequestMessage {
269 role: Role::User,
270 content: vec!["Use the echo tool".into()],
271 cache: false
272 },
273 LanguageModelRequestMessage {
274 role: Role::Assistant,
275 content: vec![MessageContent::ToolUse(tool_use)],
276 cache: false
277 },
278 LanguageModelRequestMessage {
279 role: Role::User,
280 content: vec![MessageContent::ToolResult(tool_result)],
281 cache: true
282 }
283 ]
284 );
285}
286
287#[gpui::test]
288#[cfg_attr(not(feature = "e2e"), ignore)]
289async fn test_basic_tool_calls(cx: &mut TestAppContext) {
290 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
291
292 // Test a tool call that's likely to complete *before* streaming stops.
293 let events = thread
294 .update(cx, |thread, cx| {
295 thread.add_tool(EchoTool);
296 thread.send(
297 UserMessageId::new(),
298 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
299 cx,
300 )
301 })
302 .unwrap()
303 .collect()
304 .await;
305 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
306
307 // Test a tool calls that's likely to complete *after* streaming stops.
308 let events = thread
309 .update(cx, |thread, cx| {
310 thread.remove_tool(&AgentTool::name(&EchoTool));
311 thread.add_tool(DelayTool);
312 thread.send(
313 UserMessageId::new(),
314 [
315 "Now call the delay tool with 200ms.",
316 "When the timer goes off, then you echo the output of the tool.",
317 ],
318 cx,
319 )
320 })
321 .unwrap()
322 .collect()
323 .await;
324 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
325 thread.update(cx, |thread, _cx| {
326 assert!(
327 thread
328 .last_message()
329 .unwrap()
330 .as_agent_message()
331 .unwrap()
332 .content
333 .iter()
334 .any(|content| {
335 if let AgentMessageContent::Text(text) = content {
336 text.contains("Ding")
337 } else {
338 false
339 }
340 }),
341 "{}",
342 thread.to_markdown()
343 );
344 });
345}
346
347#[gpui::test]
348#[cfg_attr(not(feature = "e2e"), ignore)]
349async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
350 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
351
352 // Test a tool call that's likely to complete *before* streaming stops.
353 let mut events = thread
354 .update(cx, |thread, cx| {
355 thread.add_tool(WordListTool);
356 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
357 })
358 .unwrap();
359
360 let mut saw_partial_tool_use = false;
361 while let Some(event) = events.next().await {
362 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
363 thread.update(cx, |thread, _cx| {
364 // Look for a tool use in the thread's last message
365 let message = thread.last_message().unwrap();
366 let agent_message = message.as_agent_message().unwrap();
367 let last_content = agent_message.content.last().unwrap();
368 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
369 assert_eq!(last_tool_use.name.as_ref(), "word_list");
370 if tool_call.status == acp::ToolCallStatus::Pending {
371 if !last_tool_use.is_input_complete
372 && last_tool_use.input.get("g").is_none()
373 {
374 saw_partial_tool_use = true;
375 }
376 } else {
377 last_tool_use
378 .input
379 .get("a")
380 .expect("'a' has streamed because input is now complete");
381 last_tool_use
382 .input
383 .get("g")
384 .expect("'g' has streamed because input is now complete");
385 }
386 } else {
387 panic!("last content should be a tool use");
388 }
389 });
390 }
391 }
392
393 assert!(
394 saw_partial_tool_use,
395 "should see at least one partially streamed tool use in the history"
396 );
397}
398
399#[gpui::test]
400async fn test_tool_authorization(cx: &mut TestAppContext) {
401 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
402 let fake_model = model.as_fake();
403
404 let mut events = thread
405 .update(cx, |thread, cx| {
406 thread.add_tool(ToolRequiringPermission);
407 thread.send(UserMessageId::new(), ["abc"], cx)
408 })
409 .unwrap();
410 cx.run_until_parked();
411 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
412 LanguageModelToolUse {
413 id: "tool_id_1".into(),
414 name: ToolRequiringPermission.name().into(),
415 raw_input: "{}".into(),
416 input: json!({}),
417 is_input_complete: true,
418 },
419 ));
420 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
421 LanguageModelToolUse {
422 id: "tool_id_2".into(),
423 name: ToolRequiringPermission.name().into(),
424 raw_input: "{}".into(),
425 input: json!({}),
426 is_input_complete: true,
427 },
428 ));
429 fake_model.end_last_completion_stream();
430 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
431 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
432
433 // Approve the first
434 tool_call_auth_1
435 .response
436 .send(tool_call_auth_1.options[1].id.clone())
437 .unwrap();
438 cx.run_until_parked();
439
440 // Reject the second
441 tool_call_auth_2
442 .response
443 .send(tool_call_auth_1.options[2].id.clone())
444 .unwrap();
445 cx.run_until_parked();
446
447 let completion = fake_model.pending_completions().pop().unwrap();
448 let message = completion.messages.last().unwrap();
449 assert_eq!(
450 message.content,
451 vec![
452 language_model::MessageContent::ToolResult(LanguageModelToolResult {
453 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
454 tool_name: ToolRequiringPermission.name().into(),
455 is_error: false,
456 content: "Allowed".into(),
457 output: Some("Allowed".into())
458 }),
459 language_model::MessageContent::ToolResult(LanguageModelToolResult {
460 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
461 tool_name: ToolRequiringPermission.name().into(),
462 is_error: true,
463 content: "Permission to run tool denied by user".into(),
464 output: None
465 })
466 ]
467 );
468
469 // Simulate yet another tool call.
470 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
471 LanguageModelToolUse {
472 id: "tool_id_3".into(),
473 name: ToolRequiringPermission.name().into(),
474 raw_input: "{}".into(),
475 input: json!({}),
476 is_input_complete: true,
477 },
478 ));
479 fake_model.end_last_completion_stream();
480
481 // Respond by always allowing tools.
482 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
483 tool_call_auth_3
484 .response
485 .send(tool_call_auth_3.options[0].id.clone())
486 .unwrap();
487 cx.run_until_parked();
488 let completion = fake_model.pending_completions().pop().unwrap();
489 let message = completion.messages.last().unwrap();
490 assert_eq!(
491 message.content,
492 vec![language_model::MessageContent::ToolResult(
493 LanguageModelToolResult {
494 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
495 tool_name: ToolRequiringPermission.name().into(),
496 is_error: false,
497 content: "Allowed".into(),
498 output: Some("Allowed".into())
499 }
500 )]
501 );
502
503 // Simulate a final tool call, ensuring we don't trigger authorization.
504 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
505 LanguageModelToolUse {
506 id: "tool_id_4".into(),
507 name: ToolRequiringPermission.name().into(),
508 raw_input: "{}".into(),
509 input: json!({}),
510 is_input_complete: true,
511 },
512 ));
513 fake_model.end_last_completion_stream();
514 cx.run_until_parked();
515 let completion = fake_model.pending_completions().pop().unwrap();
516 let message = completion.messages.last().unwrap();
517 assert_eq!(
518 message.content,
519 vec![language_model::MessageContent::ToolResult(
520 LanguageModelToolResult {
521 tool_use_id: "tool_id_4".into(),
522 tool_name: ToolRequiringPermission.name().into(),
523 is_error: false,
524 content: "Allowed".into(),
525 output: Some("Allowed".into())
526 }
527 )]
528 );
529}
530
531#[gpui::test]
532async fn test_tool_hallucination(cx: &mut TestAppContext) {
533 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
534 let fake_model = model.as_fake();
535
536 let mut events = thread
537 .update(cx, |thread, cx| {
538 thread.send(UserMessageId::new(), ["abc"], cx)
539 })
540 .unwrap();
541 cx.run_until_parked();
542 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
543 LanguageModelToolUse {
544 id: "tool_id_1".into(),
545 name: "nonexistent_tool".into(),
546 raw_input: "{}".into(),
547 input: json!({}),
548 is_input_complete: true,
549 },
550 ));
551 fake_model.end_last_completion_stream();
552
553 let tool_call = expect_tool_call(&mut events).await;
554 assert_eq!(tool_call.title, "nonexistent_tool");
555 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
556 let update = expect_tool_call_update_fields(&mut events).await;
557 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
558}
559
560#[gpui::test]
561async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
562 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
563 let fake_model = model.as_fake();
564
565 let events = thread
566 .update(cx, |thread, cx| {
567 thread.add_tool(EchoTool);
568 thread.send(UserMessageId::new(), ["abc"], cx)
569 })
570 .unwrap();
571 cx.run_until_parked();
572 let tool_use = LanguageModelToolUse {
573 id: "tool_id_1".into(),
574 name: EchoTool.name().into(),
575 raw_input: "{}".into(),
576 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
577 is_input_complete: true,
578 };
579 fake_model
580 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
581 fake_model.end_last_completion_stream();
582
583 cx.run_until_parked();
584 let completion = fake_model.pending_completions().pop().unwrap();
585 let tool_result = LanguageModelToolResult {
586 tool_use_id: "tool_id_1".into(),
587 tool_name: EchoTool.name().into(),
588 is_error: false,
589 content: "def".into(),
590 output: Some("def".into()),
591 };
592 assert_eq!(
593 completion.messages[1..],
594 vec![
595 LanguageModelRequestMessage {
596 role: Role::User,
597 content: vec!["abc".into()],
598 cache: false
599 },
600 LanguageModelRequestMessage {
601 role: Role::Assistant,
602 content: vec![MessageContent::ToolUse(tool_use.clone())],
603 cache: false
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec![MessageContent::ToolResult(tool_result.clone())],
608 cache: true
609 },
610 ]
611 );
612
613 // Simulate reaching tool use limit.
614 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
615 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
616 ));
617 fake_model.end_last_completion_stream();
618 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
619 assert!(
620 last_event
621 .unwrap_err()
622 .is::<language_model::ToolUseLimitReachedError>()
623 );
624
625 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
626 cx.run_until_parked();
627 let completion = fake_model.pending_completions().pop().unwrap();
628 assert_eq!(
629 completion.messages[1..],
630 vec![
631 LanguageModelRequestMessage {
632 role: Role::User,
633 content: vec!["abc".into()],
634 cache: false
635 },
636 LanguageModelRequestMessage {
637 role: Role::Assistant,
638 content: vec![MessageContent::ToolUse(tool_use)],
639 cache: false
640 },
641 LanguageModelRequestMessage {
642 role: Role::User,
643 content: vec![MessageContent::ToolResult(tool_result)],
644 cache: false
645 },
646 LanguageModelRequestMessage {
647 role: Role::User,
648 content: vec!["Continue where you left off".into()],
649 cache: true
650 }
651 ]
652 );
653
654 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
655 fake_model.end_last_completion_stream();
656 events.collect::<Vec<_>>().await;
657 thread.read_with(cx, |thread, _cx| {
658 assert_eq!(
659 thread.last_message().unwrap().to_markdown(),
660 indoc! {"
661 ## Assistant
662
663 Done
664 "}
665 )
666 });
667
668 // Ensure we error if calling resume when tool use limit was *not* reached.
669 let error = thread
670 .update(cx, |thread, cx| thread.resume(cx))
671 .unwrap_err();
672 assert_eq!(
673 error.to_string(),
674 "can only resume after tool use limit is reached"
675 )
676}
677
678#[gpui::test]
679async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
680 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
681 let fake_model = model.as_fake();
682
683 let events = thread
684 .update(cx, |thread, cx| {
685 thread.add_tool(EchoTool);
686 thread.send(UserMessageId::new(), ["abc"], cx)
687 })
688 .unwrap();
689 cx.run_until_parked();
690
691 let tool_use = LanguageModelToolUse {
692 id: "tool_id_1".into(),
693 name: EchoTool.name().into(),
694 raw_input: "{}".into(),
695 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
696 is_input_complete: true,
697 };
698 let tool_result = LanguageModelToolResult {
699 tool_use_id: "tool_id_1".into(),
700 tool_name: EchoTool.name().into(),
701 is_error: false,
702 content: "def".into(),
703 output: Some("def".into()),
704 };
705 fake_model
706 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
707 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
708 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
709 ));
710 fake_model.end_last_completion_stream();
711 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
712 assert!(
713 last_event
714 .unwrap_err()
715 .is::<language_model::ToolUseLimitReachedError>()
716 );
717
718 thread
719 .update(cx, |thread, cx| {
720 thread.send(UserMessageId::new(), vec!["ghi"], cx)
721 })
722 .unwrap();
723 cx.run_until_parked();
724 let completion = fake_model.pending_completions().pop().unwrap();
725 assert_eq!(
726 completion.messages[1..],
727 vec![
728 LanguageModelRequestMessage {
729 role: Role::User,
730 content: vec!["abc".into()],
731 cache: false
732 },
733 LanguageModelRequestMessage {
734 role: Role::Assistant,
735 content: vec![MessageContent::ToolUse(tool_use)],
736 cache: false
737 },
738 LanguageModelRequestMessage {
739 role: Role::User,
740 content: vec![MessageContent::ToolResult(tool_result)],
741 cache: false
742 },
743 LanguageModelRequestMessage {
744 role: Role::User,
745 content: vec!["ghi".into()],
746 cache: true
747 }
748 ]
749 );
750}
751
752async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
753 let event = events
754 .next()
755 .await
756 .expect("no tool call authorization event received")
757 .unwrap();
758 match event {
759 ThreadEvent::ToolCall(tool_call) => tool_call,
760 event => {
761 panic!("Unexpected event {event:?}");
762 }
763 }
764}
765
766async fn expect_tool_call_update_fields(
767 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
768) -> acp::ToolCallUpdate {
769 let event = events
770 .next()
771 .await
772 .expect("no tool call authorization event received")
773 .unwrap();
774 match event {
775 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
776 event => {
777 panic!("Unexpected event {event:?}");
778 }
779 }
780}
781
782async fn next_tool_call_authorization(
783 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
784) -> ToolCallAuthorization {
785 loop {
786 let event = events
787 .next()
788 .await
789 .expect("no tool call authorization event received")
790 .unwrap();
791 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
792 let permission_kinds = tool_call_authorization
793 .options
794 .iter()
795 .map(|o| o.kind)
796 .collect::<Vec<_>>();
797 assert_eq!(
798 permission_kinds,
799 vec![
800 acp::PermissionOptionKind::AllowAlways,
801 acp::PermissionOptionKind::AllowOnce,
802 acp::PermissionOptionKind::RejectOnce,
803 ]
804 );
805 return tool_call_authorization;
806 }
807 }
808}
809
810#[gpui::test]
811#[cfg_attr(not(feature = "e2e"), ignore)]
812async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
813 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
814
815 // Test concurrent tool calls with different delay times
816 let events = thread
817 .update(cx, |thread, cx| {
818 thread.add_tool(DelayTool);
819 thread.send(
820 UserMessageId::new(),
821 [
822 "Call the delay tool twice in the same message.",
823 "Once with 100ms. Once with 300ms.",
824 "When both timers are complete, describe the outputs.",
825 ],
826 cx,
827 )
828 })
829 .unwrap()
830 .collect()
831 .await;
832
833 let stop_reasons = stop_events(events);
834 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
835
836 thread.update(cx, |thread, _cx| {
837 let last_message = thread.last_message().unwrap();
838 let agent_message = last_message.as_agent_message().unwrap();
839 let text = agent_message
840 .content
841 .iter()
842 .filter_map(|content| {
843 if let AgentMessageContent::Text(text) = content {
844 Some(text.as_str())
845 } else {
846 None
847 }
848 })
849 .collect::<String>();
850
851 assert!(text.contains("Ding"));
852 });
853}
854
855#[gpui::test]
856async fn test_profiles(cx: &mut TestAppContext) {
857 let ThreadTest {
858 model, thread, fs, ..
859 } = setup(cx, TestModel::Fake).await;
860 let fake_model = model.as_fake();
861
862 thread.update(cx, |thread, _cx| {
863 thread.add_tool(DelayTool);
864 thread.add_tool(EchoTool);
865 thread.add_tool(InfiniteTool);
866 });
867
868 // Override profiles and wait for settings to be loaded.
869 fs.insert_file(
870 paths::settings_file(),
871 json!({
872 "agent": {
873 "profiles": {
874 "test-1": {
875 "name": "Test Profile 1",
876 "tools": {
877 EchoTool.name(): true,
878 DelayTool.name(): true,
879 }
880 },
881 "test-2": {
882 "name": "Test Profile 2",
883 "tools": {
884 InfiniteTool.name(): true,
885 }
886 }
887 }
888 }
889 })
890 .to_string()
891 .into_bytes(),
892 )
893 .await;
894 cx.run_until_parked();
895
896 // Test that test-1 profile (default) has echo and delay tools
897 thread
898 .update(cx, |thread, cx| {
899 thread.set_profile(AgentProfileId("test-1".into()));
900 thread.send(UserMessageId::new(), ["test"], cx)
901 })
902 .unwrap();
903 cx.run_until_parked();
904
905 let mut pending_completions = fake_model.pending_completions();
906 assert_eq!(pending_completions.len(), 1);
907 let completion = pending_completions.pop().unwrap();
908 let tool_names: Vec<String> = completion
909 .tools
910 .iter()
911 .map(|tool| tool.name.clone())
912 .collect();
913 assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
914 fake_model.end_last_completion_stream();
915
916 // Switch to test-2 profile, and verify that it has only the infinite tool.
917 thread
918 .update(cx, |thread, cx| {
919 thread.set_profile(AgentProfileId("test-2".into()));
920 thread.send(UserMessageId::new(), ["test2"], cx)
921 })
922 .unwrap();
923 cx.run_until_parked();
924 let mut pending_completions = fake_model.pending_completions();
925 assert_eq!(pending_completions.len(), 1);
926 let completion = pending_completions.pop().unwrap();
927 let tool_names: Vec<String> = completion
928 .tools
929 .iter()
930 .map(|tool| tool.name.clone())
931 .collect();
932 assert_eq!(tool_names, vec![InfiniteTool.name()]);
933}
934
935#[gpui::test]
936#[cfg_attr(not(feature = "e2e"), ignore)]
937async fn test_cancellation(cx: &mut TestAppContext) {
938 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
939
940 let mut events = thread
941 .update(cx, |thread, cx| {
942 thread.add_tool(InfiniteTool);
943 thread.add_tool(EchoTool);
944 thread.send(
945 UserMessageId::new(),
946 ["Call the echo tool, then call the infinite tool, then explain their output"],
947 cx,
948 )
949 })
950 .unwrap();
951
952 // Wait until both tools are called.
953 let mut expected_tools = vec!["Echo", "Infinite Tool"];
954 let mut echo_id = None;
955 let mut echo_completed = false;
956 while let Some(event) = events.next().await {
957 match event.unwrap() {
958 ThreadEvent::ToolCall(tool_call) => {
959 assert_eq!(tool_call.title, expected_tools.remove(0));
960 if tool_call.title == "Echo" {
961 echo_id = Some(tool_call.id);
962 }
963 }
964 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
965 acp::ToolCallUpdate {
966 id,
967 fields:
968 acp::ToolCallUpdateFields {
969 status: Some(acp::ToolCallStatus::Completed),
970 ..
971 },
972 },
973 )) if Some(&id) == echo_id.as_ref() => {
974 echo_completed = true;
975 }
976 _ => {}
977 }
978
979 if expected_tools.is_empty() && echo_completed {
980 break;
981 }
982 }
983
984 // Cancel the current send and ensure that the event stream is closed, even
985 // if one of the tools is still running.
986 thread.update(cx, |thread, cx| thread.cancel(cx));
987 let events = events.collect::<Vec<_>>().await;
988 let last_event = events.last();
989 assert!(
990 matches!(
991 last_event,
992 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
993 ),
994 "unexpected event {last_event:?}"
995 );
996
997 // Ensure we can still send a new message after cancellation.
998 let events = thread
999 .update(cx, |thread, cx| {
1000 thread.send(
1001 UserMessageId::new(),
1002 ["Testing: reply with 'Hello' then stop."],
1003 cx,
1004 )
1005 })
1006 .unwrap()
1007 .collect::<Vec<_>>()
1008 .await;
1009 thread.update(cx, |thread, _cx| {
1010 let message = thread.last_message().unwrap();
1011 let agent_message = message.as_agent_message().unwrap();
1012 assert_eq!(
1013 agent_message.content,
1014 vec![AgentMessageContent::Text("Hello".to_string())]
1015 );
1016 });
1017 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1018}
1019
1020#[gpui::test]
1021async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1022 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1023 let fake_model = model.as_fake();
1024
1025 let events_1 = thread
1026 .update(cx, |thread, cx| {
1027 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1028 })
1029 .unwrap();
1030 cx.run_until_parked();
1031 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1032 cx.run_until_parked();
1033
1034 let events_2 = thread
1035 .update(cx, |thread, cx| {
1036 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1037 })
1038 .unwrap();
1039 cx.run_until_parked();
1040 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1041 fake_model
1042 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1043 fake_model.end_last_completion_stream();
1044
1045 let events_1 = events_1.collect::<Vec<_>>().await;
1046 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1047 let events_2 = events_2.collect::<Vec<_>>().await;
1048 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1049}
1050
1051#[gpui::test]
1052async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1053 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1054 let fake_model = model.as_fake();
1055
1056 let events_1 = thread
1057 .update(cx, |thread, cx| {
1058 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1059 })
1060 .unwrap();
1061 cx.run_until_parked();
1062 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1063 fake_model
1064 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1065 fake_model.end_last_completion_stream();
1066 let events_1 = events_1.collect::<Vec<_>>().await;
1067
1068 let events_2 = thread
1069 .update(cx, |thread, cx| {
1070 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1071 })
1072 .unwrap();
1073 cx.run_until_parked();
1074 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1075 fake_model
1076 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1077 fake_model.end_last_completion_stream();
1078 let events_2 = events_2.collect::<Vec<_>>().await;
1079
1080 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1081 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1082}
1083
1084#[gpui::test]
1085async fn test_refusal(cx: &mut TestAppContext) {
1086 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1087 let fake_model = model.as_fake();
1088
1089 let events = thread
1090 .update(cx, |thread, cx| {
1091 thread.send(UserMessageId::new(), ["Hello"], cx)
1092 })
1093 .unwrap();
1094 cx.run_until_parked();
1095 thread.read_with(cx, |thread, _| {
1096 assert_eq!(
1097 thread.to_markdown(),
1098 indoc! {"
1099 ## User
1100
1101 Hello
1102 "}
1103 );
1104 });
1105
1106 fake_model.send_last_completion_stream_text_chunk("Hey!");
1107 cx.run_until_parked();
1108 thread.read_with(cx, |thread, _| {
1109 assert_eq!(
1110 thread.to_markdown(),
1111 indoc! {"
1112 ## User
1113
1114 Hello
1115
1116 ## Assistant
1117
1118 Hey!
1119 "}
1120 );
1121 });
1122
1123 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1124 fake_model
1125 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1126 let events = events.collect::<Vec<_>>().await;
1127 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1128 thread.read_with(cx, |thread, _| {
1129 assert_eq!(thread.to_markdown(), "");
1130 });
1131}
1132
1133#[gpui::test]
1134async fn test_truncate_first_message(cx: &mut TestAppContext) {
1135 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1136 let fake_model = model.as_fake();
1137
1138 let message_id = UserMessageId::new();
1139 thread
1140 .update(cx, |thread, cx| {
1141 thread.send(message_id.clone(), ["Hello"], cx)
1142 })
1143 .unwrap();
1144 cx.run_until_parked();
1145 thread.read_with(cx, |thread, _| {
1146 assert_eq!(
1147 thread.to_markdown(),
1148 indoc! {"
1149 ## User
1150
1151 Hello
1152 "}
1153 );
1154 assert_eq!(thread.latest_token_usage(), None);
1155 });
1156
1157 fake_model.send_last_completion_stream_text_chunk("Hey!");
1158 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1159 language_model::TokenUsage {
1160 input_tokens: 32_000,
1161 output_tokens: 16_000,
1162 cache_creation_input_tokens: 0,
1163 cache_read_input_tokens: 0,
1164 },
1165 ));
1166 cx.run_until_parked();
1167 thread.read_with(cx, |thread, _| {
1168 assert_eq!(
1169 thread.to_markdown(),
1170 indoc! {"
1171 ## User
1172
1173 Hello
1174
1175 ## Assistant
1176
1177 Hey!
1178 "}
1179 );
1180 assert_eq!(
1181 thread.latest_token_usage(),
1182 Some(acp_thread::TokenUsage {
1183 used_tokens: 32_000 + 16_000,
1184 max_tokens: 1_000_000,
1185 })
1186 );
1187 });
1188
1189 thread
1190 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1191 .unwrap();
1192 cx.run_until_parked();
1193 thread.read_with(cx, |thread, _| {
1194 assert_eq!(thread.to_markdown(), "");
1195 assert_eq!(thread.latest_token_usage(), None);
1196 });
1197
1198 // Ensure we can still send a new message after truncation.
1199 thread
1200 .update(cx, |thread, cx| {
1201 thread.send(UserMessageId::new(), ["Hi"], cx)
1202 })
1203 .unwrap();
1204 thread.update(cx, |thread, _cx| {
1205 assert_eq!(
1206 thread.to_markdown(),
1207 indoc! {"
1208 ## User
1209
1210 Hi
1211 "}
1212 );
1213 });
1214 cx.run_until_parked();
1215 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1216 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1217 language_model::TokenUsage {
1218 input_tokens: 40_000,
1219 output_tokens: 20_000,
1220 cache_creation_input_tokens: 0,
1221 cache_read_input_tokens: 0,
1222 },
1223 ));
1224 cx.run_until_parked();
1225 thread.read_with(cx, |thread, _| {
1226 assert_eq!(
1227 thread.to_markdown(),
1228 indoc! {"
1229 ## User
1230
1231 Hi
1232
1233 ## Assistant
1234
1235 Ahoy!
1236 "}
1237 );
1238
1239 assert_eq!(
1240 thread.latest_token_usage(),
1241 Some(acp_thread::TokenUsage {
1242 used_tokens: 40_000 + 20_000,
1243 max_tokens: 1_000_000,
1244 })
1245 );
1246 });
1247}
1248
1249#[gpui::test]
1250async fn test_truncate_second_message(cx: &mut TestAppContext) {
1251 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1252 let fake_model = model.as_fake();
1253
1254 thread
1255 .update(cx, |thread, cx| {
1256 thread.send(UserMessageId::new(), ["Message 1"], cx)
1257 })
1258 .unwrap();
1259 cx.run_until_parked();
1260 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1261 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1262 language_model::TokenUsage {
1263 input_tokens: 32_000,
1264 output_tokens: 16_000,
1265 cache_creation_input_tokens: 0,
1266 cache_read_input_tokens: 0,
1267 },
1268 ));
1269 fake_model.end_last_completion_stream();
1270 cx.run_until_parked();
1271
1272 let assert_first_message_state = |cx: &mut TestAppContext| {
1273 thread.clone().read_with(cx, |thread, _| {
1274 assert_eq!(
1275 thread.to_markdown(),
1276 indoc! {"
1277 ## User
1278
1279 Message 1
1280
1281 ## Assistant
1282
1283 Message 1 response
1284 "}
1285 );
1286
1287 assert_eq!(
1288 thread.latest_token_usage(),
1289 Some(acp_thread::TokenUsage {
1290 used_tokens: 32_000 + 16_000,
1291 max_tokens: 1_000_000,
1292 })
1293 );
1294 });
1295 };
1296
1297 assert_first_message_state(cx);
1298
1299 let second_message_id = UserMessageId::new();
1300 thread
1301 .update(cx, |thread, cx| {
1302 thread.send(second_message_id.clone(), ["Message 2"], cx)
1303 })
1304 .unwrap();
1305 cx.run_until_parked();
1306
1307 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1308 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1309 language_model::TokenUsage {
1310 input_tokens: 40_000,
1311 output_tokens: 20_000,
1312 cache_creation_input_tokens: 0,
1313 cache_read_input_tokens: 0,
1314 },
1315 ));
1316 fake_model.end_last_completion_stream();
1317 cx.run_until_parked();
1318
1319 thread.read_with(cx, |thread, _| {
1320 assert_eq!(
1321 thread.to_markdown(),
1322 indoc! {"
1323 ## User
1324
1325 Message 1
1326
1327 ## Assistant
1328
1329 Message 1 response
1330
1331 ## User
1332
1333 Message 2
1334
1335 ## Assistant
1336
1337 Message 2 response
1338 "}
1339 );
1340
1341 assert_eq!(
1342 thread.latest_token_usage(),
1343 Some(acp_thread::TokenUsage {
1344 used_tokens: 40_000 + 20_000,
1345 max_tokens: 1_000_000,
1346 })
1347 );
1348 });
1349
1350 thread
1351 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1352 .unwrap();
1353 cx.run_until_parked();
1354
1355 assert_first_message_state(cx);
1356}
1357
1358#[gpui::test]
1359async fn test_title_generation(cx: &mut TestAppContext) {
1360 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1361 let fake_model = model.as_fake();
1362
1363 let summary_model = Arc::new(FakeLanguageModel::default());
1364 thread.update(cx, |thread, cx| {
1365 thread.set_summarization_model(Some(summary_model.clone()), cx)
1366 });
1367
1368 let send = thread
1369 .update(cx, |thread, cx| {
1370 thread.send(UserMessageId::new(), ["Hello"], cx)
1371 })
1372 .unwrap();
1373 cx.run_until_parked();
1374
1375 fake_model.send_last_completion_stream_text_chunk("Hey!");
1376 fake_model.end_last_completion_stream();
1377 cx.run_until_parked();
1378 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1379
1380 // Ensure the summary model has been invoked to generate a title.
1381 summary_model.send_last_completion_stream_text_chunk("Hello ");
1382 summary_model.send_last_completion_stream_text_chunk("world\nG");
1383 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1384 summary_model.end_last_completion_stream();
1385 send.collect::<Vec<_>>().await;
1386 cx.run_until_parked();
1387 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1388
1389 // Send another message, ensuring no title is generated this time.
1390 let send = thread
1391 .update(cx, |thread, cx| {
1392 thread.send(UserMessageId::new(), ["Hello again"], cx)
1393 })
1394 .unwrap();
1395 cx.run_until_parked();
1396 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1397 fake_model.end_last_completion_stream();
1398 cx.run_until_parked();
1399 assert_eq!(summary_model.pending_completions(), Vec::new());
1400 send.collect::<Vec<_>>().await;
1401 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1402}
1403
1404#[gpui::test]
1405async fn test_agent_connection(cx: &mut TestAppContext) {
1406 cx.update(settings::init);
1407 let templates = Templates::new();
1408
1409 // Initialize language model system with test provider
1410 cx.update(|cx| {
1411 gpui_tokio::init(cx);
1412 client::init_settings(cx);
1413
1414 let http_client = FakeHttpClient::with_404_response();
1415 let clock = Arc::new(clock::FakeSystemClock::new());
1416 let client = Client::new(clock, http_client, cx);
1417 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1418 language_model::init(client.clone(), cx);
1419 language_models::init(user_store, client.clone(), cx);
1420 Project::init_settings(cx);
1421 LanguageModelRegistry::test(cx);
1422 agent_settings::init(cx);
1423 });
1424 cx.executor().forbid_parking();
1425
1426 // Create a project for new_thread
1427 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1428 fake_fs.insert_tree(path!("/test"), json!({})).await;
1429 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1430 let cwd = Path::new("/test");
1431 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1432 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1433
1434 // Create agent and connection
1435 let agent = NativeAgent::new(
1436 project.clone(),
1437 history_store,
1438 templates.clone(),
1439 None,
1440 fake_fs.clone(),
1441 &mut cx.to_async(),
1442 )
1443 .await
1444 .unwrap();
1445 let connection = NativeAgentConnection(agent.clone());
1446
1447 // Test model_selector returns Some
1448 let selector_opt = connection.model_selector();
1449 assert!(
1450 selector_opt.is_some(),
1451 "agent2 should always support ModelSelector"
1452 );
1453 let selector = selector_opt.unwrap();
1454
1455 // Test list_models
1456 let listed_models = cx
1457 .update(|cx| selector.list_models(cx))
1458 .await
1459 .expect("list_models should succeed");
1460 let AgentModelList::Grouped(listed_models) = listed_models else {
1461 panic!("Unexpected model list type");
1462 };
1463 assert!(!listed_models.is_empty(), "should have at least one model");
1464 assert_eq!(
1465 listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1466 "fake/fake"
1467 );
1468
1469 // Create a thread using new_thread
1470 let connection_rc = Rc::new(connection.clone());
1471 let acp_thread = cx
1472 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1473 .await
1474 .expect("new_thread should succeed");
1475
1476 // Get the session_id from the AcpThread
1477 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1478
1479 // Test selected_model returns the default
1480 let model = cx
1481 .update(|cx| selector.selected_model(&session_id, cx))
1482 .await
1483 .expect("selected_model should succeed");
1484 let model = cx
1485 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1486 .unwrap();
1487 let model = model.as_fake();
1488 assert_eq!(model.id().0, "fake", "should return default model");
1489
1490 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1491 cx.run_until_parked();
1492 model.send_last_completion_stream_text_chunk("def");
1493 cx.run_until_parked();
1494 acp_thread.read_with(cx, |thread, cx| {
1495 assert_eq!(
1496 thread.to_markdown(cx),
1497 indoc! {"
1498 ## User
1499
1500 abc
1501
1502 ## Assistant
1503
1504 def
1505
1506 "}
1507 )
1508 });
1509
1510 // Test cancel
1511 cx.update(|cx| connection.cancel(&session_id, cx));
1512 request.await.expect("prompt should fail gracefully");
1513
1514 // Ensure that dropping the ACP thread causes the native thread to be
1515 // dropped as well.
1516 cx.update(|_| drop(acp_thread));
1517 let result = cx
1518 .update(|cx| {
1519 connection.prompt(
1520 Some(acp_thread::UserMessageId::new()),
1521 acp::PromptRequest {
1522 session_id: session_id.clone(),
1523 prompt: vec!["ghi".into()],
1524 },
1525 cx,
1526 )
1527 })
1528 .await;
1529 assert_eq!(
1530 result.as_ref().unwrap_err().to_string(),
1531 "Session not found",
1532 "unexpected result: {:?}",
1533 result
1534 );
1535}
1536
1537#[gpui::test]
1538async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1539 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1540 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1541 let fake_model = model.as_fake();
1542
1543 let mut events = thread
1544 .update(cx, |thread, cx| {
1545 thread.send(UserMessageId::new(), ["Think"], cx)
1546 })
1547 .unwrap();
1548 cx.run_until_parked();
1549
1550 // Simulate streaming partial input.
1551 let input = json!({});
1552 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1553 LanguageModelToolUse {
1554 id: "1".into(),
1555 name: ThinkingTool.name().into(),
1556 raw_input: input.to_string(),
1557 input,
1558 is_input_complete: false,
1559 },
1560 ));
1561
1562 // Input streaming completed
1563 let input = json!({ "content": "Thinking hard!" });
1564 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1565 LanguageModelToolUse {
1566 id: "1".into(),
1567 name: "thinking".into(),
1568 raw_input: input.to_string(),
1569 input,
1570 is_input_complete: true,
1571 },
1572 ));
1573 fake_model.end_last_completion_stream();
1574 cx.run_until_parked();
1575
1576 let tool_call = expect_tool_call(&mut events).await;
1577 assert_eq!(
1578 tool_call,
1579 acp::ToolCall {
1580 id: acp::ToolCallId("1".into()),
1581 title: "Thinking".into(),
1582 kind: acp::ToolKind::Think,
1583 status: acp::ToolCallStatus::Pending,
1584 content: vec![],
1585 locations: vec![],
1586 raw_input: Some(json!({})),
1587 raw_output: None,
1588 }
1589 );
1590 let update = expect_tool_call_update_fields(&mut events).await;
1591 assert_eq!(
1592 update,
1593 acp::ToolCallUpdate {
1594 id: acp::ToolCallId("1".into()),
1595 fields: acp::ToolCallUpdateFields {
1596 title: Some("Thinking".into()),
1597 kind: Some(acp::ToolKind::Think),
1598 raw_input: Some(json!({ "content": "Thinking hard!" })),
1599 ..Default::default()
1600 },
1601 }
1602 );
1603 let update = expect_tool_call_update_fields(&mut events).await;
1604 assert_eq!(
1605 update,
1606 acp::ToolCallUpdate {
1607 id: acp::ToolCallId("1".into()),
1608 fields: acp::ToolCallUpdateFields {
1609 status: Some(acp::ToolCallStatus::InProgress),
1610 ..Default::default()
1611 },
1612 }
1613 );
1614 let update = expect_tool_call_update_fields(&mut events).await;
1615 assert_eq!(
1616 update,
1617 acp::ToolCallUpdate {
1618 id: acp::ToolCallId("1".into()),
1619 fields: acp::ToolCallUpdateFields {
1620 content: Some(vec!["Thinking hard!".into()]),
1621 ..Default::default()
1622 },
1623 }
1624 );
1625 let update = expect_tool_call_update_fields(&mut events).await;
1626 assert_eq!(
1627 update,
1628 acp::ToolCallUpdate {
1629 id: acp::ToolCallId("1".into()),
1630 fields: acp::ToolCallUpdateFields {
1631 status: Some(acp::ToolCallStatus::Completed),
1632 raw_output: Some("Finished thinking.".into()),
1633 ..Default::default()
1634 },
1635 }
1636 );
1637}
1638
1639#[gpui::test]
1640async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
1641 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1642 let fake_model = model.as_fake();
1643
1644 let mut events = thread
1645 .update(cx, |thread, cx| {
1646 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1647 thread.send(UserMessageId::new(), ["Hello!"], cx)
1648 })
1649 .unwrap();
1650 cx.run_until_parked();
1651
1652 fake_model.send_last_completion_stream_text_chunk("Hey!");
1653 fake_model.end_last_completion_stream();
1654
1655 let mut retry_events = Vec::new();
1656 while let Some(Ok(event)) = events.next().await {
1657 match event {
1658 ThreadEvent::Retry(retry_status) => {
1659 retry_events.push(retry_status);
1660 }
1661 ThreadEvent::Stop(..) => break,
1662 _ => {}
1663 }
1664 }
1665
1666 assert_eq!(retry_events.len(), 0);
1667 thread.read_with(cx, |thread, _cx| {
1668 assert_eq!(
1669 thread.to_markdown(),
1670 indoc! {"
1671 ## User
1672
1673 Hello!
1674
1675 ## Assistant
1676
1677 Hey!
1678 "}
1679 )
1680 });
1681}
1682
1683#[gpui::test]
1684async fn test_send_retry_on_error(cx: &mut TestAppContext) {
1685 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1686 let fake_model = model.as_fake();
1687
1688 let mut events = thread
1689 .update(cx, |thread, cx| {
1690 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1691 thread.send(UserMessageId::new(), ["Hello!"], cx)
1692 })
1693 .unwrap();
1694 cx.run_until_parked();
1695
1696 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
1697 provider: LanguageModelProviderName::new("Anthropic"),
1698 retry_after: Some(Duration::from_secs(3)),
1699 });
1700 fake_model.end_last_completion_stream();
1701
1702 cx.executor().advance_clock(Duration::from_secs(3));
1703 cx.run_until_parked();
1704
1705 fake_model.send_last_completion_stream_text_chunk("Hey!");
1706 fake_model.end_last_completion_stream();
1707
1708 let mut retry_events = Vec::new();
1709 while let Some(Ok(event)) = events.next().await {
1710 match event {
1711 ThreadEvent::Retry(retry_status) => {
1712 retry_events.push(retry_status);
1713 }
1714 ThreadEvent::Stop(..) => break,
1715 _ => {}
1716 }
1717 }
1718
1719 assert_eq!(retry_events.len(), 1);
1720 assert!(matches!(
1721 retry_events[0],
1722 acp_thread::RetryStatus { attempt: 1, .. }
1723 ));
1724 thread.read_with(cx, |thread, _cx| {
1725 assert_eq!(
1726 thread.to_markdown(),
1727 indoc! {"
1728 ## User
1729
1730 Hello!
1731
1732 ## Assistant
1733
1734 Hey!
1735 "}
1736 )
1737 });
1738}
1739
1740#[gpui::test]
1741async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
1742 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1743 let fake_model = model.as_fake();
1744
1745 let mut events = thread
1746 .update(cx, |thread, cx| {
1747 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1748 thread.send(UserMessageId::new(), ["Hello!"], cx)
1749 })
1750 .unwrap();
1751 cx.run_until_parked();
1752
1753 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
1754 fake_model.send_last_completion_stream_error(
1755 LanguageModelCompletionError::ServerOverloaded {
1756 provider: LanguageModelProviderName::new("Anthropic"),
1757 retry_after: Some(Duration::from_secs(3)),
1758 },
1759 );
1760 fake_model.end_last_completion_stream();
1761 cx.executor().advance_clock(Duration::from_secs(3));
1762 cx.run_until_parked();
1763 }
1764
1765 let mut errors = Vec::new();
1766 let mut retry_events = Vec::new();
1767 while let Some(event) = events.next().await {
1768 match event {
1769 Ok(ThreadEvent::Retry(retry_status)) => {
1770 retry_events.push(retry_status);
1771 }
1772 Ok(ThreadEvent::Stop(..)) => break,
1773 Err(error) => errors.push(error),
1774 _ => {}
1775 }
1776 }
1777
1778 assert_eq!(
1779 retry_events.len(),
1780 crate::thread::MAX_RETRY_ATTEMPTS as usize
1781 );
1782 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
1783 assert_eq!(retry_events[i].attempt, i + 1);
1784 }
1785 assert_eq!(errors.len(), 1);
1786 let error = errors[0]
1787 .downcast_ref::<LanguageModelCompletionError>()
1788 .unwrap();
1789 assert!(matches!(
1790 error,
1791 LanguageModelCompletionError::ServerOverloaded { .. }
1792 ));
1793}
1794
1795/// Filters out the stop events for asserting against in tests
1796fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
1797 result_events
1798 .into_iter()
1799 .filter_map(|event| match event.unwrap() {
1800 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
1801 _ => None,
1802 })
1803 .collect()
1804}
1805
1806struct ThreadTest {
1807 model: Arc<dyn LanguageModel>,
1808 thread: Entity<Thread>,
1809 project_context: Entity<ProjectContext>,
1810 fs: Arc<FakeFs>,
1811}
1812
1813enum TestModel {
1814 Sonnet4,
1815 Fake,
1816}
1817
1818impl TestModel {
1819 fn id(&self) -> LanguageModelId {
1820 match self {
1821 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1822 TestModel::Fake => unreachable!(),
1823 }
1824 }
1825}
1826
1827async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1828 cx.executor().allow_parking();
1829
1830 let fs = FakeFs::new(cx.background_executor.clone());
1831 fs.create_dir(paths::settings_file().parent().unwrap())
1832 .await
1833 .unwrap();
1834 fs.insert_file(
1835 paths::settings_file(),
1836 json!({
1837 "agent": {
1838 "default_profile": "test-profile",
1839 "profiles": {
1840 "test-profile": {
1841 "name": "Test Profile",
1842 "tools": {
1843 EchoTool.name(): true,
1844 DelayTool.name(): true,
1845 WordListTool.name(): true,
1846 ToolRequiringPermission.name(): true,
1847 InfiniteTool.name(): true,
1848 }
1849 }
1850 }
1851 }
1852 })
1853 .to_string()
1854 .into_bytes(),
1855 )
1856 .await;
1857
1858 cx.update(|cx| {
1859 settings::init(cx);
1860 Project::init_settings(cx);
1861 agent_settings::init(cx);
1862 gpui_tokio::init(cx);
1863 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1864 cx.set_http_client(Arc::new(http_client));
1865
1866 client::init_settings(cx);
1867 let client = Client::production(cx);
1868 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1869 language_model::init(client.clone(), cx);
1870 language_models::init(user_store, client.clone(), cx);
1871
1872 watch_settings(fs.clone(), cx);
1873 });
1874
1875 let templates = Templates::new();
1876
1877 fs.insert_tree(path!("/test"), json!({})).await;
1878 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1879
1880 let model = cx
1881 .update(|cx| {
1882 if let TestModel::Fake = model {
1883 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1884 } else {
1885 let model_id = model.id();
1886 let models = LanguageModelRegistry::read_global(cx);
1887 let model = models
1888 .available_models(cx)
1889 .find(|model| model.id() == model_id)
1890 .unwrap();
1891
1892 let provider = models.provider(&model.provider_id()).unwrap();
1893 let authenticated = provider.authenticate(cx);
1894
1895 cx.spawn(async move |_cx| {
1896 authenticated.await.unwrap();
1897 model
1898 })
1899 }
1900 })
1901 .await;
1902
1903 let project_context = cx.new(|_cx| ProjectContext::default());
1904 let context_server_registry =
1905 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1906 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1907 let thread = cx.new(|cx| {
1908 Thread::new(
1909 project,
1910 project_context.clone(),
1911 context_server_registry,
1912 action_log,
1913 templates,
1914 Some(model.clone()),
1915 cx,
1916 )
1917 });
1918 ThreadTest {
1919 model,
1920 thread,
1921 project_context,
1922 fs,
1923 }
1924}
1925
1926#[cfg(test)]
1927#[ctor::ctor]
1928fn init_logger() {
1929 if std::env::var("RUST_LOG").is_ok() {
1930 env_logger::init();
1931 }
1932}
1933
1934fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1935 let fs = fs.clone();
1936 cx.spawn({
1937 async move |cx| {
1938 let mut new_settings_content_rx = settings::watch_config_file(
1939 cx.background_executor(),
1940 fs,
1941 paths::settings_file().clone(),
1942 );
1943
1944 while let Some(new_settings_content) = new_settings_content_rx.next().await {
1945 cx.update(|cx| {
1946 SettingsStore::update_global(cx, |settings, cx| {
1947 settings.set_user_settings(&new_settings_content, cx)
1948 })
1949 })
1950 .ok();
1951 }
1952 }
1953 })
1954 .detach();
1955}