1use super::{MessageCacheMetadata, WorkflowStepEdit};
2use crate::{
3 assistant_panel, prompt_library, slash_command::file_command, CacheStatus, Context,
4 ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
5 WorkflowStepEditKind,
6};
7use anyhow::Result;
8use assistant_slash_command::{
9 ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
10 SlashCommandRegistry,
11};
12use collections::HashSet;
13use fs::FakeFs;
14use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
15use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
16use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
17use parking_lot::Mutex;
18use project::Project;
19use rand::prelude::*;
20use serde_json::json;
21use settings::SettingsStore;
22use std::{
23 cell::RefCell,
24 env,
25 ops::Range,
26 path::Path,
27 rc::Rc,
28 sync::{atomic::AtomicBool, Arc},
29};
30use text::{network::Network, OffsetRangeExt as _, ReplicaId};
31use ui::{Context as _, WindowContext};
32use unindent::Unindent;
33use util::{
34 test::{generate_marked_text, marked_text_ranges},
35 RandomCharIter,
36};
37use workspace::Workspace;
38
39#[gpui::test]
40fn test_inserting_and_removing_messages(cx: &mut AppContext) {
41 let settings_store = SettingsStore::test(cx);
42 LanguageModelRegistry::test(cx);
43 cx.set_global(settings_store);
44 assistant_panel::init(cx);
45 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
46 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
47 let context =
48 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
49 let buffer = context.read(cx).buffer.clone();
50
51 let message_1 = context.read(cx).message_anchors[0].clone();
52 assert_eq!(
53 messages(&context, cx),
54 vec![(message_1.id, Role::User, 0..0)]
55 );
56
57 let message_2 = context.update(cx, |context, cx| {
58 context
59 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
60 .unwrap()
61 });
62 assert_eq!(
63 messages(&context, cx),
64 vec![
65 (message_1.id, Role::User, 0..1),
66 (message_2.id, Role::Assistant, 1..1)
67 ]
68 );
69
70 buffer.update(cx, |buffer, cx| {
71 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
72 });
73 assert_eq!(
74 messages(&context, cx),
75 vec![
76 (message_1.id, Role::User, 0..2),
77 (message_2.id, Role::Assistant, 2..3)
78 ]
79 );
80
81 let message_3 = context.update(cx, |context, cx| {
82 context
83 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
84 .unwrap()
85 });
86 assert_eq!(
87 messages(&context, cx),
88 vec![
89 (message_1.id, Role::User, 0..2),
90 (message_2.id, Role::Assistant, 2..4),
91 (message_3.id, Role::User, 4..4)
92 ]
93 );
94
95 let message_4 = context.update(cx, |context, cx| {
96 context
97 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
98 .unwrap()
99 });
100 assert_eq!(
101 messages(&context, cx),
102 vec![
103 (message_1.id, Role::User, 0..2),
104 (message_2.id, Role::Assistant, 2..4),
105 (message_4.id, Role::User, 4..5),
106 (message_3.id, Role::User, 5..5),
107 ]
108 );
109
110 buffer.update(cx, |buffer, cx| {
111 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
112 });
113 assert_eq!(
114 messages(&context, cx),
115 vec![
116 (message_1.id, Role::User, 0..2),
117 (message_2.id, Role::Assistant, 2..4),
118 (message_4.id, Role::User, 4..6),
119 (message_3.id, Role::User, 6..7),
120 ]
121 );
122
123 // Deleting across message boundaries merges the messages.
124 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
125 assert_eq!(
126 messages(&context, cx),
127 vec![
128 (message_1.id, Role::User, 0..3),
129 (message_3.id, Role::User, 3..4),
130 ]
131 );
132
133 // Undoing the deletion should also undo the merge.
134 buffer.update(cx, |buffer, cx| buffer.undo(cx));
135 assert_eq!(
136 messages(&context, cx),
137 vec![
138 (message_1.id, Role::User, 0..2),
139 (message_2.id, Role::Assistant, 2..4),
140 (message_4.id, Role::User, 4..6),
141 (message_3.id, Role::User, 6..7),
142 ]
143 );
144
145 // Redoing the deletion should also redo the merge.
146 buffer.update(cx, |buffer, cx| buffer.redo(cx));
147 assert_eq!(
148 messages(&context, cx),
149 vec![
150 (message_1.id, Role::User, 0..3),
151 (message_3.id, Role::User, 3..4),
152 ]
153 );
154
155 // Ensure we can still insert after a merged message.
156 let message_5 = context.update(cx, |context, cx| {
157 context
158 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
159 .unwrap()
160 });
161 assert_eq!(
162 messages(&context, cx),
163 vec![
164 (message_1.id, Role::User, 0..3),
165 (message_5.id, Role::System, 3..4),
166 (message_3.id, Role::User, 4..5)
167 ]
168 );
169}
170
171#[gpui::test]
172fn test_message_splitting(cx: &mut AppContext) {
173 let settings_store = SettingsStore::test(cx);
174 cx.set_global(settings_store);
175 LanguageModelRegistry::test(cx);
176 assistant_panel::init(cx);
177 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
178
179 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
180 let context =
181 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
182 let buffer = context.read(cx).buffer.clone();
183
184 let message_1 = context.read(cx).message_anchors[0].clone();
185 assert_eq!(
186 messages(&context, cx),
187 vec![(message_1.id, Role::User, 0..0)]
188 );
189
190 buffer.update(cx, |buffer, cx| {
191 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
192 });
193
194 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
195 let message_2 = message_2.unwrap();
196
197 // We recycle newlines in the middle of a split message
198 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
199 assert_eq!(
200 messages(&context, cx),
201 vec![
202 (message_1.id, Role::User, 0..4),
203 (message_2.id, Role::User, 4..16),
204 ]
205 );
206
207 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
208 let message_3 = message_3.unwrap();
209
210 // We don't recycle newlines at the end of a split message
211 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
212 assert_eq!(
213 messages(&context, cx),
214 vec![
215 (message_1.id, Role::User, 0..4),
216 (message_3.id, Role::User, 4..5),
217 (message_2.id, Role::User, 5..17),
218 ]
219 );
220
221 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
222 let message_4 = message_4.unwrap();
223 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
224 assert_eq!(
225 messages(&context, cx),
226 vec![
227 (message_1.id, Role::User, 0..4),
228 (message_3.id, Role::User, 4..5),
229 (message_2.id, Role::User, 5..9),
230 (message_4.id, Role::User, 9..17),
231 ]
232 );
233
234 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
235 let message_5 = message_5.unwrap();
236 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
237 assert_eq!(
238 messages(&context, cx),
239 vec![
240 (message_1.id, Role::User, 0..4),
241 (message_3.id, Role::User, 4..5),
242 (message_2.id, Role::User, 5..9),
243 (message_4.id, Role::User, 9..10),
244 (message_5.id, Role::User, 10..18),
245 ]
246 );
247
248 let (message_6, message_7) =
249 context.update(cx, |context, cx| context.split_message(14..16, cx));
250 let message_6 = message_6.unwrap();
251 let message_7 = message_7.unwrap();
252 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
253 assert_eq!(
254 messages(&context, cx),
255 vec![
256 (message_1.id, Role::User, 0..4),
257 (message_3.id, Role::User, 4..5),
258 (message_2.id, Role::User, 5..9),
259 (message_4.id, Role::User, 9..10),
260 (message_5.id, Role::User, 10..14),
261 (message_6.id, Role::User, 14..17),
262 (message_7.id, Role::User, 17..19),
263 ]
264 );
265}
266
267#[gpui::test]
268fn test_messages_for_offsets(cx: &mut AppContext) {
269 let settings_store = SettingsStore::test(cx);
270 LanguageModelRegistry::test(cx);
271 cx.set_global(settings_store);
272 assistant_panel::init(cx);
273 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
274 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
275 let context =
276 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
277 let buffer = context.read(cx).buffer.clone();
278
279 let message_1 = context.read(cx).message_anchors[0].clone();
280 assert_eq!(
281 messages(&context, cx),
282 vec![(message_1.id, Role::User, 0..0)]
283 );
284
285 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
286 let message_2 = context
287 .update(cx, |context, cx| {
288 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
289 })
290 .unwrap();
291 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
292
293 let message_3 = context
294 .update(cx, |context, cx| {
295 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
296 })
297 .unwrap();
298 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
299
300 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
301 assert_eq!(
302 messages(&context, cx),
303 vec![
304 (message_1.id, Role::User, 0..4),
305 (message_2.id, Role::User, 4..8),
306 (message_3.id, Role::User, 8..11)
307 ]
308 );
309
310 assert_eq!(
311 message_ids_for_offsets(&context, &[0, 4, 9], cx),
312 [message_1.id, message_2.id, message_3.id]
313 );
314 assert_eq!(
315 message_ids_for_offsets(&context, &[0, 1, 11], cx),
316 [message_1.id, message_3.id]
317 );
318
319 let message_4 = context
320 .update(cx, |context, cx| {
321 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
322 })
323 .unwrap();
324 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
325 assert_eq!(
326 messages(&context, cx),
327 vec![
328 (message_1.id, Role::User, 0..4),
329 (message_2.id, Role::User, 4..8),
330 (message_3.id, Role::User, 8..12),
331 (message_4.id, Role::User, 12..12)
332 ]
333 );
334 assert_eq!(
335 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
336 [message_1.id, message_2.id, message_3.id, message_4.id]
337 );
338
339 fn message_ids_for_offsets(
340 context: &Model<Context>,
341 offsets: &[usize],
342 cx: &AppContext,
343 ) -> Vec<MessageId> {
344 context
345 .read(cx)
346 .messages_for_offsets(offsets.iter().copied(), cx)
347 .into_iter()
348 .map(|message| message.id)
349 .collect()
350 }
351}
352
353#[gpui::test]
354async fn test_slash_commands(cx: &mut TestAppContext) {
355 let settings_store = cx.update(SettingsStore::test);
356 cx.set_global(settings_store);
357 cx.update(LanguageModelRegistry::test);
358 cx.update(Project::init_settings);
359 cx.update(assistant_panel::init);
360 let fs = FakeFs::new(cx.background_executor.clone());
361
362 fs.insert_tree(
363 "/test",
364 json!({
365 "src": {
366 "lib.rs": "fn one() -> usize { 1 }",
367 "main.rs": "
368 use crate::one;
369 fn main() { one(); }
370 ".unindent(),
371 }
372 }),
373 )
374 .await;
375
376 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
377 slash_command_registry.register_command(file_command::FileSlashCommand, false);
378
379 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
380 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
381 let context =
382 cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
383
384 let output_ranges = Rc::new(RefCell::new(HashSet::default()));
385 context.update(cx, |_, cx| {
386 cx.subscribe(&context, {
387 let ranges = output_ranges.clone();
388 move |_, _, event, _| match event {
389 ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
390 for range in removed {
391 ranges.borrow_mut().remove(range);
392 }
393 for command in updated {
394 ranges.borrow_mut().insert(command.source_range.clone());
395 }
396 }
397 _ => {}
398 }
399 })
400 .detach();
401 });
402
403 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
404
405 // Insert a slash command
406 buffer.update(cx, |buffer, cx| {
407 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
408 });
409 assert_text_and_output_ranges(
410 &buffer,
411 &output_ranges.borrow(),
412 "
413 «/file src/lib.rs»
414 "
415 .unindent()
416 .trim_end(),
417 cx,
418 );
419
420 // Edit the argument of the slash command.
421 buffer.update(cx, |buffer, cx| {
422 let edit_offset = buffer.text().find("lib.rs").unwrap();
423 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
424 });
425 assert_text_and_output_ranges(
426 &buffer,
427 &output_ranges.borrow(),
428 "
429 «/file src/main.rs»
430 "
431 .unindent()
432 .trim_end(),
433 cx,
434 );
435
436 // Edit the name of the slash command, using one that doesn't exist.
437 buffer.update(cx, |buffer, cx| {
438 let edit_offset = buffer.text().find("/file").unwrap();
439 buffer.edit(
440 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
441 None,
442 cx,
443 );
444 });
445 assert_text_and_output_ranges(
446 &buffer,
447 &output_ranges.borrow(),
448 "
449 /unknown src/main.rs
450 "
451 .unindent()
452 .trim_end(),
453 cx,
454 );
455
456 #[track_caller]
457 fn assert_text_and_output_ranges(
458 buffer: &Model<Buffer>,
459 ranges: &HashSet<Range<language::Anchor>>,
460 expected_marked_text: &str,
461 cx: &mut TestAppContext,
462 ) {
463 let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
464 let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
465 let mut ranges = ranges
466 .iter()
467 .map(|range| range.to_offset(buffer))
468 .collect::<Vec<_>>();
469 ranges.sort_by_key(|a| a.start);
470 (buffer.text(), ranges)
471 });
472
473 assert_eq!(actual_text, expected_text);
474 assert_eq!(actual_ranges, expected_ranges);
475 }
476}
477
478#[gpui::test]
479async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
480 cx.update(prompt_library::init);
481 let settings_store = cx.update(SettingsStore::test);
482 cx.set_global(settings_store);
483 cx.update(language::init);
484 cx.update(Project::init_settings);
485 let fs = FakeFs::new(cx.executor());
486 let project = Project::test(fs, [Path::new("/root")], cx).await;
487 cx.update(LanguageModelRegistry::test);
488
489 cx.update(assistant_panel::init);
490 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
491
492 // Create a new context
493 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
494 let context = cx.new_model(|cx| {
495 Context::local(
496 registry.clone(),
497 Some(project),
498 None,
499 prompt_builder.clone(),
500 cx,
501 )
502 });
503
504 // Insert an assistant message to simulate a response.
505 let assistant_message_id = context.update(cx, |context, cx| {
506 let user_message_id = context.messages(cx).next().unwrap().id;
507 context
508 .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
509 .unwrap()
510 .id
511 });
512
513 // No edit tags
514 edit(
515 &context,
516 "
517
518 «one
519 two
520 »",
521 cx,
522 );
523 expect_steps(
524 &context,
525 "
526
527 one
528 two
529 ",
530 &[],
531 cx,
532 );
533
534 // Partial edit step tag is added
535 edit(
536 &context,
537 "
538
539 one
540 two
541 «
542 <step»",
543 cx,
544 );
545 expect_steps(
546 &context,
547 "
548
549 one
550 two
551
552 <step",
553 &[],
554 cx,
555 );
556
557 // The rest of the step tag is added. The unclosed
558 // step is treated as incomplete.
559 edit(
560 &context,
561 "
562
563 one
564 two
565
566 <step«>
567 Add a second function
568
569 ```rust
570 fn two() {}
571 ```
572
573 <edit>»",
574 cx,
575 );
576 expect_steps(
577 &context,
578 "
579
580 one
581 two
582
583 «<step>
584 Add a second function
585
586 ```rust
587 fn two() {}
588 ```
589
590 <edit>»",
591 &[&[]],
592 cx,
593 );
594
595 // The full suggestion is added
596 edit(
597 &context,
598 "
599
600 one
601 two
602
603 <step>
604 Add a second function
605
606 ```rust
607 fn two() {}
608 ```
609
610 <edit>«
611 <path>src/lib.rs</path>
612 <operation>insert_after</operation>
613 <search>fn one</search>
614 <description>add a `two` function</description>
615 </edit>
616 </step>
617
618 also,»",
619 cx,
620 );
621 expect_steps(
622 &context,
623 "
624
625 one
626 two
627
628 «<step>
629 Add a second function
630
631 ```rust
632 fn two() {}
633 ```
634
635 <edit>
636 <path>src/lib.rs</path>
637 <operation>insert_after</operation>
638 <search>fn one</search>
639 <description>add a `two` function</description>
640 </edit>
641 </step>»
642
643 also,",
644 &[&[WorkflowStepEdit {
645 path: "src/lib.rs".into(),
646 kind: WorkflowStepEditKind::InsertAfter {
647 search: "fn one".into(),
648 description: "add a `two` function".into(),
649 },
650 }]],
651 cx,
652 );
653
654 // The step is manually edited.
655 edit(
656 &context,
657 "
658
659 one
660 two
661
662 <step>
663 Add a second function
664
665 ```rust
666 fn two() {}
667 ```
668
669 <edit>
670 <path>src/lib.rs</path>
671 <operation>insert_after</operation>
672 <search>«fn zero»</search>
673 <description>add a `two` function</description>
674 </edit>
675 </step>
676
677 also,",
678 cx,
679 );
680 expect_steps(
681 &context,
682 "
683
684 one
685 two
686
687 «<step>
688 Add a second function
689
690 ```rust
691 fn two() {}
692 ```
693
694 <edit>
695 <path>src/lib.rs</path>
696 <operation>insert_after</operation>
697 <search>fn zero</search>
698 <description>add a `two` function</description>
699 </edit>
700 </step>»
701
702 also,",
703 &[&[WorkflowStepEdit {
704 path: "src/lib.rs".into(),
705 kind: WorkflowStepEditKind::InsertAfter {
706 search: "fn zero".into(),
707 description: "add a `two` function".into(),
708 },
709 }]],
710 cx,
711 );
712
713 // When setting the message role to User, the steps are cleared.
714 context.update(cx, |context, cx| {
715 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
716 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
717 });
718 expect_steps(
719 &context,
720 "
721
722 one
723 two
724
725 <step>
726 Add a second function
727
728 ```rust
729 fn two() {}
730 ```
731
732 <edit>
733 <path>src/lib.rs</path>
734 <operation>insert_after</operation>
735 <search>fn zero</search>
736 <description>add a `two` function</description>
737 </edit>
738 </step>
739
740 also,",
741 &[],
742 cx,
743 );
744
745 // When setting the message role back to Assistant, the steps are reparsed.
746 context.update(cx, |context, cx| {
747 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
748 });
749 expect_steps(
750 &context,
751 "
752
753 one
754 two
755
756 «<step>
757 Add a second function
758
759 ```rust
760 fn two() {}
761 ```
762
763 <edit>
764 <path>src/lib.rs</path>
765 <operation>insert_after</operation>
766 <search>fn zero</search>
767 <description>add a `two` function</description>
768 </edit>
769 </step>»
770
771 also,",
772 &[&[WorkflowStepEdit {
773 path: "src/lib.rs".into(),
774 kind: WorkflowStepEditKind::InsertAfter {
775 search: "fn zero".into(),
776 description: "add a `two` function".into(),
777 },
778 }]],
779 cx,
780 );
781
782 // Ensure steps are re-parsed when deserializing.
783 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
784 let deserialized_context = cx.new_model(|cx| {
785 Context::deserialize(
786 serialized_context,
787 Default::default(),
788 registry.clone(),
789 prompt_builder.clone(),
790 None,
791 None,
792 cx,
793 )
794 });
795 expect_steps(
796 &deserialized_context,
797 "
798
799 one
800 two
801
802 «<step>
803 Add a second function
804
805 ```rust
806 fn two() {}
807 ```
808
809 <edit>
810 <path>src/lib.rs</path>
811 <operation>insert_after</operation>
812 <search>fn zero</search>
813 <description>add a `two` function</description>
814 </edit>
815 </step>»
816
817 also,",
818 &[&[WorkflowStepEdit {
819 path: "src/lib.rs".into(),
820 kind: WorkflowStepEditKind::InsertAfter {
821 search: "fn zero".into(),
822 description: "add a `two` function".into(),
823 },
824 }]],
825 cx,
826 );
827
828 fn edit(context: &Model<Context>, new_text_marked_with_edits: &str, cx: &mut TestAppContext) {
829 context.update(cx, |context, cx| {
830 context.buffer.update(cx, |buffer, cx| {
831 buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
832 });
833 });
834 cx.executor().run_until_parked();
835 }
836
837 fn expect_steps(
838 context: &Model<Context>,
839 expected_marked_text: &str,
840 expected_suggestions: &[&[WorkflowStepEdit]],
841 cx: &mut TestAppContext,
842 ) {
843 context.update(cx, |context, cx| {
844 let expected_marked_text = expected_marked_text.unindent();
845 let (expected_text, expected_ranges) = marked_text_ranges(&expected_marked_text, false);
846 context.buffer.read_with(cx, |buffer, _| {
847 assert_eq!(buffer.text(), expected_text);
848 let ranges = context
849 .workflow_steps
850 .iter()
851 .map(|entry| entry.range.to_offset(buffer))
852 .collect::<Vec<_>>();
853 let marked = generate_marked_text(&expected_text, &ranges, false);
854 assert_eq!(
855 marked,
856 expected_marked_text,
857 "unexpected suggestion ranges. actual: {ranges:?}, expected: {expected_ranges:?}"
858 );
859 let suggestions = context
860 .workflow_steps
861 .iter()
862 .map(|step| {
863 step.edits
864 .iter()
865 .map(|edit| {
866 let edit = edit.as_ref().unwrap();
867 WorkflowStepEdit {
868 path: edit.path.clone(),
869 kind: edit.kind.clone(),
870 }
871 })
872 .collect::<Vec<_>>()
873 })
874 .collect::<Vec<_>>();
875
876 assert_eq!(suggestions, expected_suggestions);
877 });
878 });
879 }
880}
881
882#[gpui::test]
883async fn test_serialization(cx: &mut TestAppContext) {
884 let settings_store = cx.update(SettingsStore::test);
885 cx.set_global(settings_store);
886 cx.update(LanguageModelRegistry::test);
887 cx.update(assistant_panel::init);
888 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
889 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
890 let context =
891 cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
892 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
893 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
894 let message_1 = context.update(cx, |context, cx| {
895 context
896 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
897 .unwrap()
898 });
899 let message_2 = context.update(cx, |context, cx| {
900 context
901 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
902 .unwrap()
903 });
904 buffer.update(cx, |buffer, cx| {
905 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
906 buffer.finalize_last_transaction();
907 });
908 let _message_3 = context.update(cx, |context, cx| {
909 context
910 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
911 .unwrap()
912 });
913 buffer.update(cx, |buffer, cx| buffer.undo(cx));
914 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
915 assert_eq!(
916 cx.read(|cx| messages(&context, cx)),
917 [
918 (message_0, Role::User, 0..2),
919 (message_1.id, Role::Assistant, 2..6),
920 (message_2.id, Role::System, 6..6),
921 ]
922 );
923
924 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
925 let deserialized_context = cx.new_model(|cx| {
926 Context::deserialize(
927 serialized_context,
928 Default::default(),
929 registry.clone(),
930 prompt_builder.clone(),
931 None,
932 None,
933 cx,
934 )
935 });
936 let deserialized_buffer =
937 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
938 assert_eq!(
939 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
940 "a\nb\nc\n"
941 );
942 assert_eq!(
943 cx.read(|cx| messages(&deserialized_context, cx)),
944 [
945 (message_0, Role::User, 0..2),
946 (message_1.id, Role::Assistant, 2..6),
947 (message_2.id, Role::System, 6..6),
948 ]
949 );
950}
951
952#[gpui::test(iterations = 100)]
953async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
954 let min_peers = env::var("MIN_PEERS")
955 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
956 .unwrap_or(2);
957 let max_peers = env::var("MAX_PEERS")
958 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
959 .unwrap_or(5);
960 let operations = env::var("OPERATIONS")
961 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
962 .unwrap_or(50);
963
964 let settings_store = cx.update(SettingsStore::test);
965 cx.set_global(settings_store);
966 cx.update(LanguageModelRegistry::test);
967
968 cx.update(assistant_panel::init);
969 let slash_commands = cx.update(SlashCommandRegistry::default_global);
970 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
971 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
972 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
973
974 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
975 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
976 let mut contexts = Vec::new();
977
978 let num_peers = rng.gen_range(min_peers..=max_peers);
979 let context_id = ContextId::new();
980 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
981 for i in 0..num_peers {
982 let context = cx.new_model(|cx| {
983 Context::new(
984 context_id.clone(),
985 i as ReplicaId,
986 language::Capability::ReadWrite,
987 registry.clone(),
988 prompt_builder.clone(),
989 None,
990 None,
991 cx,
992 )
993 });
994
995 cx.update(|cx| {
996 cx.subscribe(&context, {
997 let network = network.clone();
998 move |_, event, _| {
999 if let ContextEvent::Operation(op) = event {
1000 network
1001 .lock()
1002 .broadcast(i as ReplicaId, vec![op.to_proto()]);
1003 }
1004 }
1005 })
1006 .detach();
1007 });
1008
1009 contexts.push(context);
1010 network.lock().add_peer(i as ReplicaId);
1011 }
1012
1013 let mut mutation_count = operations;
1014
1015 while mutation_count > 0
1016 || !network.lock().is_idle()
1017 || network.lock().contains_disconnected_peers()
1018 {
1019 let context_index = rng.gen_range(0..contexts.len());
1020 let context = &contexts[context_index];
1021
1022 match rng.gen_range(0..100) {
1023 0..=29 if mutation_count > 0 => {
1024 log::info!("Context {}: edit buffer", context_index);
1025 context.update(cx, |context, cx| {
1026 context
1027 .buffer
1028 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1029 });
1030 mutation_count -= 1;
1031 }
1032 30..=44 if mutation_count > 0 => {
1033 context.update(cx, |context, cx| {
1034 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1035 log::info!("Context {}: split message at {:?}", context_index, range);
1036 context.split_message(range, cx);
1037 });
1038 mutation_count -= 1;
1039 }
1040 45..=59 if mutation_count > 0 => {
1041 context.update(cx, |context, cx| {
1042 if let Some(message) = context.messages(cx).choose(&mut rng) {
1043 let role = *[Role::User, Role::Assistant, Role::System]
1044 .choose(&mut rng)
1045 .unwrap();
1046 log::info!(
1047 "Context {}: insert message after {:?} with {:?}",
1048 context_index,
1049 message.id,
1050 role
1051 );
1052 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1053 }
1054 });
1055 mutation_count -= 1;
1056 }
1057 60..=74 if mutation_count > 0 => {
1058 context.update(cx, |context, cx| {
1059 let command_text = "/".to_string()
1060 + slash_commands
1061 .command_names()
1062 .choose(&mut rng)
1063 .unwrap()
1064 .clone()
1065 .as_ref();
1066
1067 let command_range = context.buffer.update(cx, |buffer, cx| {
1068 let offset = buffer.random_byte_range(0, &mut rng).start;
1069 buffer.edit(
1070 [(offset..offset, format!("\n{}\n", command_text))],
1071 None,
1072 cx,
1073 );
1074 offset + 1..offset + 1 + command_text.len()
1075 });
1076
1077 let output_len = rng.gen_range(1..=10);
1078 let output_text = RandomCharIter::new(&mut rng)
1079 .filter(|c| *c != '\r')
1080 .take(output_len)
1081 .collect::<String>();
1082
1083 let num_sections = rng.gen_range(0..=3);
1084 let mut sections = Vec::with_capacity(num_sections);
1085 for _ in 0..num_sections {
1086 let section_start = rng.gen_range(0..output_len);
1087 let section_end = rng.gen_range(section_start..=output_len);
1088 sections.push(SlashCommandOutputSection {
1089 range: section_start..section_end,
1090 icon: ui::IconName::Ai,
1091 label: "section".into(),
1092 metadata: None,
1093 });
1094 }
1095
1096 log::info!(
1097 "Context {}: insert slash command output at {:?} with {:?}",
1098 context_index,
1099 command_range,
1100 sections
1101 );
1102
1103 let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1104 ..context.buffer.read(cx).anchor_after(command_range.end);
1105 context.insert_command_output(
1106 command_range,
1107 Task::ready(Ok(SlashCommandOutput {
1108 text: output_text,
1109 sections,
1110 run_commands_in_text: false,
1111 })),
1112 true,
1113 false,
1114 cx,
1115 );
1116 });
1117 cx.run_until_parked();
1118 mutation_count -= 1;
1119 }
1120 75..=84 if mutation_count > 0 => {
1121 context.update(cx, |context, cx| {
1122 if let Some(message) = context.messages(cx).choose(&mut rng) {
1123 let new_status = match rng.gen_range(0..3) {
1124 0 => MessageStatus::Done,
1125 1 => MessageStatus::Pending,
1126 _ => MessageStatus::Error(SharedString::from("Random error")),
1127 };
1128 log::info!(
1129 "Context {}: update message {:?} status to {:?}",
1130 context_index,
1131 message.id,
1132 new_status
1133 );
1134 context.update_metadata(message.id, cx, |metadata| {
1135 metadata.status = new_status;
1136 });
1137 }
1138 });
1139 mutation_count -= 1;
1140 }
1141 _ => {
1142 let replica_id = context_index as ReplicaId;
1143 if network.lock().is_disconnected(replica_id) {
1144 network.lock().reconnect_peer(replica_id, 0);
1145
1146 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1147 let host_context = &contexts[0].read(cx);
1148 let guest_context = context.read(cx);
1149 (
1150 guest_context.serialize_ops(&host_context.version(cx), cx),
1151 host_context.serialize_ops(&guest_context.version(cx), cx),
1152 )
1153 });
1154 let ops_to_send = ops_to_send.await;
1155 let ops_to_receive = ops_to_receive
1156 .await
1157 .into_iter()
1158 .map(ContextOperation::from_proto)
1159 .collect::<Result<Vec<_>>>()
1160 .unwrap();
1161 log::info!(
1162 "Context {}: reconnecting. Sent {} operations, received {} operations",
1163 context_index,
1164 ops_to_send.len(),
1165 ops_to_receive.len()
1166 );
1167
1168 network.lock().broadcast(replica_id, ops_to_send);
1169 context
1170 .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx))
1171 .unwrap();
1172 } else if rng.gen_bool(0.1) && replica_id != 0 {
1173 log::info!("Context {}: disconnecting", context_index);
1174 network.lock().disconnect_peer(replica_id);
1175 } else if network.lock().has_unreceived(replica_id) {
1176 log::info!("Context {}: applying operations", context_index);
1177 let ops = network.lock().receive(replica_id);
1178 let ops = ops
1179 .into_iter()
1180 .map(ContextOperation::from_proto)
1181 .collect::<Result<Vec<_>>>()
1182 .unwrap();
1183 context
1184 .update(cx, |context, cx| context.apply_ops(ops, cx))
1185 .unwrap();
1186 }
1187 }
1188 }
1189 }
1190
1191 cx.read(|cx| {
1192 let first_context = contexts[0].read(cx);
1193 for context in &contexts[1..] {
1194 let context = context.read(cx);
1195 assert!(context.pending_ops.is_empty());
1196 assert_eq!(
1197 context.buffer.read(cx).text(),
1198 first_context.buffer.read(cx).text(),
1199 "Context {} text != Context 0 text",
1200 context.buffer.read(cx).replica_id()
1201 );
1202 assert_eq!(
1203 context.message_anchors,
1204 first_context.message_anchors,
1205 "Context {} messages != Context 0 messages",
1206 context.buffer.read(cx).replica_id()
1207 );
1208 assert_eq!(
1209 context.messages_metadata,
1210 first_context.messages_metadata,
1211 "Context {} message metadata != Context 0 message metadata",
1212 context.buffer.read(cx).replica_id()
1213 );
1214 assert_eq!(
1215 context.slash_command_output_sections,
1216 first_context.slash_command_output_sections,
1217 "Context {} slash command output sections != Context 0 slash command output sections",
1218 context.buffer.read(cx).replica_id()
1219 );
1220 }
1221 });
1222}
1223
1224#[gpui::test]
1225fn test_mark_cache_anchors(cx: &mut AppContext) {
1226 let settings_store = SettingsStore::test(cx);
1227 LanguageModelRegistry::test(cx);
1228 cx.set_global(settings_store);
1229 assistant_panel::init(cx);
1230 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1231 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1232 let context =
1233 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
1234 let buffer = context.read(cx).buffer.clone();
1235
1236 // Create a test cache configuration
1237 let cache_configuration = &Some(LanguageModelCacheConfiguration {
1238 max_cache_anchors: 3,
1239 should_speculate: true,
1240 min_total_token: 10,
1241 });
1242
1243 let message_1 = context.read(cx).message_anchors[0].clone();
1244
1245 context.update(cx, |context, cx| {
1246 context.mark_cache_anchors(cache_configuration, false, cx)
1247 });
1248
1249 assert_eq!(
1250 messages_cache(&context, cx)
1251 .iter()
1252 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1253 .count(),
1254 0,
1255 "Empty messages should not have any cache anchors."
1256 );
1257
1258 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1259 let message_2 = context
1260 .update(cx, |context, cx| {
1261 context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1262 })
1263 .unwrap();
1264
1265 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1266 let message_3 = context
1267 .update(cx, |context, cx| {
1268 context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1269 })
1270 .unwrap();
1271 buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1272
1273 context.update(cx, |context, cx| {
1274 context.mark_cache_anchors(cache_configuration, false, cx)
1275 });
1276 assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1277 assert_eq!(
1278 messages_cache(&context, cx)
1279 .iter()
1280 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1281 .count(),
1282 0,
1283 "Messages should not be marked for cache before going over the token minimum."
1284 );
1285 context.update(cx, |context, _| {
1286 context.token_count = Some(20);
1287 });
1288
1289 context.update(cx, |context, cx| {
1290 context.mark_cache_anchors(cache_configuration, true, cx)
1291 });
1292 assert_eq!(
1293 messages_cache(&context, cx)
1294 .iter()
1295 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1296 .collect::<Vec<bool>>(),
1297 vec![true, true, false],
1298 "Last message should not be an anchor on speculative request."
1299 );
1300
1301 context
1302 .update(cx, |context, cx| {
1303 context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1304 })
1305 .unwrap();
1306
1307 context.update(cx, |context, cx| {
1308 context.mark_cache_anchors(cache_configuration, false, cx)
1309 });
1310 assert_eq!(
1311 messages_cache(&context, cx)
1312 .iter()
1313 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1314 .collect::<Vec<bool>>(),
1315 vec![false, true, true, false],
1316 "Most recent message should also be cached if not a speculative request."
1317 );
1318 context.update(cx, |context, cx| {
1319 context.update_cache_status_for_completion(cx)
1320 });
1321 assert_eq!(
1322 messages_cache(&context, cx)
1323 .iter()
1324 .map(|(_, cache)| cache
1325 .as_ref()
1326 .map_or(None, |cache| Some(cache.status.clone())))
1327 .collect::<Vec<Option<CacheStatus>>>(),
1328 vec![
1329 Some(CacheStatus::Cached),
1330 Some(CacheStatus::Cached),
1331 Some(CacheStatus::Cached),
1332 None
1333 ],
1334 "All user messages prior to anchor should be marked as cached."
1335 );
1336
1337 buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1338 context.update(cx, |context, cx| {
1339 context.mark_cache_anchors(cache_configuration, false, cx)
1340 });
1341 assert_eq!(
1342 messages_cache(&context, cx)
1343 .iter()
1344 .map(|(_, cache)| cache
1345 .as_ref()
1346 .map_or(None, |cache| Some(cache.status.clone())))
1347 .collect::<Vec<Option<CacheStatus>>>(),
1348 vec![
1349 Some(CacheStatus::Cached),
1350 Some(CacheStatus::Cached),
1351 Some(CacheStatus::Pending),
1352 None
1353 ],
1354 "Modifying a message should invalidate it's cache but leave previous messages."
1355 );
1356 buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1357 context.update(cx, |context, cx| {
1358 context.mark_cache_anchors(cache_configuration, false, cx)
1359 });
1360 assert_eq!(
1361 messages_cache(&context, cx)
1362 .iter()
1363 .map(|(_, cache)| cache
1364 .as_ref()
1365 .map_or(None, |cache| Some(cache.status.clone())))
1366 .collect::<Vec<Option<CacheStatus>>>(),
1367 vec![
1368 Some(CacheStatus::Pending),
1369 Some(CacheStatus::Pending),
1370 Some(CacheStatus::Pending),
1371 None
1372 ],
1373 "Modifying a message should invalidate all future messages."
1374 );
1375}
1376
1377fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
1378 context
1379 .read(cx)
1380 .messages(cx)
1381 .map(|message| (message.id, message.role, message.offset_range))
1382 .collect()
1383}
1384
1385fn messages_cache(
1386 context: &Model<Context>,
1387 cx: &AppContext,
1388) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1389 context
1390 .read(cx)
1391 .messages(cx)
1392 .map(|message| (message.id, message.cache.clone()))
1393 .collect()
1394}
1395
1396#[derive(Clone)]
1397struct FakeSlashCommand(String);
1398
1399impl SlashCommand for FakeSlashCommand {
1400 fn name(&self) -> String {
1401 self.0.clone()
1402 }
1403
1404 fn description(&self) -> String {
1405 format!("Fake slash command: {}", self.0)
1406 }
1407
1408 fn menu_text(&self) -> String {
1409 format!("Run fake command: {}", self.0)
1410 }
1411
1412 fn complete_argument(
1413 self: Arc<Self>,
1414 _arguments: &[String],
1415 _cancel: Arc<AtomicBool>,
1416 _workspace: Option<WeakView<Workspace>>,
1417 _cx: &mut WindowContext,
1418 ) -> Task<Result<Vec<ArgumentCompletion>>> {
1419 Task::ready(Ok(vec![]))
1420 }
1421
1422 fn requires_argument(&self) -> bool {
1423 false
1424 }
1425
1426 fn run(
1427 self: Arc<Self>,
1428 _arguments: &[String],
1429 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1430 _context_buffer: BufferSnapshot,
1431 _workspace: WeakView<Workspace>,
1432 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1433 _cx: &mut WindowContext,
1434 ) -> Task<Result<SlashCommandOutput>> {
1435 Task::ready(Ok(SlashCommandOutput {
1436 text: format!("Executed fake command: {}", self.0),
1437 sections: vec![],
1438 run_commands_in_text: false,
1439 }))
1440 }
1441}