1use crate::{
2 AssistantContext, CacheStatus, ContextEvent, ContextId, ContextOperation,
3 InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
4};
5use anyhow::Result;
6use assistant_slash_command::{
7 ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
8 SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
9};
10use assistant_slash_commands::FileSlashCommand;
11use collections::{HashMap, HashSet};
12use fs::FakeFs;
13use futures::{
14 channel::mpsc,
15 stream::{self, StreamExt},
16};
17use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
18use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
19use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
20use parking_lot::Mutex;
21use pretty_assertions::assert_eq;
22use project::Project;
23use prompt_store::PromptBuilder;
24use rand::prelude::*;
25use serde_json::json;
26use settings::SettingsStore;
27use std::{
28 cell::RefCell,
29 env,
30 ops::Range,
31 path::Path,
32 rc::Rc,
33 sync::{Arc, atomic::AtomicBool},
34};
35use text::{ReplicaId, ToOffset, network::Network};
36use ui::{IconName, Window};
37use unindent::Unindent;
38use util::RandomCharIter;
39use workspace::Workspace;
40
41#[gpui::test]
42fn test_inserting_and_removing_messages(cx: &mut App) {
43 init_test(cx);
44
45 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
46 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
47 let context = cx.new(|cx| {
48 AssistantContext::local(
49 registry,
50 None,
51 None,
52 prompt_builder.clone(),
53 Arc::new(SlashCommandWorkingSet::default()),
54 cx,
55 )
56 });
57 let buffer = context.read(cx).buffer.clone();
58
59 let message_1 = context.read(cx).message_anchors[0].clone();
60 assert_eq!(
61 messages(&context, cx),
62 vec![(message_1.id, Role::User, 0..0)]
63 );
64
65 let message_2 = context.update(cx, |context, cx| {
66 context
67 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
68 .unwrap()
69 });
70 assert_eq!(
71 messages(&context, cx),
72 vec![
73 (message_1.id, Role::User, 0..1),
74 (message_2.id, Role::Assistant, 1..1)
75 ]
76 );
77
78 buffer.update(cx, |buffer, cx| {
79 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
80 });
81 assert_eq!(
82 messages(&context, cx),
83 vec![
84 (message_1.id, Role::User, 0..2),
85 (message_2.id, Role::Assistant, 2..3)
86 ]
87 );
88
89 let message_3 = context.update(cx, |context, cx| {
90 context
91 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
92 .unwrap()
93 });
94 assert_eq!(
95 messages(&context, cx),
96 vec![
97 (message_1.id, Role::User, 0..2),
98 (message_2.id, Role::Assistant, 2..4),
99 (message_3.id, Role::User, 4..4)
100 ]
101 );
102
103 let message_4 = context.update(cx, |context, cx| {
104 context
105 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
106 .unwrap()
107 });
108 assert_eq!(
109 messages(&context, cx),
110 vec![
111 (message_1.id, Role::User, 0..2),
112 (message_2.id, Role::Assistant, 2..4),
113 (message_4.id, Role::User, 4..5),
114 (message_3.id, Role::User, 5..5),
115 ]
116 );
117
118 buffer.update(cx, |buffer, cx| {
119 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
120 });
121 assert_eq!(
122 messages(&context, cx),
123 vec![
124 (message_1.id, Role::User, 0..2),
125 (message_2.id, Role::Assistant, 2..4),
126 (message_4.id, Role::User, 4..6),
127 (message_3.id, Role::User, 6..7),
128 ]
129 );
130
131 // Deleting across message boundaries merges the messages.
132 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
133 assert_eq!(
134 messages(&context, cx),
135 vec![
136 (message_1.id, Role::User, 0..3),
137 (message_3.id, Role::User, 3..4),
138 ]
139 );
140
141 // Undoing the deletion should also undo the merge.
142 buffer.update(cx, |buffer, cx| buffer.undo(cx));
143 assert_eq!(
144 messages(&context, cx),
145 vec![
146 (message_1.id, Role::User, 0..2),
147 (message_2.id, Role::Assistant, 2..4),
148 (message_4.id, Role::User, 4..6),
149 (message_3.id, Role::User, 6..7),
150 ]
151 );
152
153 // Redoing the deletion should also redo the merge.
154 buffer.update(cx, |buffer, cx| buffer.redo(cx));
155 assert_eq!(
156 messages(&context, cx),
157 vec![
158 (message_1.id, Role::User, 0..3),
159 (message_3.id, Role::User, 3..4),
160 ]
161 );
162
163 // Ensure we can still insert after a merged message.
164 let message_5 = context.update(cx, |context, cx| {
165 context
166 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
167 .unwrap()
168 });
169 assert_eq!(
170 messages(&context, cx),
171 vec![
172 (message_1.id, Role::User, 0..3),
173 (message_5.id, Role::System, 3..4),
174 (message_3.id, Role::User, 4..5)
175 ]
176 );
177}
178
179#[gpui::test]
180fn test_message_splitting(cx: &mut App) {
181 init_test(cx);
182
183 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
184
185 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
186 let context = cx.new(|cx| {
187 AssistantContext::local(
188 registry.clone(),
189 None,
190 None,
191 prompt_builder.clone(),
192 Arc::new(SlashCommandWorkingSet::default()),
193 cx,
194 )
195 });
196 let buffer = context.read(cx).buffer.clone();
197
198 let message_1 = context.read(cx).message_anchors[0].clone();
199 assert_eq!(
200 messages(&context, cx),
201 vec![(message_1.id, Role::User, 0..0)]
202 );
203
204 buffer.update(cx, |buffer, cx| {
205 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
206 });
207
208 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
209 let message_2 = message_2.unwrap();
210
211 // We recycle newlines in the middle of a split message
212 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
213 assert_eq!(
214 messages(&context, cx),
215 vec![
216 (message_1.id, Role::User, 0..4),
217 (message_2.id, Role::User, 4..16),
218 ]
219 );
220
221 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
222 let message_3 = message_3.unwrap();
223
224 // We don't recycle newlines at the end of a split message
225 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
226 assert_eq!(
227 messages(&context, cx),
228 vec![
229 (message_1.id, Role::User, 0..4),
230 (message_3.id, Role::User, 4..5),
231 (message_2.id, Role::User, 5..17),
232 ]
233 );
234
235 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
236 let message_4 = message_4.unwrap();
237 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
238 assert_eq!(
239 messages(&context, cx),
240 vec![
241 (message_1.id, Role::User, 0..4),
242 (message_3.id, Role::User, 4..5),
243 (message_2.id, Role::User, 5..9),
244 (message_4.id, Role::User, 9..17),
245 ]
246 );
247
248 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
249 let message_5 = message_5.unwrap();
250 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
251 assert_eq!(
252 messages(&context, cx),
253 vec![
254 (message_1.id, Role::User, 0..4),
255 (message_3.id, Role::User, 4..5),
256 (message_2.id, Role::User, 5..9),
257 (message_4.id, Role::User, 9..10),
258 (message_5.id, Role::User, 10..18),
259 ]
260 );
261
262 let (message_6, message_7) =
263 context.update(cx, |context, cx| context.split_message(14..16, cx));
264 let message_6 = message_6.unwrap();
265 let message_7 = message_7.unwrap();
266 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
267 assert_eq!(
268 messages(&context, cx),
269 vec![
270 (message_1.id, Role::User, 0..4),
271 (message_3.id, Role::User, 4..5),
272 (message_2.id, Role::User, 5..9),
273 (message_4.id, Role::User, 9..10),
274 (message_5.id, Role::User, 10..14),
275 (message_6.id, Role::User, 14..17),
276 (message_7.id, Role::User, 17..19),
277 ]
278 );
279}
280
281#[gpui::test]
282fn test_messages_for_offsets(cx: &mut App) {
283 init_test(cx);
284
285 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
286 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
287 let context = cx.new(|cx| {
288 AssistantContext::local(
289 registry,
290 None,
291 None,
292 prompt_builder.clone(),
293 Arc::new(SlashCommandWorkingSet::default()),
294 cx,
295 )
296 });
297 let buffer = context.read(cx).buffer.clone();
298
299 let message_1 = context.read(cx).message_anchors[0].clone();
300 assert_eq!(
301 messages(&context, cx),
302 vec![(message_1.id, Role::User, 0..0)]
303 );
304
305 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
306 let message_2 = context
307 .update(cx, |context, cx| {
308 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
309 })
310 .unwrap();
311 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
312
313 let message_3 = context
314 .update(cx, |context, cx| {
315 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
316 })
317 .unwrap();
318 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
319
320 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
321 assert_eq!(
322 messages(&context, cx),
323 vec![
324 (message_1.id, Role::User, 0..4),
325 (message_2.id, Role::User, 4..8),
326 (message_3.id, Role::User, 8..11)
327 ]
328 );
329
330 assert_eq!(
331 message_ids_for_offsets(&context, &[0, 4, 9], cx),
332 [message_1.id, message_2.id, message_3.id]
333 );
334 assert_eq!(
335 message_ids_for_offsets(&context, &[0, 1, 11], cx),
336 [message_1.id, message_3.id]
337 );
338
339 let message_4 = context
340 .update(cx, |context, cx| {
341 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
342 })
343 .unwrap();
344 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
345 assert_eq!(
346 messages(&context, cx),
347 vec![
348 (message_1.id, Role::User, 0..4),
349 (message_2.id, Role::User, 4..8),
350 (message_3.id, Role::User, 8..12),
351 (message_4.id, Role::User, 12..12)
352 ]
353 );
354 assert_eq!(
355 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
356 [message_1.id, message_2.id, message_3.id, message_4.id]
357 );
358
359 fn message_ids_for_offsets(
360 context: &Entity<AssistantContext>,
361 offsets: &[usize],
362 cx: &App,
363 ) -> Vec<MessageId> {
364 context
365 .read(cx)
366 .messages_for_offsets(offsets.iter().copied(), cx)
367 .into_iter()
368 .map(|message| message.id)
369 .collect()
370 }
371}
372
373#[gpui::test]
374async fn test_slash_commands(cx: &mut TestAppContext) {
375 cx.update(init_test);
376
377 let fs = FakeFs::new(cx.background_executor.clone());
378
379 fs.insert_tree(
380 "/test",
381 json!({
382 "src": {
383 "lib.rs": "fn one() -> usize { 1 }",
384 "main.rs": "
385 use crate::one;
386 fn main() { one(); }
387 ".unindent(),
388 }
389 }),
390 )
391 .await;
392
393 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
394 slash_command_registry.register_command(FileSlashCommand, false);
395
396 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
397 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
398 let context = cx.new(|cx| {
399 AssistantContext::local(
400 registry.clone(),
401 None,
402 None,
403 prompt_builder.clone(),
404 Arc::new(SlashCommandWorkingSet::default()),
405 cx,
406 )
407 });
408
409 #[derive(Default)]
410 struct ContextRanges {
411 parsed_commands: HashSet<Range<language::Anchor>>,
412 command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
413 output_sections: HashSet<Range<language::Anchor>>,
414 }
415
416 let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
417 context.update(cx, |_, cx| {
418 cx.subscribe(&context, {
419 let context_ranges = context_ranges.clone();
420 move |context, _, event, _| {
421 let mut context_ranges = context_ranges.borrow_mut();
422 match event {
423 ContextEvent::InvokedSlashCommandChanged { command_id } => {
424 let command = context.invoked_slash_command(command_id).unwrap();
425 context_ranges
426 .command_outputs
427 .insert(*command_id, command.range.clone());
428 }
429 ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
430 for range in removed {
431 context_ranges.parsed_commands.remove(range);
432 }
433 for command in updated {
434 context_ranges
435 .parsed_commands
436 .insert(command.source_range.clone());
437 }
438 }
439 ContextEvent::SlashCommandOutputSectionAdded { section } => {
440 context_ranges.output_sections.insert(section.range.clone());
441 }
442 _ => {}
443 }
444 }
445 })
446 .detach();
447 });
448
449 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
450
451 // Insert a slash command
452 buffer.update(cx, |buffer, cx| {
453 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
454 });
455 assert_text_and_context_ranges(
456 &buffer,
457 &context_ranges,
458 &"
459 «/file src/lib.rs»"
460 .unindent(),
461 cx,
462 );
463
464 // Edit the argument of the slash command.
465 buffer.update(cx, |buffer, cx| {
466 let edit_offset = buffer.text().find("lib.rs").unwrap();
467 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
468 });
469 assert_text_and_context_ranges(
470 &buffer,
471 &context_ranges,
472 &"
473 «/file src/main.rs»"
474 .unindent(),
475 cx,
476 );
477
478 // Edit the name of the slash command, using one that doesn't exist.
479 buffer.update(cx, |buffer, cx| {
480 let edit_offset = buffer.text().find("/file").unwrap();
481 buffer.edit(
482 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
483 None,
484 cx,
485 );
486 });
487 assert_text_and_context_ranges(
488 &buffer,
489 &context_ranges,
490 &"
491 /unknown src/main.rs"
492 .unindent(),
493 cx,
494 );
495
496 // Undoing the insertion of an non-existent slash command resorts the previous one.
497 buffer.update(cx, |buffer, cx| buffer.undo(cx));
498 assert_text_and_context_ranges(
499 &buffer,
500 &context_ranges,
501 &"
502 «/file src/main.rs»"
503 .unindent(),
504 cx,
505 );
506
507 let (command_output_tx, command_output_rx) = mpsc::unbounded();
508 context.update(cx, |context, cx| {
509 let command_source_range = context.parsed_slash_commands[0].source_range.clone();
510 context.insert_command_output(
511 command_source_range,
512 "file",
513 Task::ready(Ok(command_output_rx.boxed())),
514 true,
515 cx,
516 );
517 });
518 assert_text_and_context_ranges(
519 &buffer,
520 &context_ranges,
521 &"
522 ⟦«/file src/main.rs»
523 …⟧
524 "
525 .unindent(),
526 cx,
527 );
528
529 command_output_tx
530 .unbounded_send(Ok(SlashCommandEvent::StartSection {
531 icon: IconName::Ai,
532 label: "src/main.rs".into(),
533 metadata: None,
534 }))
535 .unwrap();
536 command_output_tx
537 .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
538 .unwrap();
539 cx.run_until_parked();
540 assert_text_and_context_ranges(
541 &buffer,
542 &context_ranges,
543 &"
544 ⟦«/file src/main.rs»
545 src/main.rs…⟧
546 "
547 .unindent(),
548 cx,
549 );
550
551 command_output_tx
552 .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
553 .unwrap();
554 cx.run_until_parked();
555 assert_text_and_context_ranges(
556 &buffer,
557 &context_ranges,
558 &"
559 ⟦«/file src/main.rs»
560 src/main.rs
561 fn main() {}…⟧
562 "
563 .unindent(),
564 cx,
565 );
566
567 command_output_tx
568 .unbounded_send(Ok(SlashCommandEvent::EndSection))
569 .unwrap();
570 cx.run_until_parked();
571 assert_text_and_context_ranges(
572 &buffer,
573 &context_ranges,
574 &"
575 ⟦«/file src/main.rs»
576 ⟪src/main.rs
577 fn main() {}⟫…⟧
578 "
579 .unindent(),
580 cx,
581 );
582
583 drop(command_output_tx);
584 cx.run_until_parked();
585 assert_text_and_context_ranges(
586 &buffer,
587 &context_ranges,
588 &"
589 ⟦⟪src/main.rs
590 fn main() {}⟫⟧
591 "
592 .unindent(),
593 cx,
594 );
595
596 #[track_caller]
597 fn assert_text_and_context_ranges(
598 buffer: &Entity<Buffer>,
599 ranges: &RefCell<ContextRanges>,
600 expected_marked_text: &str,
601 cx: &mut TestAppContext,
602 ) {
603 let mut actual_marked_text = String::new();
604 buffer.update(cx, |buffer, _| {
605 struct Endpoint {
606 offset: usize,
607 marker: char,
608 }
609
610 let ranges = ranges.borrow();
611 let mut endpoints = Vec::new();
612 for range in ranges.command_outputs.values() {
613 endpoints.push(Endpoint {
614 offset: range.start.to_offset(buffer),
615 marker: '⟦',
616 });
617 }
618 for range in ranges.parsed_commands.iter() {
619 endpoints.push(Endpoint {
620 offset: range.start.to_offset(buffer),
621 marker: '«',
622 });
623 }
624 for range in ranges.output_sections.iter() {
625 endpoints.push(Endpoint {
626 offset: range.start.to_offset(buffer),
627 marker: '⟪',
628 });
629 }
630
631 for range in ranges.output_sections.iter() {
632 endpoints.push(Endpoint {
633 offset: range.end.to_offset(buffer),
634 marker: '⟫',
635 });
636 }
637 for range in ranges.parsed_commands.iter() {
638 endpoints.push(Endpoint {
639 offset: range.end.to_offset(buffer),
640 marker: '»',
641 });
642 }
643 for range in ranges.command_outputs.values() {
644 endpoints.push(Endpoint {
645 offset: range.end.to_offset(buffer),
646 marker: '⟧',
647 });
648 }
649
650 endpoints.sort_by_key(|endpoint| endpoint.offset);
651 let mut offset = 0;
652 for endpoint in endpoints {
653 actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
654 actual_marked_text.push(endpoint.marker);
655 offset = endpoint.offset;
656 }
657 actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
658 });
659
660 assert_eq!(actual_marked_text, expected_marked_text);
661 }
662}
663
664#[gpui::test]
665async fn test_serialization(cx: &mut TestAppContext) {
666 cx.update(init_test);
667
668 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
669 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
670 let context = cx.new(|cx| {
671 AssistantContext::local(
672 registry.clone(),
673 None,
674 None,
675 prompt_builder.clone(),
676 Arc::new(SlashCommandWorkingSet::default()),
677 cx,
678 )
679 });
680 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
681 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
682 let message_1 = context.update(cx, |context, cx| {
683 context
684 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
685 .unwrap()
686 });
687 let message_2 = context.update(cx, |context, cx| {
688 context
689 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
690 .unwrap()
691 });
692 buffer.update(cx, |buffer, cx| {
693 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
694 buffer.finalize_last_transaction();
695 });
696 let _message_3 = context.update(cx, |context, cx| {
697 context
698 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
699 .unwrap()
700 });
701 buffer.update(cx, |buffer, cx| buffer.undo(cx));
702 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
703 assert_eq!(
704 cx.read(|cx| messages(&context, cx)),
705 [
706 (message_0, Role::User, 0..2),
707 (message_1.id, Role::Assistant, 2..6),
708 (message_2.id, Role::System, 6..6),
709 ]
710 );
711
712 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
713 let deserialized_context = cx.new(|cx| {
714 AssistantContext::deserialize(
715 serialized_context,
716 Path::new("").into(),
717 registry.clone(),
718 prompt_builder.clone(),
719 Arc::new(SlashCommandWorkingSet::default()),
720 None,
721 None,
722 cx,
723 )
724 });
725 let deserialized_buffer =
726 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
727 assert_eq!(
728 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
729 "a\nb\nc\n"
730 );
731 assert_eq!(
732 cx.read(|cx| messages(&deserialized_context, cx)),
733 [
734 (message_0, Role::User, 0..2),
735 (message_1.id, Role::Assistant, 2..6),
736 (message_2.id, Role::System, 6..6),
737 ]
738 );
739}
740
741#[gpui::test(iterations = 100)]
742async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
743 cx.update(init_test);
744
745 let min_peers = env::var("MIN_PEERS")
746 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
747 .unwrap_or(2);
748 let max_peers = env::var("MAX_PEERS")
749 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
750 .unwrap_or(5);
751 let operations = env::var("OPERATIONS")
752 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
753 .unwrap_or(50);
754
755 let slash_commands = cx.update(SlashCommandRegistry::default_global);
756 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
757 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
758 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
759
760 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
761 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
762 let mut contexts = Vec::new();
763
764 let num_peers = rng.gen_range(min_peers..=max_peers);
765 let context_id = ContextId::new();
766 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
767 for i in 0..num_peers {
768 let context = cx.new(|cx| {
769 AssistantContext::new(
770 context_id.clone(),
771 i as ReplicaId,
772 language::Capability::ReadWrite,
773 registry.clone(),
774 prompt_builder.clone(),
775 Arc::new(SlashCommandWorkingSet::default()),
776 None,
777 None,
778 cx,
779 )
780 });
781
782 cx.update(|cx| {
783 cx.subscribe(&context, {
784 let network = network.clone();
785 move |_, event, _| {
786 if let ContextEvent::Operation(op) = event {
787 network
788 .lock()
789 .broadcast(i as ReplicaId, vec![op.to_proto()]);
790 }
791 }
792 })
793 .detach();
794 });
795
796 contexts.push(context);
797 network.lock().add_peer(i as ReplicaId);
798 }
799
800 let mut mutation_count = operations;
801
802 while mutation_count > 0
803 || !network.lock().is_idle()
804 || network.lock().contains_disconnected_peers()
805 {
806 let context_index = rng.gen_range(0..contexts.len());
807 let context = &contexts[context_index];
808
809 match rng.gen_range(0..100) {
810 0..=29 if mutation_count > 0 => {
811 log::info!("Context {}: edit buffer", context_index);
812 context.update(cx, |context, cx| {
813 context
814 .buffer
815 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
816 });
817 mutation_count -= 1;
818 }
819 30..=44 if mutation_count > 0 => {
820 context.update(cx, |context, cx| {
821 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
822 log::info!("Context {}: split message at {:?}", context_index, range);
823 context.split_message(range, cx);
824 });
825 mutation_count -= 1;
826 }
827 45..=59 if mutation_count > 0 => {
828 context.update(cx, |context, cx| {
829 if let Some(message) = context.messages(cx).choose(&mut rng) {
830 let role = *[Role::User, Role::Assistant, Role::System]
831 .choose(&mut rng)
832 .unwrap();
833 log::info!(
834 "Context {}: insert message after {:?} with {:?}",
835 context_index,
836 message.id,
837 role
838 );
839 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
840 }
841 });
842 mutation_count -= 1;
843 }
844 60..=74 if mutation_count > 0 => {
845 context.update(cx, |context, cx| {
846 let command_text = "/".to_string()
847 + slash_commands
848 .command_names()
849 .choose(&mut rng)
850 .unwrap()
851 .clone()
852 .as_ref();
853
854 let command_range = context.buffer.update(cx, |buffer, cx| {
855 let offset = buffer.random_byte_range(0, &mut rng).start;
856 buffer.edit(
857 [(offset..offset, format!("\n{}\n", command_text))],
858 None,
859 cx,
860 );
861 offset + 1..offset + 1 + command_text.len()
862 });
863
864 let output_text = RandomCharIter::new(&mut rng)
865 .filter(|c| *c != '\r')
866 .take(10)
867 .collect::<String>();
868
869 let mut events = vec![Ok(SlashCommandEvent::StartMessage {
870 role: Role::User,
871 merge_same_roles: true,
872 })];
873
874 let num_sections = rng.gen_range(0..=3);
875 let mut section_start = 0;
876 for _ in 0..num_sections {
877 let mut section_end = rng.gen_range(section_start..=output_text.len());
878 while !output_text.is_char_boundary(section_end) {
879 section_end += 1;
880 }
881 events.push(Ok(SlashCommandEvent::StartSection {
882 icon: IconName::Ai,
883 label: "section".into(),
884 metadata: None,
885 }));
886 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
887 text: output_text[section_start..section_end].to_string(),
888 run_commands_in_text: false,
889 })));
890 events.push(Ok(SlashCommandEvent::EndSection));
891 section_start = section_end;
892 }
893
894 if section_start < output_text.len() {
895 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
896 text: output_text[section_start..].to_string(),
897 run_commands_in_text: false,
898 })));
899 }
900
901 log::info!(
902 "Context {}: insert slash command output at {:?} with {:?} events",
903 context_index,
904 command_range,
905 events.len()
906 );
907
908 let command_range = context.buffer.read(cx).anchor_after(command_range.start)
909 ..context.buffer.read(cx).anchor_after(command_range.end);
910 context.insert_command_output(
911 command_range,
912 "/command",
913 Task::ready(Ok(stream::iter(events).boxed())),
914 true,
915 cx,
916 );
917 });
918 cx.run_until_parked();
919 mutation_count -= 1;
920 }
921 75..=84 if mutation_count > 0 => {
922 context.update(cx, |context, cx| {
923 if let Some(message) = context.messages(cx).choose(&mut rng) {
924 let new_status = match rng.gen_range(0..3) {
925 0 => MessageStatus::Done,
926 1 => MessageStatus::Pending,
927 _ => MessageStatus::Error(SharedString::from("Random error")),
928 };
929 log::info!(
930 "Context {}: update message {:?} status to {:?}",
931 context_index,
932 message.id,
933 new_status
934 );
935 context.update_metadata(message.id, cx, |metadata| {
936 metadata.status = new_status;
937 });
938 }
939 });
940 mutation_count -= 1;
941 }
942 _ => {
943 let replica_id = context_index as ReplicaId;
944 if network.lock().is_disconnected(replica_id) {
945 network.lock().reconnect_peer(replica_id, 0);
946
947 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
948 let host_context = &contexts[0].read(cx);
949 let guest_context = context.read(cx);
950 (
951 guest_context.serialize_ops(&host_context.version(cx), cx),
952 host_context.serialize_ops(&guest_context.version(cx), cx),
953 )
954 });
955 let ops_to_send = ops_to_send.await;
956 let ops_to_receive = ops_to_receive
957 .await
958 .into_iter()
959 .map(ContextOperation::from_proto)
960 .collect::<Result<Vec<_>>>()
961 .unwrap();
962 log::info!(
963 "Context {}: reconnecting. Sent {} operations, received {} operations",
964 context_index,
965 ops_to_send.len(),
966 ops_to_receive.len()
967 );
968
969 network.lock().broadcast(replica_id, ops_to_send);
970 context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
971 } else if rng.gen_bool(0.1) && replica_id != 0 {
972 log::info!("Context {}: disconnecting", context_index);
973 network.lock().disconnect_peer(replica_id);
974 } else if network.lock().has_unreceived(replica_id) {
975 log::info!("Context {}: applying operations", context_index);
976 let ops = network.lock().receive(replica_id);
977 let ops = ops
978 .into_iter()
979 .map(ContextOperation::from_proto)
980 .collect::<Result<Vec<_>>>()
981 .unwrap();
982 context.update(cx, |context, cx| context.apply_ops(ops, cx));
983 }
984 }
985 }
986 }
987
988 cx.read(|cx| {
989 let first_context = contexts[0].read(cx);
990 for context in &contexts[1..] {
991 let context = context.read(cx);
992 assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
993 assert_eq!(
994 context.buffer.read(cx).text(),
995 first_context.buffer.read(cx).text(),
996 "Context {} text != Context 0 text",
997 context.buffer.read(cx).replica_id()
998 );
999 assert_eq!(
1000 context.message_anchors,
1001 first_context.message_anchors,
1002 "Context {} messages != Context 0 messages",
1003 context.buffer.read(cx).replica_id()
1004 );
1005 assert_eq!(
1006 context.messages_metadata,
1007 first_context.messages_metadata,
1008 "Context {} message metadata != Context 0 message metadata",
1009 context.buffer.read(cx).replica_id()
1010 );
1011 assert_eq!(
1012 context.slash_command_output_sections,
1013 first_context.slash_command_output_sections,
1014 "Context {} slash command output sections != Context 0 slash command output sections",
1015 context.buffer.read(cx).replica_id()
1016 );
1017 }
1018 });
1019}
1020
1021#[gpui::test]
1022fn test_mark_cache_anchors(cx: &mut App) {
1023 init_test(cx);
1024
1025 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1026 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1027 let context = cx.new(|cx| {
1028 AssistantContext::local(
1029 registry,
1030 None,
1031 None,
1032 prompt_builder.clone(),
1033 Arc::new(SlashCommandWorkingSet::default()),
1034 cx,
1035 )
1036 });
1037 let buffer = context.read(cx).buffer.clone();
1038
1039 // Create a test cache configuration
1040 let cache_configuration = &Some(LanguageModelCacheConfiguration {
1041 max_cache_anchors: 3,
1042 should_speculate: true,
1043 min_total_token: 10,
1044 });
1045
1046 let message_1 = context.read(cx).message_anchors[0].clone();
1047
1048 context.update(cx, |context, cx| {
1049 context.mark_cache_anchors(cache_configuration, false, cx)
1050 });
1051
1052 assert_eq!(
1053 messages_cache(&context, cx)
1054 .iter()
1055 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1056 .count(),
1057 0,
1058 "Empty messages should not have any cache anchors."
1059 );
1060
1061 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1062 let message_2 = context
1063 .update(cx, |context, cx| {
1064 context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1065 })
1066 .unwrap();
1067
1068 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1069 let message_3 = context
1070 .update(cx, |context, cx| {
1071 context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1072 })
1073 .unwrap();
1074 buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1075
1076 context.update(cx, |context, cx| {
1077 context.mark_cache_anchors(cache_configuration, false, cx)
1078 });
1079 assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1080 assert_eq!(
1081 messages_cache(&context, cx)
1082 .iter()
1083 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1084 .count(),
1085 0,
1086 "Messages should not be marked for cache before going over the token minimum."
1087 );
1088 context.update(cx, |context, _| {
1089 context.token_count = Some(20);
1090 });
1091
1092 context.update(cx, |context, cx| {
1093 context.mark_cache_anchors(cache_configuration, true, cx)
1094 });
1095 assert_eq!(
1096 messages_cache(&context, cx)
1097 .iter()
1098 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1099 .collect::<Vec<bool>>(),
1100 vec![true, true, false],
1101 "Last message should not be an anchor on speculative request."
1102 );
1103
1104 context
1105 .update(cx, |context, cx| {
1106 context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1107 })
1108 .unwrap();
1109
1110 context.update(cx, |context, cx| {
1111 context.mark_cache_anchors(cache_configuration, false, cx)
1112 });
1113 assert_eq!(
1114 messages_cache(&context, cx)
1115 .iter()
1116 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1117 .collect::<Vec<bool>>(),
1118 vec![false, true, true, false],
1119 "Most recent message should also be cached if not a speculative request."
1120 );
1121 context.update(cx, |context, cx| {
1122 context.update_cache_status_for_completion(cx)
1123 });
1124 assert_eq!(
1125 messages_cache(&context, cx)
1126 .iter()
1127 .map(|(_, cache)| cache
1128 .as_ref()
1129 .map_or(None, |cache| Some(cache.status.clone())))
1130 .collect::<Vec<Option<CacheStatus>>>(),
1131 vec![
1132 Some(CacheStatus::Cached),
1133 Some(CacheStatus::Cached),
1134 Some(CacheStatus::Cached),
1135 None
1136 ],
1137 "All user messages prior to anchor should be marked as cached."
1138 );
1139
1140 buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1141 context.update(cx, |context, cx| {
1142 context.mark_cache_anchors(cache_configuration, false, cx)
1143 });
1144 assert_eq!(
1145 messages_cache(&context, cx)
1146 .iter()
1147 .map(|(_, cache)| cache
1148 .as_ref()
1149 .map_or(None, |cache| Some(cache.status.clone())))
1150 .collect::<Vec<Option<CacheStatus>>>(),
1151 vec![
1152 Some(CacheStatus::Cached),
1153 Some(CacheStatus::Cached),
1154 Some(CacheStatus::Pending),
1155 None
1156 ],
1157 "Modifying a message should invalidate it's cache but leave previous messages."
1158 );
1159 buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1160 context.update(cx, |context, cx| {
1161 context.mark_cache_anchors(cache_configuration, false, cx)
1162 });
1163 assert_eq!(
1164 messages_cache(&context, cx)
1165 .iter()
1166 .map(|(_, cache)| cache
1167 .as_ref()
1168 .map_or(None, |cache| Some(cache.status.clone())))
1169 .collect::<Vec<Option<CacheStatus>>>(),
1170 vec![
1171 Some(CacheStatus::Pending),
1172 Some(CacheStatus::Pending),
1173 Some(CacheStatus::Pending),
1174 None
1175 ],
1176 "Modifying a message should invalidate all future messages."
1177 );
1178}
1179
1180fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1181 context
1182 .read(cx)
1183 .messages(cx)
1184 .map(|message| (message.id, message.role, message.offset_range))
1185 .collect()
1186}
1187
1188fn messages_cache(
1189 context: &Entity<AssistantContext>,
1190 cx: &App,
1191) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1192 context
1193 .read(cx)
1194 .messages(cx)
1195 .map(|message| (message.id, message.cache.clone()))
1196 .collect()
1197}
1198
1199fn init_test(cx: &mut App) {
1200 let settings_store = SettingsStore::test(cx);
1201 prompt_store::init(cx);
1202 LanguageModelRegistry::test(cx);
1203 cx.set_global(settings_store);
1204 language::init(cx);
1205 assistant_settings::init(cx);
1206 Project::init_settings(cx);
1207}
1208
1209#[derive(Clone)]
1210struct FakeSlashCommand(String);
1211
1212impl SlashCommand for FakeSlashCommand {
1213 fn name(&self) -> String {
1214 self.0.clone()
1215 }
1216
1217 fn description(&self) -> String {
1218 format!("Fake slash command: {}", self.0)
1219 }
1220
1221 fn menu_text(&self) -> String {
1222 format!("Run fake command: {}", self.0)
1223 }
1224
1225 fn complete_argument(
1226 self: Arc<Self>,
1227 _arguments: &[String],
1228 _cancel: Arc<AtomicBool>,
1229 _workspace: Option<WeakEntity<Workspace>>,
1230 _window: &mut Window,
1231 _cx: &mut App,
1232 ) -> Task<Result<Vec<ArgumentCompletion>>> {
1233 Task::ready(Ok(vec![]))
1234 }
1235
1236 fn requires_argument(&self) -> bool {
1237 false
1238 }
1239
1240 fn run(
1241 self: Arc<Self>,
1242 _arguments: &[String],
1243 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1244 _context_buffer: BufferSnapshot,
1245 _workspace: WeakEntity<Workspace>,
1246 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1247 _window: &mut Window,
1248 _cx: &mut App,
1249 ) -> Task<SlashCommandResult> {
1250 Task::ready(Ok(SlashCommandOutput {
1251 text: format!("Executed fake command: {}", self.0),
1252 sections: vec![],
1253 run_commands_in_text: false,
1254 }
1255 .to_event_stream()))
1256 }
1257}