1use crate::{
2 AssistantContext, CacheStatus, ContextEvent, ContextId, ContextOperation, ContextSummary,
3 InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
4};
5use anyhow::Result;
6use assistant_slash_command::{
7 ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
8 SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
9};
10use assistant_slash_commands::FileSlashCommand;
11use collections::{HashMap, HashSet};
12use fs::FakeFs;
13use futures::{
14 channel::mpsc,
15 stream::{self, StreamExt},
16};
17use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
18use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
19use language_model::{
20 ConfiguredModel, LanguageModelCacheConfiguration, LanguageModelRegistry, Role,
21 fake_provider::{FakeLanguageModel, FakeLanguageModelProvider},
22};
23use parking_lot::Mutex;
24use pretty_assertions::assert_eq;
25use project::Project;
26use prompt_store::PromptBuilder;
27use rand::prelude::*;
28use serde_json::json;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 env,
33 ops::Range,
34 path::Path,
35 rc::Rc,
36 sync::{Arc, atomic::AtomicBool},
37};
38use text::{ReplicaId, ToOffset, network::Network};
39use ui::{IconName, Window};
40use unindent::Unindent;
41use util::RandomCharIter;
42use workspace::Workspace;
43
44#[gpui::test]
45fn test_inserting_and_removing_messages(cx: &mut App) {
46 init_test(cx);
47
48 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
49 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
50 let context = cx.new(|cx| {
51 AssistantContext::local(
52 registry,
53 None,
54 None,
55 prompt_builder.clone(),
56 Arc::new(SlashCommandWorkingSet::default()),
57 cx,
58 )
59 });
60 let buffer = context.read(cx).buffer.clone();
61
62 let message_1 = context.read(cx).message_anchors[0].clone();
63 assert_eq!(
64 messages(&context, cx),
65 vec![(message_1.id, Role::User, 0..0)]
66 );
67
68 let message_2 = context.update(cx, |context, cx| {
69 context
70 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
71 .unwrap()
72 });
73 assert_eq!(
74 messages(&context, cx),
75 vec![
76 (message_1.id, Role::User, 0..1),
77 (message_2.id, Role::Assistant, 1..1)
78 ]
79 );
80
81 buffer.update(cx, |buffer, cx| {
82 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
83 });
84 assert_eq!(
85 messages(&context, cx),
86 vec![
87 (message_1.id, Role::User, 0..2),
88 (message_2.id, Role::Assistant, 2..3)
89 ]
90 );
91
92 let message_3 = context.update(cx, |context, cx| {
93 context
94 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
95 .unwrap()
96 });
97 assert_eq!(
98 messages(&context, cx),
99 vec![
100 (message_1.id, Role::User, 0..2),
101 (message_2.id, Role::Assistant, 2..4),
102 (message_3.id, Role::User, 4..4)
103 ]
104 );
105
106 let message_4 = context.update(cx, |context, cx| {
107 context
108 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
109 .unwrap()
110 });
111 assert_eq!(
112 messages(&context, cx),
113 vec![
114 (message_1.id, Role::User, 0..2),
115 (message_2.id, Role::Assistant, 2..4),
116 (message_4.id, Role::User, 4..5),
117 (message_3.id, Role::User, 5..5),
118 ]
119 );
120
121 buffer.update(cx, |buffer, cx| {
122 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
123 });
124 assert_eq!(
125 messages(&context, cx),
126 vec![
127 (message_1.id, Role::User, 0..2),
128 (message_2.id, Role::Assistant, 2..4),
129 (message_4.id, Role::User, 4..6),
130 (message_3.id, Role::User, 6..7),
131 ]
132 );
133
134 // Deleting across message boundaries merges the messages.
135 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
136 assert_eq!(
137 messages(&context, cx),
138 vec![
139 (message_1.id, Role::User, 0..3),
140 (message_3.id, Role::User, 3..4),
141 ]
142 );
143
144 // Undoing the deletion should also undo the merge.
145 buffer.update(cx, |buffer, cx| buffer.undo(cx));
146 assert_eq!(
147 messages(&context, cx),
148 vec![
149 (message_1.id, Role::User, 0..2),
150 (message_2.id, Role::Assistant, 2..4),
151 (message_4.id, Role::User, 4..6),
152 (message_3.id, Role::User, 6..7),
153 ]
154 );
155
156 // Redoing the deletion should also redo the merge.
157 buffer.update(cx, |buffer, cx| buffer.redo(cx));
158 assert_eq!(
159 messages(&context, cx),
160 vec![
161 (message_1.id, Role::User, 0..3),
162 (message_3.id, Role::User, 3..4),
163 ]
164 );
165
166 // Ensure we can still insert after a merged message.
167 let message_5 = context.update(cx, |context, cx| {
168 context
169 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
170 .unwrap()
171 });
172 assert_eq!(
173 messages(&context, cx),
174 vec![
175 (message_1.id, Role::User, 0..3),
176 (message_5.id, Role::System, 3..4),
177 (message_3.id, Role::User, 4..5)
178 ]
179 );
180}
181
182#[gpui::test]
183fn test_message_splitting(cx: &mut App) {
184 init_test(cx);
185
186 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
187
188 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
189 let context = cx.new(|cx| {
190 AssistantContext::local(
191 registry.clone(),
192 None,
193 None,
194 prompt_builder.clone(),
195 Arc::new(SlashCommandWorkingSet::default()),
196 cx,
197 )
198 });
199 let buffer = context.read(cx).buffer.clone();
200
201 let message_1 = context.read(cx).message_anchors[0].clone();
202 assert_eq!(
203 messages(&context, cx),
204 vec![(message_1.id, Role::User, 0..0)]
205 );
206
207 buffer.update(cx, |buffer, cx| {
208 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
209 });
210
211 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
212 let message_2 = message_2.unwrap();
213
214 // We recycle newlines in the middle of a split message
215 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
216 assert_eq!(
217 messages(&context, cx),
218 vec![
219 (message_1.id, Role::User, 0..4),
220 (message_2.id, Role::User, 4..16),
221 ]
222 );
223
224 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
225 let message_3 = message_3.unwrap();
226
227 // We don't recycle newlines at the end of a split message
228 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
229 assert_eq!(
230 messages(&context, cx),
231 vec![
232 (message_1.id, Role::User, 0..4),
233 (message_3.id, Role::User, 4..5),
234 (message_2.id, Role::User, 5..17),
235 ]
236 );
237
238 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
239 let message_4 = message_4.unwrap();
240 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
241 assert_eq!(
242 messages(&context, cx),
243 vec![
244 (message_1.id, Role::User, 0..4),
245 (message_3.id, Role::User, 4..5),
246 (message_2.id, Role::User, 5..9),
247 (message_4.id, Role::User, 9..17),
248 ]
249 );
250
251 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
252 let message_5 = message_5.unwrap();
253 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
254 assert_eq!(
255 messages(&context, cx),
256 vec![
257 (message_1.id, Role::User, 0..4),
258 (message_3.id, Role::User, 4..5),
259 (message_2.id, Role::User, 5..9),
260 (message_4.id, Role::User, 9..10),
261 (message_5.id, Role::User, 10..18),
262 ]
263 );
264
265 let (message_6, message_7) =
266 context.update(cx, |context, cx| context.split_message(14..16, cx));
267 let message_6 = message_6.unwrap();
268 let message_7 = message_7.unwrap();
269 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
270 assert_eq!(
271 messages(&context, cx),
272 vec![
273 (message_1.id, Role::User, 0..4),
274 (message_3.id, Role::User, 4..5),
275 (message_2.id, Role::User, 5..9),
276 (message_4.id, Role::User, 9..10),
277 (message_5.id, Role::User, 10..14),
278 (message_6.id, Role::User, 14..17),
279 (message_7.id, Role::User, 17..19),
280 ]
281 );
282}
283
284#[gpui::test]
285fn test_messages_for_offsets(cx: &mut App) {
286 init_test(cx);
287
288 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
289 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
290 let context = cx.new(|cx| {
291 AssistantContext::local(
292 registry,
293 None,
294 None,
295 prompt_builder.clone(),
296 Arc::new(SlashCommandWorkingSet::default()),
297 cx,
298 )
299 });
300 let buffer = context.read(cx).buffer.clone();
301
302 let message_1 = context.read(cx).message_anchors[0].clone();
303 assert_eq!(
304 messages(&context, cx),
305 vec![(message_1.id, Role::User, 0..0)]
306 );
307
308 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
309 let message_2 = context
310 .update(cx, |context, cx| {
311 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
312 })
313 .unwrap();
314 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
315
316 let message_3 = context
317 .update(cx, |context, cx| {
318 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
319 })
320 .unwrap();
321 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
322
323 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
324 assert_eq!(
325 messages(&context, cx),
326 vec![
327 (message_1.id, Role::User, 0..4),
328 (message_2.id, Role::User, 4..8),
329 (message_3.id, Role::User, 8..11)
330 ]
331 );
332
333 assert_eq!(
334 message_ids_for_offsets(&context, &[0, 4, 9], cx),
335 [message_1.id, message_2.id, message_3.id]
336 );
337 assert_eq!(
338 message_ids_for_offsets(&context, &[0, 1, 11], cx),
339 [message_1.id, message_3.id]
340 );
341
342 let message_4 = context
343 .update(cx, |context, cx| {
344 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
345 })
346 .unwrap();
347 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
348 assert_eq!(
349 messages(&context, cx),
350 vec![
351 (message_1.id, Role::User, 0..4),
352 (message_2.id, Role::User, 4..8),
353 (message_3.id, Role::User, 8..12),
354 (message_4.id, Role::User, 12..12)
355 ]
356 );
357 assert_eq!(
358 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
359 [message_1.id, message_2.id, message_3.id, message_4.id]
360 );
361
362 fn message_ids_for_offsets(
363 context: &Entity<AssistantContext>,
364 offsets: &[usize],
365 cx: &App,
366 ) -> Vec<MessageId> {
367 context
368 .read(cx)
369 .messages_for_offsets(offsets.iter().copied(), cx)
370 .into_iter()
371 .map(|message| message.id)
372 .collect()
373 }
374}
375
376#[gpui::test]
377async fn test_slash_commands(cx: &mut TestAppContext) {
378 cx.update(init_test);
379
380 let fs = FakeFs::new(cx.background_executor.clone());
381
382 fs.insert_tree(
383 "/test",
384 json!({
385 "src": {
386 "lib.rs": "fn one() -> usize { 1 }",
387 "main.rs": "
388 use crate::one;
389 fn main() { one(); }
390 ".unindent(),
391 }
392 }),
393 )
394 .await;
395
396 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
397 slash_command_registry.register_command(FileSlashCommand, false);
398
399 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
400 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
401 let context = cx.new(|cx| {
402 AssistantContext::local(
403 registry.clone(),
404 None,
405 None,
406 prompt_builder.clone(),
407 Arc::new(SlashCommandWorkingSet::default()),
408 cx,
409 )
410 });
411
412 #[derive(Default)]
413 struct ContextRanges {
414 parsed_commands: HashSet<Range<language::Anchor>>,
415 command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
416 output_sections: HashSet<Range<language::Anchor>>,
417 }
418
419 let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
420 context.update(cx, |_, cx| {
421 cx.subscribe(&context, {
422 let context_ranges = context_ranges.clone();
423 move |context, _, event, _| {
424 let mut context_ranges = context_ranges.borrow_mut();
425 match event {
426 ContextEvent::InvokedSlashCommandChanged { command_id } => {
427 let command = context.invoked_slash_command(command_id).unwrap();
428 context_ranges
429 .command_outputs
430 .insert(*command_id, command.range.clone());
431 }
432 ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
433 for range in removed {
434 context_ranges.parsed_commands.remove(range);
435 }
436 for command in updated {
437 context_ranges
438 .parsed_commands
439 .insert(command.source_range.clone());
440 }
441 }
442 ContextEvent::SlashCommandOutputSectionAdded { section } => {
443 context_ranges.output_sections.insert(section.range.clone());
444 }
445 _ => {}
446 }
447 }
448 })
449 .detach();
450 });
451
452 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
453
454 // Insert a slash command
455 buffer.update(cx, |buffer, cx| {
456 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
457 });
458 assert_text_and_context_ranges(
459 &buffer,
460 &context_ranges,
461 &"
462 «/file src/lib.rs»"
463 .unindent(),
464 cx,
465 );
466
467 // Edit the argument of the slash command.
468 buffer.update(cx, |buffer, cx| {
469 let edit_offset = buffer.text().find("lib.rs").unwrap();
470 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
471 });
472 assert_text_and_context_ranges(
473 &buffer,
474 &context_ranges,
475 &"
476 «/file src/main.rs»"
477 .unindent(),
478 cx,
479 );
480
481 // Edit the name of the slash command, using one that doesn't exist.
482 buffer.update(cx, |buffer, cx| {
483 let edit_offset = buffer.text().find("/file").unwrap();
484 buffer.edit(
485 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
486 None,
487 cx,
488 );
489 });
490 assert_text_and_context_ranges(
491 &buffer,
492 &context_ranges,
493 &"
494 /unknown src/main.rs"
495 .unindent(),
496 cx,
497 );
498
499 // Undoing the insertion of an non-existent slash command resorts the previous one.
500 buffer.update(cx, |buffer, cx| buffer.undo(cx));
501 assert_text_and_context_ranges(
502 &buffer,
503 &context_ranges,
504 &"
505 «/file src/main.rs»"
506 .unindent(),
507 cx,
508 );
509
510 let (command_output_tx, command_output_rx) = mpsc::unbounded();
511 context.update(cx, |context, cx| {
512 let command_source_range = context.parsed_slash_commands[0].source_range.clone();
513 context.insert_command_output(
514 command_source_range,
515 "file",
516 Task::ready(Ok(command_output_rx.boxed())),
517 true,
518 cx,
519 );
520 });
521 assert_text_and_context_ranges(
522 &buffer,
523 &context_ranges,
524 &"
525 ⟦«/file src/main.rs»
526 …⟧
527 "
528 .unindent(),
529 cx,
530 );
531
532 command_output_tx
533 .unbounded_send(Ok(SlashCommandEvent::StartSection {
534 icon: IconName::Ai,
535 label: "src/main.rs".into(),
536 metadata: None,
537 }))
538 .unwrap();
539 command_output_tx
540 .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
541 .unwrap();
542 cx.run_until_parked();
543 assert_text_and_context_ranges(
544 &buffer,
545 &context_ranges,
546 &"
547 ⟦«/file src/main.rs»
548 src/main.rs…⟧
549 "
550 .unindent(),
551 cx,
552 );
553
554 command_output_tx
555 .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
556 .unwrap();
557 cx.run_until_parked();
558 assert_text_and_context_ranges(
559 &buffer,
560 &context_ranges,
561 &"
562 ⟦«/file src/main.rs»
563 src/main.rs
564 fn main() {}…⟧
565 "
566 .unindent(),
567 cx,
568 );
569
570 command_output_tx
571 .unbounded_send(Ok(SlashCommandEvent::EndSection))
572 .unwrap();
573 cx.run_until_parked();
574 assert_text_and_context_ranges(
575 &buffer,
576 &context_ranges,
577 &"
578 ⟦«/file src/main.rs»
579 ⟪src/main.rs
580 fn main() {}⟫…⟧
581 "
582 .unindent(),
583 cx,
584 );
585
586 drop(command_output_tx);
587 cx.run_until_parked();
588 assert_text_and_context_ranges(
589 &buffer,
590 &context_ranges,
591 &"
592 ⟦⟪src/main.rs
593 fn main() {}⟫⟧
594 "
595 .unindent(),
596 cx,
597 );
598
599 #[track_caller]
600 fn assert_text_and_context_ranges(
601 buffer: &Entity<Buffer>,
602 ranges: &RefCell<ContextRanges>,
603 expected_marked_text: &str,
604 cx: &mut TestAppContext,
605 ) {
606 let mut actual_marked_text = String::new();
607 buffer.update(cx, |buffer, _| {
608 struct Endpoint {
609 offset: usize,
610 marker: char,
611 }
612
613 let ranges = ranges.borrow();
614 let mut endpoints = Vec::new();
615 for range in ranges.command_outputs.values() {
616 endpoints.push(Endpoint {
617 offset: range.start.to_offset(buffer),
618 marker: '⟦',
619 });
620 }
621 for range in ranges.parsed_commands.iter() {
622 endpoints.push(Endpoint {
623 offset: range.start.to_offset(buffer),
624 marker: '«',
625 });
626 }
627 for range in ranges.output_sections.iter() {
628 endpoints.push(Endpoint {
629 offset: range.start.to_offset(buffer),
630 marker: '⟪',
631 });
632 }
633
634 for range in ranges.output_sections.iter() {
635 endpoints.push(Endpoint {
636 offset: range.end.to_offset(buffer),
637 marker: '⟫',
638 });
639 }
640 for range in ranges.parsed_commands.iter() {
641 endpoints.push(Endpoint {
642 offset: range.end.to_offset(buffer),
643 marker: '»',
644 });
645 }
646 for range in ranges.command_outputs.values() {
647 endpoints.push(Endpoint {
648 offset: range.end.to_offset(buffer),
649 marker: '⟧',
650 });
651 }
652
653 endpoints.sort_by_key(|endpoint| endpoint.offset);
654 let mut offset = 0;
655 for endpoint in endpoints {
656 actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
657 actual_marked_text.push(endpoint.marker);
658 offset = endpoint.offset;
659 }
660 actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
661 });
662
663 assert_eq!(actual_marked_text, expected_marked_text);
664 }
665}
666
667#[gpui::test]
668async fn test_serialization(cx: &mut TestAppContext) {
669 cx.update(init_test);
670
671 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
672 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
673 let context = cx.new(|cx| {
674 AssistantContext::local(
675 registry.clone(),
676 None,
677 None,
678 prompt_builder.clone(),
679 Arc::new(SlashCommandWorkingSet::default()),
680 cx,
681 )
682 });
683 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
684 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
685 let message_1 = context.update(cx, |context, cx| {
686 context
687 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
688 .unwrap()
689 });
690 let message_2 = context.update(cx, |context, cx| {
691 context
692 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
693 .unwrap()
694 });
695 buffer.update(cx, |buffer, cx| {
696 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
697 buffer.finalize_last_transaction();
698 });
699 let _message_3 = context.update(cx, |context, cx| {
700 context
701 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
702 .unwrap()
703 });
704 buffer.update(cx, |buffer, cx| buffer.undo(cx));
705 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
706 assert_eq!(
707 cx.read(|cx| messages(&context, cx)),
708 [
709 (message_0, Role::User, 0..2),
710 (message_1.id, Role::Assistant, 2..6),
711 (message_2.id, Role::System, 6..6),
712 ]
713 );
714
715 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
716 let deserialized_context = cx.new(|cx| {
717 AssistantContext::deserialize(
718 serialized_context,
719 Path::new("").into(),
720 registry.clone(),
721 prompt_builder.clone(),
722 Arc::new(SlashCommandWorkingSet::default()),
723 None,
724 None,
725 cx,
726 )
727 });
728 let deserialized_buffer =
729 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
730 assert_eq!(
731 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
732 "a\nb\nc\n"
733 );
734 assert_eq!(
735 cx.read(|cx| messages(&deserialized_context, cx)),
736 [
737 (message_0, Role::User, 0..2),
738 (message_1.id, Role::Assistant, 2..6),
739 (message_2.id, Role::System, 6..6),
740 ]
741 );
742}
743
744#[gpui::test(iterations = 100)]
745async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
746 cx.update(init_test);
747
748 let min_peers = env::var("MIN_PEERS")
749 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
750 .unwrap_or(2);
751 let max_peers = env::var("MAX_PEERS")
752 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
753 .unwrap_or(5);
754 let operations = env::var("OPERATIONS")
755 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
756 .unwrap_or(50);
757
758 let slash_commands = cx.update(SlashCommandRegistry::default_global);
759 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
760 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
761 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
762
763 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
764 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
765 let mut contexts = Vec::new();
766
767 let num_peers = rng.random_range(min_peers..=max_peers);
768 let context_id = ContextId::new();
769 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
770 for i in 0..num_peers {
771 let context = cx.new(|cx| {
772 AssistantContext::new(
773 context_id.clone(),
774 i as ReplicaId,
775 language::Capability::ReadWrite,
776 registry.clone(),
777 prompt_builder.clone(),
778 Arc::new(SlashCommandWorkingSet::default()),
779 None,
780 None,
781 cx,
782 )
783 });
784
785 cx.update(|cx| {
786 cx.subscribe(&context, {
787 let network = network.clone();
788 move |_, event, _| {
789 if let ContextEvent::Operation(op) = event {
790 network
791 .lock()
792 .broadcast(i as ReplicaId, vec![op.to_proto()]);
793 }
794 }
795 })
796 .detach();
797 });
798
799 contexts.push(context);
800 network.lock().add_peer(i as ReplicaId);
801 }
802
803 let mut mutation_count = operations;
804
805 while mutation_count > 0
806 || !network.lock().is_idle()
807 || network.lock().contains_disconnected_peers()
808 {
809 let context_index = rng.random_range(0..contexts.len());
810 let context = &contexts[context_index];
811
812 match rng.random_range(0..100) {
813 0..=29 if mutation_count > 0 => {
814 log::info!("Context {}: edit buffer", context_index);
815 context.update(cx, |context, cx| {
816 context
817 .buffer
818 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
819 });
820 mutation_count -= 1;
821 }
822 30..=44 if mutation_count > 0 => {
823 context.update(cx, |context, cx| {
824 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
825 log::info!("Context {}: split message at {:?}", context_index, range);
826 context.split_message(range, cx);
827 });
828 mutation_count -= 1;
829 }
830 45..=59 if mutation_count > 0 => {
831 context.update(cx, |context, cx| {
832 if let Some(message) = context.messages(cx).choose(&mut rng) {
833 let role = *[Role::User, Role::Assistant, Role::System]
834 .choose(&mut rng)
835 .unwrap();
836 log::info!(
837 "Context {}: insert message after {:?} with {:?}",
838 context_index,
839 message.id,
840 role
841 );
842 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
843 }
844 });
845 mutation_count -= 1;
846 }
847 60..=74 if mutation_count > 0 => {
848 context.update(cx, |context, cx| {
849 let command_text = "/".to_string()
850 + slash_commands
851 .command_names()
852 .choose(&mut rng)
853 .unwrap()
854 .clone()
855 .as_ref();
856
857 let command_range = context.buffer.update(cx, |buffer, cx| {
858 let offset = buffer.random_byte_range(0, &mut rng).start;
859 buffer.edit(
860 [(offset..offset, format!("\n{}\n", command_text))],
861 None,
862 cx,
863 );
864 offset + 1..offset + 1 + command_text.len()
865 });
866
867 let output_text = RandomCharIter::new(&mut rng)
868 .filter(|c| *c != '\r')
869 .take(10)
870 .collect::<String>();
871
872 let mut events = vec![Ok(SlashCommandEvent::StartMessage {
873 role: Role::User,
874 merge_same_roles: true,
875 })];
876
877 let num_sections = rng.random_range(0..=3);
878 let mut section_start = 0;
879 for _ in 0..num_sections {
880 let mut section_end = rng.random_range(section_start..=output_text.len());
881 while !output_text.is_char_boundary(section_end) {
882 section_end += 1;
883 }
884 events.push(Ok(SlashCommandEvent::StartSection {
885 icon: IconName::Ai,
886 label: "section".into(),
887 metadata: None,
888 }));
889 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
890 text: output_text[section_start..section_end].to_string(),
891 run_commands_in_text: false,
892 })));
893 events.push(Ok(SlashCommandEvent::EndSection));
894 section_start = section_end;
895 }
896
897 if section_start < output_text.len() {
898 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
899 text: output_text[section_start..].to_string(),
900 run_commands_in_text: false,
901 })));
902 }
903
904 log::info!(
905 "Context {}: insert slash command output at {:?} with {:?} events",
906 context_index,
907 command_range,
908 events.len()
909 );
910
911 let command_range = context.buffer.read(cx).anchor_after(command_range.start)
912 ..context.buffer.read(cx).anchor_after(command_range.end);
913 context.insert_command_output(
914 command_range,
915 "/command",
916 Task::ready(Ok(stream::iter(events).boxed())),
917 true,
918 cx,
919 );
920 });
921 cx.run_until_parked();
922 mutation_count -= 1;
923 }
924 75..=84 if mutation_count > 0 => {
925 context.update(cx, |context, cx| {
926 if let Some(message) = context.messages(cx).choose(&mut rng) {
927 let new_status = match rng.random_range(0..3) {
928 0 => MessageStatus::Done,
929 1 => MessageStatus::Pending,
930 _ => MessageStatus::Error(SharedString::from("Random error")),
931 };
932 log::info!(
933 "Context {}: update message {:?} status to {:?}",
934 context_index,
935 message.id,
936 new_status
937 );
938 context.update_metadata(message.id, cx, |metadata| {
939 metadata.status = new_status;
940 });
941 }
942 });
943 mutation_count -= 1;
944 }
945 _ => {
946 let replica_id = context_index as ReplicaId;
947 if network.lock().is_disconnected(replica_id) {
948 network.lock().reconnect_peer(replica_id, 0);
949
950 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
951 let host_context = &contexts[0].read(cx);
952 let guest_context = context.read(cx);
953 (
954 guest_context.serialize_ops(&host_context.version(cx), cx),
955 host_context.serialize_ops(&guest_context.version(cx), cx),
956 )
957 });
958 let ops_to_send = ops_to_send.await;
959 let ops_to_receive = ops_to_receive
960 .await
961 .into_iter()
962 .map(ContextOperation::from_proto)
963 .collect::<Result<Vec<_>>>()
964 .unwrap();
965 log::info!(
966 "Context {}: reconnecting. Sent {} operations, received {} operations",
967 context_index,
968 ops_to_send.len(),
969 ops_to_receive.len()
970 );
971
972 network.lock().broadcast(replica_id, ops_to_send);
973 context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
974 } else if rng.random_bool(0.1) && replica_id != 0 {
975 log::info!("Context {}: disconnecting", context_index);
976 network.lock().disconnect_peer(replica_id);
977 } else if network.lock().has_unreceived(replica_id) {
978 log::info!("Context {}: applying operations", context_index);
979 let ops = network.lock().receive(replica_id);
980 let ops = ops
981 .into_iter()
982 .map(ContextOperation::from_proto)
983 .collect::<Result<Vec<_>>>()
984 .unwrap();
985 context.update(cx, |context, cx| context.apply_ops(ops, cx));
986 }
987 }
988 }
989 }
990
991 cx.read(|cx| {
992 let first_context = contexts[0].read(cx);
993 for context in &contexts[1..] {
994 let context = context.read(cx);
995 assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
996 assert_eq!(
997 context.buffer.read(cx).text(),
998 first_context.buffer.read(cx).text(),
999 "Context {} text != Context 0 text",
1000 context.buffer.read(cx).replica_id()
1001 );
1002 assert_eq!(
1003 context.message_anchors,
1004 first_context.message_anchors,
1005 "Context {} messages != Context 0 messages",
1006 context.buffer.read(cx).replica_id()
1007 );
1008 assert_eq!(
1009 context.messages_metadata,
1010 first_context.messages_metadata,
1011 "Context {} message metadata != Context 0 message metadata",
1012 context.buffer.read(cx).replica_id()
1013 );
1014 assert_eq!(
1015 context.slash_command_output_sections,
1016 first_context.slash_command_output_sections,
1017 "Context {} slash command output sections != Context 0 slash command output sections",
1018 context.buffer.read(cx).replica_id()
1019 );
1020 }
1021 });
1022}
1023
1024#[gpui::test]
1025fn test_mark_cache_anchors(cx: &mut App) {
1026 init_test(cx);
1027
1028 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1029 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1030 let context = cx.new(|cx| {
1031 AssistantContext::local(
1032 registry,
1033 None,
1034 None,
1035 prompt_builder.clone(),
1036 Arc::new(SlashCommandWorkingSet::default()),
1037 cx,
1038 )
1039 });
1040 let buffer = context.read(cx).buffer.clone();
1041
1042 // Create a test cache configuration
1043 let cache_configuration = &Some(LanguageModelCacheConfiguration {
1044 max_cache_anchors: 3,
1045 should_speculate: true,
1046 min_total_token: 10,
1047 });
1048
1049 let message_1 = context.read(cx).message_anchors[0].clone();
1050
1051 context.update(cx, |context, cx| {
1052 context.mark_cache_anchors(cache_configuration, false, cx)
1053 });
1054
1055 assert_eq!(
1056 messages_cache(&context, cx)
1057 .iter()
1058 .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1059 .count(),
1060 0,
1061 "Empty messages should not have any cache anchors."
1062 );
1063
1064 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1065 let message_2 = context
1066 .update(cx, |context, cx| {
1067 context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1068 })
1069 .unwrap();
1070
1071 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1072 let message_3 = context
1073 .update(cx, |context, cx| {
1074 context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1075 })
1076 .unwrap();
1077 buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1078
1079 context.update(cx, |context, cx| {
1080 context.mark_cache_anchors(cache_configuration, false, cx)
1081 });
1082 assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1083 assert_eq!(
1084 messages_cache(&context, cx)
1085 .iter()
1086 .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1087 .count(),
1088 0,
1089 "Messages should not be marked for cache before going over the token minimum."
1090 );
1091 context.update(cx, |context, _| {
1092 context.token_count = Some(20);
1093 });
1094
1095 context.update(cx, |context, cx| {
1096 context.mark_cache_anchors(cache_configuration, true, cx)
1097 });
1098 assert_eq!(
1099 messages_cache(&context, cx)
1100 .iter()
1101 .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1102 .collect::<Vec<bool>>(),
1103 vec![true, true, false],
1104 "Last message should not be an anchor on speculative request."
1105 );
1106
1107 context
1108 .update(cx, |context, cx| {
1109 context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1110 })
1111 .unwrap();
1112
1113 context.update(cx, |context, cx| {
1114 context.mark_cache_anchors(cache_configuration, false, cx)
1115 });
1116 assert_eq!(
1117 messages_cache(&context, cx)
1118 .iter()
1119 .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1120 .collect::<Vec<bool>>(),
1121 vec![false, true, true, false],
1122 "Most recent message should also be cached if not a speculative request."
1123 );
1124 context.update(cx, |context, cx| {
1125 context.update_cache_status_for_completion(cx)
1126 });
1127 assert_eq!(
1128 messages_cache(&context, cx)
1129 .iter()
1130 .map(|(_, cache)| cache
1131 .as_ref()
1132 .map_or(None, |cache| Some(cache.status.clone())))
1133 .collect::<Vec<Option<CacheStatus>>>(),
1134 vec![
1135 Some(CacheStatus::Cached),
1136 Some(CacheStatus::Cached),
1137 Some(CacheStatus::Cached),
1138 None
1139 ],
1140 "All user messages prior to anchor should be marked as cached."
1141 );
1142
1143 buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1144 context.update(cx, |context, cx| {
1145 context.mark_cache_anchors(cache_configuration, false, cx)
1146 });
1147 assert_eq!(
1148 messages_cache(&context, cx)
1149 .iter()
1150 .map(|(_, cache)| cache
1151 .as_ref()
1152 .map_or(None, |cache| Some(cache.status.clone())))
1153 .collect::<Vec<Option<CacheStatus>>>(),
1154 vec![
1155 Some(CacheStatus::Cached),
1156 Some(CacheStatus::Cached),
1157 Some(CacheStatus::Pending),
1158 None
1159 ],
1160 "Modifying a message should invalidate it's cache but leave previous messages."
1161 );
1162 buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1163 context.update(cx, |context, cx| {
1164 context.mark_cache_anchors(cache_configuration, false, cx)
1165 });
1166 assert_eq!(
1167 messages_cache(&context, cx)
1168 .iter()
1169 .map(|(_, cache)| cache
1170 .as_ref()
1171 .map_or(None, |cache| Some(cache.status.clone())))
1172 .collect::<Vec<Option<CacheStatus>>>(),
1173 vec![
1174 Some(CacheStatus::Pending),
1175 Some(CacheStatus::Pending),
1176 Some(CacheStatus::Pending),
1177 None
1178 ],
1179 "Modifying a message should invalidate all future messages."
1180 );
1181}
1182
1183#[gpui::test]
1184async fn test_summarization(cx: &mut TestAppContext) {
1185 let (context, fake_model) = setup_context_editor_with_fake_model(cx);
1186
1187 // Initial state should be pending
1188 context.read_with(cx, |context, _| {
1189 assert!(matches!(context.summary(), ContextSummary::Pending));
1190 assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT);
1191 });
1192
1193 let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone());
1194 context.update(cx, |context, cx| {
1195 context
1196 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1197 .unwrap();
1198 });
1199
1200 // Send a message
1201 context.update(cx, |context, cx| {
1202 context.assist(cx);
1203 });
1204
1205 simulate_successful_response(&fake_model, cx);
1206
1207 // Should start generating summary when there are >= 2 messages
1208 context.read_with(cx, |context, _| {
1209 assert!(!context.summary().content().unwrap().done);
1210 });
1211
1212 cx.run_until_parked();
1213 fake_model.send_last_completion_stream_text_chunk("Brief");
1214 fake_model.send_last_completion_stream_text_chunk(" Introduction");
1215 fake_model.end_last_completion_stream();
1216 cx.run_until_parked();
1217
1218 // Summary should be set
1219 context.read_with(cx, |context, _| {
1220 assert_eq!(context.summary().or_default(), "Brief Introduction");
1221 });
1222
1223 // We should be able to manually set a summary
1224 context.update(cx, |context, cx| {
1225 context.set_custom_summary("Brief Intro".into(), cx);
1226 });
1227
1228 context.read_with(cx, |context, _| {
1229 assert_eq!(context.summary().or_default(), "Brief Intro");
1230 });
1231}
1232
1233#[gpui::test]
1234async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
1235 let (context, fake_model) = setup_context_editor_with_fake_model(cx);
1236
1237 test_summarize_error(&fake_model, &context, cx);
1238
1239 // Now we should be able to set a summary
1240 context.update(cx, |context, cx| {
1241 context.set_custom_summary("Brief Intro".into(), cx);
1242 });
1243
1244 context.read_with(cx, |context, _| {
1245 assert_eq!(context.summary().or_default(), "Brief Intro");
1246 });
1247}
1248
1249#[gpui::test]
1250async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
1251 let (context, fake_model) = setup_context_editor_with_fake_model(cx);
1252
1253 test_summarize_error(&fake_model, &context, cx);
1254
1255 // Sending another message should not trigger another summarize request
1256 context.update(cx, |context, cx| {
1257 context.assist(cx);
1258 });
1259
1260 simulate_successful_response(&fake_model, cx);
1261
1262 context.read_with(cx, |context, _| {
1263 // State is still Error, not Generating
1264 assert!(matches!(context.summary(), ContextSummary::Error));
1265 });
1266
1267 // But the summarize request can be invoked manually
1268 context.update(cx, |context, cx| {
1269 context.summarize(true, cx);
1270 });
1271
1272 context.read_with(cx, |context, _| {
1273 assert!(!context.summary().content().unwrap().done);
1274 });
1275
1276 cx.run_until_parked();
1277 fake_model.send_last_completion_stream_text_chunk("A successful summary");
1278 fake_model.end_last_completion_stream();
1279 cx.run_until_parked();
1280
1281 context.read_with(cx, |context, _| {
1282 assert_eq!(context.summary().or_default(), "A successful summary");
1283 });
1284}
1285
1286fn test_summarize_error(
1287 model: &Arc<FakeLanguageModel>,
1288 context: &Entity<AssistantContext>,
1289 cx: &mut TestAppContext,
1290) {
1291 let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone());
1292 context.update(cx, |context, cx| {
1293 context
1294 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1295 .unwrap();
1296 });
1297
1298 // Send a message
1299 context.update(cx, |context, cx| {
1300 context.assist(cx);
1301 });
1302
1303 simulate_successful_response(model, cx);
1304
1305 context.read_with(cx, |context, _| {
1306 assert!(!context.summary().content().unwrap().done);
1307 });
1308
1309 // Simulate summary request ending
1310 cx.run_until_parked();
1311 model.end_last_completion_stream();
1312 cx.run_until_parked();
1313
1314 // State is set to Error and default message
1315 context.read_with(cx, |context, _| {
1316 assert_eq!(*context.summary(), ContextSummary::Error);
1317 assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT);
1318 });
1319}
1320
1321fn setup_context_editor_with_fake_model(
1322 cx: &mut TestAppContext,
1323) -> (Entity<AssistantContext>, Arc<FakeLanguageModel>) {
1324 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1325
1326 let fake_provider = Arc::new(FakeLanguageModelProvider::default());
1327 let fake_model = Arc::new(fake_provider.test_model());
1328
1329 cx.update(|cx| {
1330 init_test(cx);
1331 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1332 registry.set_default_model(
1333 Some(ConfiguredModel {
1334 provider: fake_provider.clone(),
1335 model: fake_model.clone(),
1336 }),
1337 cx,
1338 )
1339 })
1340 });
1341
1342 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1343 let context = cx.new(|cx| {
1344 AssistantContext::local(
1345 registry,
1346 None,
1347 None,
1348 prompt_builder.clone(),
1349 Arc::new(SlashCommandWorkingSet::default()),
1350 cx,
1351 )
1352 });
1353
1354 (context, fake_model)
1355}
1356
1357fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
1358 cx.run_until_parked();
1359 fake_model.send_last_completion_stream_text_chunk("Assistant response");
1360 fake_model.end_last_completion_stream();
1361 cx.run_until_parked();
1362}
1363
1364fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1365 context
1366 .read(cx)
1367 .messages(cx)
1368 .map(|message| (message.id, message.role, message.offset_range))
1369 .collect()
1370}
1371
1372fn messages_cache(
1373 context: &Entity<AssistantContext>,
1374 cx: &App,
1375) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1376 context
1377 .read(cx)
1378 .messages(cx)
1379 .map(|message| (message.id, message.cache))
1380 .collect()
1381}
1382
1383fn init_test(cx: &mut App) {
1384 let settings_store = SettingsStore::test(cx);
1385 prompt_store::init(cx);
1386 LanguageModelRegistry::test(cx);
1387 cx.set_global(settings_store);
1388 language::init(cx);
1389 agent_settings::init(cx);
1390 Project::init_settings(cx);
1391}
1392
1393#[derive(Clone)]
1394struct FakeSlashCommand(String);
1395
1396impl SlashCommand for FakeSlashCommand {
1397 fn name(&self) -> String {
1398 self.0.clone()
1399 }
1400
1401 fn description(&self) -> String {
1402 format!("Fake slash command: {}", self.0)
1403 }
1404
1405 fn menu_text(&self) -> String {
1406 format!("Run fake command: {}", self.0)
1407 }
1408
1409 fn complete_argument(
1410 self: Arc<Self>,
1411 _arguments: &[String],
1412 _cancel: Arc<AtomicBool>,
1413 _workspace: Option<WeakEntity<Workspace>>,
1414 _window: &mut Window,
1415 _cx: &mut App,
1416 ) -> Task<Result<Vec<ArgumentCompletion>>> {
1417 Task::ready(Ok(vec![]))
1418 }
1419
1420 fn requires_argument(&self) -> bool {
1421 false
1422 }
1423
1424 fn run(
1425 self: Arc<Self>,
1426 _arguments: &[String],
1427 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1428 _context_buffer: BufferSnapshot,
1429 _workspace: WeakEntity<Workspace>,
1430 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1431 _window: &mut Window,
1432 _cx: &mut App,
1433 ) -> Task<SlashCommandResult> {
1434 Task::ready(Ok(SlashCommandOutput {
1435 text: format!("Executed fake command: {}", self.0),
1436 sections: vec![],
1437 run_commands_in_text: false,
1438 }
1439 .into_event_stream()))
1440 }
1441}