1use crate::{
2 AssistantContext, AssistantEdit, AssistantEditKind, CacheStatus, ContextEvent, ContextId,
3 ContextOperation, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
4};
5use anyhow::Result;
6use assistant_slash_command::{
7 ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
8 SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
9};
10use assistant_slash_commands::FileSlashCommand;
11use collections::{HashMap, HashSet};
12use fs::FakeFs;
13use futures::{
14 channel::mpsc,
15 stream::{self, StreamExt},
16};
17use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
18use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
19use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
20use parking_lot::Mutex;
21use pretty_assertions::assert_eq;
22use project::Project;
23use prompt_store::PromptBuilder;
24use rand::prelude::*;
25use serde_json::json;
26use settings::SettingsStore;
27use std::{
28 cell::RefCell,
29 env,
30 ops::Range,
31 path::Path,
32 rc::Rc,
33 sync::{Arc, atomic::AtomicBool},
34};
35use text::{OffsetRangeExt as _, ReplicaId, ToOffset, network::Network};
36use ui::{IconName, Window};
37use unindent::Unindent;
38use util::{
39 RandomCharIter,
40 test::{generate_marked_text, marked_text_ranges},
41};
42use workspace::Workspace;
43
44#[gpui::test]
45fn test_inserting_and_removing_messages(cx: &mut App) {
46 let settings_store = SettingsStore::test(cx);
47 LanguageModelRegistry::test(cx);
48 cx.set_global(settings_store);
49 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
50 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
51 let context = cx.new(|cx| {
52 AssistantContext::local(
53 registry,
54 None,
55 None,
56 prompt_builder.clone(),
57 Arc::new(SlashCommandWorkingSet::default()),
58 cx,
59 )
60 });
61 let buffer = context.read(cx).buffer.clone();
62
63 let message_1 = context.read(cx).message_anchors[0].clone();
64 assert_eq!(
65 messages(&context, cx),
66 vec![(message_1.id, Role::User, 0..0)]
67 );
68
69 let message_2 = context.update(cx, |context, cx| {
70 context
71 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
72 .unwrap()
73 });
74 assert_eq!(
75 messages(&context, cx),
76 vec![
77 (message_1.id, Role::User, 0..1),
78 (message_2.id, Role::Assistant, 1..1)
79 ]
80 );
81
82 buffer.update(cx, |buffer, cx| {
83 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
84 });
85 assert_eq!(
86 messages(&context, cx),
87 vec![
88 (message_1.id, Role::User, 0..2),
89 (message_2.id, Role::Assistant, 2..3)
90 ]
91 );
92
93 let message_3 = context.update(cx, |context, cx| {
94 context
95 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
96 .unwrap()
97 });
98 assert_eq!(
99 messages(&context, cx),
100 vec![
101 (message_1.id, Role::User, 0..2),
102 (message_2.id, Role::Assistant, 2..4),
103 (message_3.id, Role::User, 4..4)
104 ]
105 );
106
107 let message_4 = context.update(cx, |context, cx| {
108 context
109 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
110 .unwrap()
111 });
112 assert_eq!(
113 messages(&context, cx),
114 vec![
115 (message_1.id, Role::User, 0..2),
116 (message_2.id, Role::Assistant, 2..4),
117 (message_4.id, Role::User, 4..5),
118 (message_3.id, Role::User, 5..5),
119 ]
120 );
121
122 buffer.update(cx, |buffer, cx| {
123 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
124 });
125 assert_eq!(
126 messages(&context, cx),
127 vec![
128 (message_1.id, Role::User, 0..2),
129 (message_2.id, Role::Assistant, 2..4),
130 (message_4.id, Role::User, 4..6),
131 (message_3.id, Role::User, 6..7),
132 ]
133 );
134
135 // Deleting across message boundaries merges the messages.
136 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
137 assert_eq!(
138 messages(&context, cx),
139 vec![
140 (message_1.id, Role::User, 0..3),
141 (message_3.id, Role::User, 3..4),
142 ]
143 );
144
145 // Undoing the deletion should also undo the merge.
146 buffer.update(cx, |buffer, cx| buffer.undo(cx));
147 assert_eq!(
148 messages(&context, cx),
149 vec![
150 (message_1.id, Role::User, 0..2),
151 (message_2.id, Role::Assistant, 2..4),
152 (message_4.id, Role::User, 4..6),
153 (message_3.id, Role::User, 6..7),
154 ]
155 );
156
157 // Redoing the deletion should also redo the merge.
158 buffer.update(cx, |buffer, cx| buffer.redo(cx));
159 assert_eq!(
160 messages(&context, cx),
161 vec![
162 (message_1.id, Role::User, 0..3),
163 (message_3.id, Role::User, 3..4),
164 ]
165 );
166
167 // Ensure we can still insert after a merged message.
168 let message_5 = context.update(cx, |context, cx| {
169 context
170 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
171 .unwrap()
172 });
173 assert_eq!(
174 messages(&context, cx),
175 vec![
176 (message_1.id, Role::User, 0..3),
177 (message_5.id, Role::System, 3..4),
178 (message_3.id, Role::User, 4..5)
179 ]
180 );
181}
182
183#[gpui::test]
184fn test_message_splitting(cx: &mut App) {
185 let settings_store = SettingsStore::test(cx);
186 cx.set_global(settings_store);
187 LanguageModelRegistry::test(cx);
188 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
189
190 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
191 let context = cx.new(|cx| {
192 AssistantContext::local(
193 registry.clone(),
194 None,
195 None,
196 prompt_builder.clone(),
197 Arc::new(SlashCommandWorkingSet::default()),
198 cx,
199 )
200 });
201 let buffer = context.read(cx).buffer.clone();
202
203 let message_1 = context.read(cx).message_anchors[0].clone();
204 assert_eq!(
205 messages(&context, cx),
206 vec![(message_1.id, Role::User, 0..0)]
207 );
208
209 buffer.update(cx, |buffer, cx| {
210 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
211 });
212
213 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
214 let message_2 = message_2.unwrap();
215
216 // We recycle newlines in the middle of a split message
217 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
218 assert_eq!(
219 messages(&context, cx),
220 vec![
221 (message_1.id, Role::User, 0..4),
222 (message_2.id, Role::User, 4..16),
223 ]
224 );
225
226 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
227 let message_3 = message_3.unwrap();
228
229 // We don't recycle newlines at the end of a split message
230 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
231 assert_eq!(
232 messages(&context, cx),
233 vec![
234 (message_1.id, Role::User, 0..4),
235 (message_3.id, Role::User, 4..5),
236 (message_2.id, Role::User, 5..17),
237 ]
238 );
239
240 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
241 let message_4 = message_4.unwrap();
242 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
243 assert_eq!(
244 messages(&context, cx),
245 vec![
246 (message_1.id, Role::User, 0..4),
247 (message_3.id, Role::User, 4..5),
248 (message_2.id, Role::User, 5..9),
249 (message_4.id, Role::User, 9..17),
250 ]
251 );
252
253 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
254 let message_5 = message_5.unwrap();
255 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
256 assert_eq!(
257 messages(&context, cx),
258 vec![
259 (message_1.id, Role::User, 0..4),
260 (message_3.id, Role::User, 4..5),
261 (message_2.id, Role::User, 5..9),
262 (message_4.id, Role::User, 9..10),
263 (message_5.id, Role::User, 10..18),
264 ]
265 );
266
267 let (message_6, message_7) =
268 context.update(cx, |context, cx| context.split_message(14..16, cx));
269 let message_6 = message_6.unwrap();
270 let message_7 = message_7.unwrap();
271 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
272 assert_eq!(
273 messages(&context, cx),
274 vec![
275 (message_1.id, Role::User, 0..4),
276 (message_3.id, Role::User, 4..5),
277 (message_2.id, Role::User, 5..9),
278 (message_4.id, Role::User, 9..10),
279 (message_5.id, Role::User, 10..14),
280 (message_6.id, Role::User, 14..17),
281 (message_7.id, Role::User, 17..19),
282 ]
283 );
284}
285
286#[gpui::test]
287fn test_messages_for_offsets(cx: &mut App) {
288 let settings_store = SettingsStore::test(cx);
289 LanguageModelRegistry::test(cx);
290 cx.set_global(settings_store);
291 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
292 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
293 let context = cx.new(|cx| {
294 AssistantContext::local(
295 registry,
296 None,
297 None,
298 prompt_builder.clone(),
299 Arc::new(SlashCommandWorkingSet::default()),
300 cx,
301 )
302 });
303 let buffer = context.read(cx).buffer.clone();
304
305 let message_1 = context.read(cx).message_anchors[0].clone();
306 assert_eq!(
307 messages(&context, cx),
308 vec![(message_1.id, Role::User, 0..0)]
309 );
310
311 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
312 let message_2 = context
313 .update(cx, |context, cx| {
314 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
315 })
316 .unwrap();
317 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
318
319 let message_3 = context
320 .update(cx, |context, cx| {
321 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
322 })
323 .unwrap();
324 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
325
326 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
327 assert_eq!(
328 messages(&context, cx),
329 vec![
330 (message_1.id, Role::User, 0..4),
331 (message_2.id, Role::User, 4..8),
332 (message_3.id, Role::User, 8..11)
333 ]
334 );
335
336 assert_eq!(
337 message_ids_for_offsets(&context, &[0, 4, 9], cx),
338 [message_1.id, message_2.id, message_3.id]
339 );
340 assert_eq!(
341 message_ids_for_offsets(&context, &[0, 1, 11], cx),
342 [message_1.id, message_3.id]
343 );
344
345 let message_4 = context
346 .update(cx, |context, cx| {
347 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
348 })
349 .unwrap();
350 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
351 assert_eq!(
352 messages(&context, cx),
353 vec![
354 (message_1.id, Role::User, 0..4),
355 (message_2.id, Role::User, 4..8),
356 (message_3.id, Role::User, 8..12),
357 (message_4.id, Role::User, 12..12)
358 ]
359 );
360 assert_eq!(
361 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
362 [message_1.id, message_2.id, message_3.id, message_4.id]
363 );
364
365 fn message_ids_for_offsets(
366 context: &Entity<AssistantContext>,
367 offsets: &[usize],
368 cx: &App,
369 ) -> Vec<MessageId> {
370 context
371 .read(cx)
372 .messages_for_offsets(offsets.iter().copied(), cx)
373 .into_iter()
374 .map(|message| message.id)
375 .collect()
376 }
377}
378
379#[gpui::test]
380async fn test_slash_commands(cx: &mut TestAppContext) {
381 let settings_store = cx.update(SettingsStore::test);
382 cx.set_global(settings_store);
383 cx.update(LanguageModelRegistry::test);
384 cx.update(Project::init_settings);
385 let fs = FakeFs::new(cx.background_executor.clone());
386
387 fs.insert_tree(
388 "/test",
389 json!({
390 "src": {
391 "lib.rs": "fn one() -> usize { 1 }",
392 "main.rs": "
393 use crate::one;
394 fn main() { one(); }
395 ".unindent(),
396 }
397 }),
398 )
399 .await;
400
401 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
402 slash_command_registry.register_command(FileSlashCommand, false);
403
404 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
405 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
406 let context = cx.new(|cx| {
407 AssistantContext::local(
408 registry.clone(),
409 None,
410 None,
411 prompt_builder.clone(),
412 Arc::new(SlashCommandWorkingSet::default()),
413 cx,
414 )
415 });
416
417 #[derive(Default)]
418 struct ContextRanges {
419 parsed_commands: HashSet<Range<language::Anchor>>,
420 command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
421 output_sections: HashSet<Range<language::Anchor>>,
422 }
423
424 let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
425 context.update(cx, |_, cx| {
426 cx.subscribe(&context, {
427 let context_ranges = context_ranges.clone();
428 move |context, _, event, _| {
429 let mut context_ranges = context_ranges.borrow_mut();
430 match event {
431 ContextEvent::InvokedSlashCommandChanged { command_id } => {
432 let command = context.invoked_slash_command(command_id).unwrap();
433 context_ranges
434 .command_outputs
435 .insert(*command_id, command.range.clone());
436 }
437 ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
438 for range in removed {
439 context_ranges.parsed_commands.remove(range);
440 }
441 for command in updated {
442 context_ranges
443 .parsed_commands
444 .insert(command.source_range.clone());
445 }
446 }
447 ContextEvent::SlashCommandOutputSectionAdded { section } => {
448 context_ranges.output_sections.insert(section.range.clone());
449 }
450 _ => {}
451 }
452 }
453 })
454 .detach();
455 });
456
457 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
458
459 // Insert a slash command
460 buffer.update(cx, |buffer, cx| {
461 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
462 });
463 assert_text_and_context_ranges(
464 &buffer,
465 &context_ranges,
466 &"
467 «/file src/lib.rs»"
468 .unindent(),
469 cx,
470 );
471
472 // Edit the argument of the slash command.
473 buffer.update(cx, |buffer, cx| {
474 let edit_offset = buffer.text().find("lib.rs").unwrap();
475 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
476 });
477 assert_text_and_context_ranges(
478 &buffer,
479 &context_ranges,
480 &"
481 «/file src/main.rs»"
482 .unindent(),
483 cx,
484 );
485
486 // Edit the name of the slash command, using one that doesn't exist.
487 buffer.update(cx, |buffer, cx| {
488 let edit_offset = buffer.text().find("/file").unwrap();
489 buffer.edit(
490 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
491 None,
492 cx,
493 );
494 });
495 assert_text_and_context_ranges(
496 &buffer,
497 &context_ranges,
498 &"
499 /unknown src/main.rs"
500 .unindent(),
501 cx,
502 );
503
504 // Undoing the insertion of an non-existent slash command resorts the previous one.
505 buffer.update(cx, |buffer, cx| buffer.undo(cx));
506 assert_text_and_context_ranges(
507 &buffer,
508 &context_ranges,
509 &"
510 «/file src/main.rs»"
511 .unindent(),
512 cx,
513 );
514
515 let (command_output_tx, command_output_rx) = mpsc::unbounded();
516 context.update(cx, |context, cx| {
517 let command_source_range = context.parsed_slash_commands[0].source_range.clone();
518 context.insert_command_output(
519 command_source_range,
520 "file",
521 Task::ready(Ok(command_output_rx.boxed())),
522 true,
523 cx,
524 );
525 });
526 assert_text_and_context_ranges(
527 &buffer,
528 &context_ranges,
529 &"
530 ⟦«/file src/main.rs»
531 …⟧
532 "
533 .unindent(),
534 cx,
535 );
536
537 command_output_tx
538 .unbounded_send(Ok(SlashCommandEvent::StartSection {
539 icon: IconName::Ai,
540 label: "src/main.rs".into(),
541 metadata: None,
542 }))
543 .unwrap();
544 command_output_tx
545 .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
546 .unwrap();
547 cx.run_until_parked();
548 assert_text_and_context_ranges(
549 &buffer,
550 &context_ranges,
551 &"
552 ⟦«/file src/main.rs»
553 src/main.rs…⟧
554 "
555 .unindent(),
556 cx,
557 );
558
559 command_output_tx
560 .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
561 .unwrap();
562 cx.run_until_parked();
563 assert_text_and_context_ranges(
564 &buffer,
565 &context_ranges,
566 &"
567 ⟦«/file src/main.rs»
568 src/main.rs
569 fn main() {}…⟧
570 "
571 .unindent(),
572 cx,
573 );
574
575 command_output_tx
576 .unbounded_send(Ok(SlashCommandEvent::EndSection))
577 .unwrap();
578 cx.run_until_parked();
579 assert_text_and_context_ranges(
580 &buffer,
581 &context_ranges,
582 &"
583 ⟦«/file src/main.rs»
584 ⟪src/main.rs
585 fn main() {}⟫…⟧
586 "
587 .unindent(),
588 cx,
589 );
590
591 drop(command_output_tx);
592 cx.run_until_parked();
593 assert_text_and_context_ranges(
594 &buffer,
595 &context_ranges,
596 &"
597 ⟦⟪src/main.rs
598 fn main() {}⟫⟧
599 "
600 .unindent(),
601 cx,
602 );
603
604 #[track_caller]
605 fn assert_text_and_context_ranges(
606 buffer: &Entity<Buffer>,
607 ranges: &RefCell<ContextRanges>,
608 expected_marked_text: &str,
609 cx: &mut TestAppContext,
610 ) {
611 let mut actual_marked_text = String::new();
612 buffer.update(cx, |buffer, _| {
613 struct Endpoint {
614 offset: usize,
615 marker: char,
616 }
617
618 let ranges = ranges.borrow();
619 let mut endpoints = Vec::new();
620 for range in ranges.command_outputs.values() {
621 endpoints.push(Endpoint {
622 offset: range.start.to_offset(buffer),
623 marker: '⟦',
624 });
625 }
626 for range in ranges.parsed_commands.iter() {
627 endpoints.push(Endpoint {
628 offset: range.start.to_offset(buffer),
629 marker: '«',
630 });
631 }
632 for range in ranges.output_sections.iter() {
633 endpoints.push(Endpoint {
634 offset: range.start.to_offset(buffer),
635 marker: '⟪',
636 });
637 }
638
639 for range in ranges.output_sections.iter() {
640 endpoints.push(Endpoint {
641 offset: range.end.to_offset(buffer),
642 marker: '⟫',
643 });
644 }
645 for range in ranges.parsed_commands.iter() {
646 endpoints.push(Endpoint {
647 offset: range.end.to_offset(buffer),
648 marker: '»',
649 });
650 }
651 for range in ranges.command_outputs.values() {
652 endpoints.push(Endpoint {
653 offset: range.end.to_offset(buffer),
654 marker: '⟧',
655 });
656 }
657
658 endpoints.sort_by_key(|endpoint| endpoint.offset);
659 let mut offset = 0;
660 for endpoint in endpoints {
661 actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
662 actual_marked_text.push(endpoint.marker);
663 offset = endpoint.offset;
664 }
665 actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
666 });
667
668 assert_eq!(actual_marked_text, expected_marked_text);
669 }
670}
671
672#[gpui::test]
673async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
674 cx.update(prompt_store::init);
675 let mut settings_store = cx.update(SettingsStore::test);
676 cx.update(|cx| {
677 settings_store
678 .set_user_settings(
679 r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
680 cx,
681 )
682 .unwrap()
683 });
684 cx.set_global(settings_store);
685 cx.update(language::init);
686 cx.update(Project::init_settings);
687 let fs = FakeFs::new(cx.executor());
688 let project = Project::test(fs, [Path::new("/root")], cx).await;
689 cx.update(LanguageModelRegistry::test);
690
691 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
692
693 // Create a new context
694 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
695 let context = cx.new(|cx| {
696 AssistantContext::local(
697 registry.clone(),
698 Some(project),
699 None,
700 prompt_builder.clone(),
701 Arc::new(SlashCommandWorkingSet::default()),
702 cx,
703 )
704 });
705
706 // Insert an assistant message to simulate a response.
707 let assistant_message_id = context.update(cx, |context, cx| {
708 let user_message_id = context.messages(cx).next().unwrap().id;
709 context
710 .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
711 .unwrap()
712 .id
713 });
714
715 // No edit tags
716 edit(
717 &context,
718 "
719
720 «one
721 two
722 »",
723 cx,
724 );
725 expect_patches(
726 &context,
727 "
728
729 one
730 two
731 ",
732 &[],
733 cx,
734 );
735
736 // Partial edit step tag is added
737 edit(
738 &context,
739 "
740
741 one
742 two
743 «
744 <patch»",
745 cx,
746 );
747 expect_patches(
748 &context,
749 "
750
751 one
752 two
753
754 <patch",
755 &[],
756 cx,
757 );
758
759 // The rest of the step tag is added. The unclosed
760 // step is treated as incomplete.
761 edit(
762 &context,
763 "
764
765 one
766 two
767
768 <patch«>
769 <edit>»",
770 cx,
771 );
772 expect_patches(
773 &context,
774 "
775
776 one
777 two
778
779 «<patch>
780 <edit>»",
781 &[&[]],
782 cx,
783 );
784
785 // The full patch is added
786 edit(
787 &context,
788 "
789
790 one
791 two
792
793 <patch>
794 <edit>«
795 <description>add a `two` function</description>
796 <path>src/lib.rs</path>
797 <operation>insert_after</operation>
798 <old_text>fn one</old_text>
799 <new_text>
800 fn two() {}
801 </new_text>
802 </edit>
803 </patch>
804
805 also,»",
806 cx,
807 );
808 expect_patches(
809 &context,
810 "
811
812 one
813 two
814
815 «<patch>
816 <edit>
817 <description>add a `two` function</description>
818 <path>src/lib.rs</path>
819 <operation>insert_after</operation>
820 <old_text>fn one</old_text>
821 <new_text>
822 fn two() {}
823 </new_text>
824 </edit>
825 </patch>
826 »
827 also,",
828 &[&[AssistantEdit {
829 path: "src/lib.rs".into(),
830 kind: AssistantEditKind::InsertAfter {
831 old_text: "fn one".into(),
832 new_text: "fn two() {}".into(),
833 description: Some("add a `two` function".into()),
834 },
835 }]],
836 cx,
837 );
838
839 // The step is manually edited.
840 edit(
841 &context,
842 "
843
844 one
845 two
846
847 <patch>
848 <edit>
849 <description>add a `two` function</description>
850 <path>src/lib.rs</path>
851 <operation>insert_after</operation>
852 <old_text>«fn zero»</old_text>
853 <new_text>
854 fn two() {}
855 </new_text>
856 </edit>
857 </patch>
858
859 also,",
860 cx,
861 );
862 expect_patches(
863 &context,
864 "
865
866 one
867 two
868
869 «<patch>
870 <edit>
871 <description>add a `two` function</description>
872 <path>src/lib.rs</path>
873 <operation>insert_after</operation>
874 <old_text>fn zero</old_text>
875 <new_text>
876 fn two() {}
877 </new_text>
878 </edit>
879 </patch>
880 »
881 also,",
882 &[&[AssistantEdit {
883 path: "src/lib.rs".into(),
884 kind: AssistantEditKind::InsertAfter {
885 old_text: "fn zero".into(),
886 new_text: "fn two() {}".into(),
887 description: Some("add a `two` function".into()),
888 },
889 }]],
890 cx,
891 );
892
893 // When setting the message role to User, the steps are cleared.
894 context.update(cx, |context, cx| {
895 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
896 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
897 });
898 expect_patches(
899 &context,
900 "
901
902 one
903 two
904
905 <patch>
906 <edit>
907 <description>add a `two` function</description>
908 <path>src/lib.rs</path>
909 <operation>insert_after</operation>
910 <old_text>fn zero</old_text>
911 <new_text>
912 fn two() {}
913 </new_text>
914 </edit>
915 </patch>
916
917 also,",
918 &[],
919 cx,
920 );
921
922 // When setting the message role back to Assistant, the steps are reparsed.
923 context.update(cx, |context, cx| {
924 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
925 });
926 expect_patches(
927 &context,
928 "
929
930 one
931 two
932
933 «<patch>
934 <edit>
935 <description>add a `two` function</description>
936 <path>src/lib.rs</path>
937 <operation>insert_after</operation>
938 <old_text>fn zero</old_text>
939 <new_text>
940 fn two() {}
941 </new_text>
942 </edit>
943 </patch>
944 »
945 also,",
946 &[&[AssistantEdit {
947 path: "src/lib.rs".into(),
948 kind: AssistantEditKind::InsertAfter {
949 old_text: "fn zero".into(),
950 new_text: "fn two() {}".into(),
951 description: Some("add a `two` function".into()),
952 },
953 }]],
954 cx,
955 );
956
957 // Ensure steps are re-parsed when deserializing.
958 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
959 let deserialized_context = cx.new(|cx| {
960 AssistantContext::deserialize(
961 serialized_context,
962 Default::default(),
963 registry.clone(),
964 prompt_builder.clone(),
965 Arc::new(SlashCommandWorkingSet::default()),
966 None,
967 None,
968 cx,
969 )
970 });
971 expect_patches(
972 &deserialized_context,
973 "
974
975 one
976 two
977
978 «<patch>
979 <edit>
980 <description>add a `two` function</description>
981 <path>src/lib.rs</path>
982 <operation>insert_after</operation>
983 <old_text>fn zero</old_text>
984 <new_text>
985 fn two() {}
986 </new_text>
987 </edit>
988 </patch>
989 »
990 also,",
991 &[&[AssistantEdit {
992 path: "src/lib.rs".into(),
993 kind: AssistantEditKind::InsertAfter {
994 old_text: "fn zero".into(),
995 new_text: "fn two() {}".into(),
996 description: Some("add a `two` function".into()),
997 },
998 }]],
999 cx,
1000 );
1001
1002 fn edit(
1003 context: &Entity<AssistantContext>,
1004 new_text_marked_with_edits: &str,
1005 cx: &mut TestAppContext,
1006 ) {
1007 context.update(cx, |context, cx| {
1008 context.buffer.update(cx, |buffer, cx| {
1009 buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
1010 });
1011 });
1012 cx.executor().run_until_parked();
1013 }
1014
1015 #[track_caller]
1016 fn expect_patches(
1017 context: &Entity<AssistantContext>,
1018 expected_marked_text: &str,
1019 expected_suggestions: &[&[AssistantEdit]],
1020 cx: &mut TestAppContext,
1021 ) {
1022 let expected_marked_text = expected_marked_text.unindent();
1023 let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
1024
1025 let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
1026 context.buffer.read_with(cx, |buffer, _| {
1027 let ranges = context
1028 .patches
1029 .iter()
1030 .map(|entry| entry.range.to_offset(buffer))
1031 .collect::<Vec<_>>();
1032 (
1033 buffer.text(),
1034 ranges,
1035 context
1036 .patches
1037 .iter()
1038 .map(|step| step.edits.clone())
1039 .collect::<Vec<_>>(),
1040 )
1041 })
1042 });
1043
1044 assert_eq!(buffer_text, expected_text);
1045
1046 let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
1047 assert_eq!(actual_marked_text, expected_marked_text);
1048
1049 assert_eq!(
1050 patches
1051 .iter()
1052 .map(|patch| {
1053 patch
1054 .iter()
1055 .map(|edit| {
1056 let edit = edit.as_ref().unwrap();
1057 AssistantEdit {
1058 path: edit.path.clone(),
1059 kind: edit.kind.clone(),
1060 }
1061 })
1062 .collect::<Vec<_>>()
1063 })
1064 .collect::<Vec<_>>(),
1065 expected_suggestions
1066 );
1067 }
1068}
1069
1070#[gpui::test]
1071async fn test_serialization(cx: &mut TestAppContext) {
1072 let settings_store = cx.update(SettingsStore::test);
1073 cx.set_global(settings_store);
1074 cx.update(LanguageModelRegistry::test);
1075 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1076 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1077 let context = cx.new(|cx| {
1078 AssistantContext::local(
1079 registry.clone(),
1080 None,
1081 None,
1082 prompt_builder.clone(),
1083 Arc::new(SlashCommandWorkingSet::default()),
1084 cx,
1085 )
1086 });
1087 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
1088 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
1089 let message_1 = context.update(cx, |context, cx| {
1090 context
1091 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
1092 .unwrap()
1093 });
1094 let message_2 = context.update(cx, |context, cx| {
1095 context
1096 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1097 .unwrap()
1098 });
1099 buffer.update(cx, |buffer, cx| {
1100 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
1101 buffer.finalize_last_transaction();
1102 });
1103 let _message_3 = context.update(cx, |context, cx| {
1104 context
1105 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
1106 .unwrap()
1107 });
1108 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1109 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
1110 assert_eq!(
1111 cx.read(|cx| messages(&context, cx)),
1112 [
1113 (message_0, Role::User, 0..2),
1114 (message_1.id, Role::Assistant, 2..6),
1115 (message_2.id, Role::System, 6..6),
1116 ]
1117 );
1118
1119 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
1120 let deserialized_context = cx.new(|cx| {
1121 AssistantContext::deserialize(
1122 serialized_context,
1123 Default::default(),
1124 registry.clone(),
1125 prompt_builder.clone(),
1126 Arc::new(SlashCommandWorkingSet::default()),
1127 None,
1128 None,
1129 cx,
1130 )
1131 });
1132 let deserialized_buffer =
1133 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
1134 assert_eq!(
1135 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
1136 "a\nb\nc\n"
1137 );
1138 assert_eq!(
1139 cx.read(|cx| messages(&deserialized_context, cx)),
1140 [
1141 (message_0, Role::User, 0..2),
1142 (message_1.id, Role::Assistant, 2..6),
1143 (message_2.id, Role::System, 6..6),
1144 ]
1145 );
1146}
1147
1148#[gpui::test(iterations = 100)]
1149async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
1150 let min_peers = env::var("MIN_PEERS")
1151 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
1152 .unwrap_or(2);
1153 let max_peers = env::var("MAX_PEERS")
1154 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
1155 .unwrap_or(5);
1156 let operations = env::var("OPERATIONS")
1157 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
1158 .unwrap_or(50);
1159
1160 let settings_store = cx.update(SettingsStore::test);
1161 cx.set_global(settings_store);
1162 cx.update(LanguageModelRegistry::test);
1163
1164 let slash_commands = cx.update(SlashCommandRegistry::default_global);
1165 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
1166 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
1167 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
1168
1169 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
1170 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
1171 let mut contexts = Vec::new();
1172
1173 let num_peers = rng.gen_range(min_peers..=max_peers);
1174 let context_id = ContextId::new();
1175 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1176 for i in 0..num_peers {
1177 let context = cx.new(|cx| {
1178 AssistantContext::new(
1179 context_id.clone(),
1180 i as ReplicaId,
1181 language::Capability::ReadWrite,
1182 registry.clone(),
1183 prompt_builder.clone(),
1184 Arc::new(SlashCommandWorkingSet::default()),
1185 None,
1186 None,
1187 cx,
1188 )
1189 });
1190
1191 cx.update(|cx| {
1192 cx.subscribe(&context, {
1193 let network = network.clone();
1194 move |_, event, _| {
1195 if let ContextEvent::Operation(op) = event {
1196 network
1197 .lock()
1198 .broadcast(i as ReplicaId, vec![op.to_proto()]);
1199 }
1200 }
1201 })
1202 .detach();
1203 });
1204
1205 contexts.push(context);
1206 network.lock().add_peer(i as ReplicaId);
1207 }
1208
1209 let mut mutation_count = operations;
1210
1211 while mutation_count > 0
1212 || !network.lock().is_idle()
1213 || network.lock().contains_disconnected_peers()
1214 {
1215 let context_index = rng.gen_range(0..contexts.len());
1216 let context = &contexts[context_index];
1217
1218 match rng.gen_range(0..100) {
1219 0..=29 if mutation_count > 0 => {
1220 log::info!("Context {}: edit buffer", context_index);
1221 context.update(cx, |context, cx| {
1222 context
1223 .buffer
1224 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1225 });
1226 mutation_count -= 1;
1227 }
1228 30..=44 if mutation_count > 0 => {
1229 context.update(cx, |context, cx| {
1230 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1231 log::info!("Context {}: split message at {:?}", context_index, range);
1232 context.split_message(range, cx);
1233 });
1234 mutation_count -= 1;
1235 }
1236 45..=59 if mutation_count > 0 => {
1237 context.update(cx, |context, cx| {
1238 if let Some(message) = context.messages(cx).choose(&mut rng) {
1239 let role = *[Role::User, Role::Assistant, Role::System]
1240 .choose(&mut rng)
1241 .unwrap();
1242 log::info!(
1243 "Context {}: insert message after {:?} with {:?}",
1244 context_index,
1245 message.id,
1246 role
1247 );
1248 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1249 }
1250 });
1251 mutation_count -= 1;
1252 }
1253 60..=74 if mutation_count > 0 => {
1254 context.update(cx, |context, cx| {
1255 let command_text = "/".to_string()
1256 + slash_commands
1257 .command_names()
1258 .choose(&mut rng)
1259 .unwrap()
1260 .clone()
1261 .as_ref();
1262
1263 let command_range = context.buffer.update(cx, |buffer, cx| {
1264 let offset = buffer.random_byte_range(0, &mut rng).start;
1265 buffer.edit(
1266 [(offset..offset, format!("\n{}\n", command_text))],
1267 None,
1268 cx,
1269 );
1270 offset + 1..offset + 1 + command_text.len()
1271 });
1272
1273 let output_text = RandomCharIter::new(&mut rng)
1274 .filter(|c| *c != '\r')
1275 .take(10)
1276 .collect::<String>();
1277
1278 let mut events = vec![Ok(SlashCommandEvent::StartMessage {
1279 role: Role::User,
1280 merge_same_roles: true,
1281 })];
1282
1283 let num_sections = rng.gen_range(0..=3);
1284 let mut section_start = 0;
1285 for _ in 0..num_sections {
1286 let mut section_end = rng.gen_range(section_start..=output_text.len());
1287 while !output_text.is_char_boundary(section_end) {
1288 section_end += 1;
1289 }
1290 events.push(Ok(SlashCommandEvent::StartSection {
1291 icon: IconName::Ai,
1292 label: "section".into(),
1293 metadata: None,
1294 }));
1295 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1296 text: output_text[section_start..section_end].to_string(),
1297 run_commands_in_text: false,
1298 })));
1299 events.push(Ok(SlashCommandEvent::EndSection));
1300 section_start = section_end;
1301 }
1302
1303 if section_start < output_text.len() {
1304 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1305 text: output_text[section_start..].to_string(),
1306 run_commands_in_text: false,
1307 })));
1308 }
1309
1310 log::info!(
1311 "Context {}: insert slash command output at {:?} with {:?} events",
1312 context_index,
1313 command_range,
1314 events.len()
1315 );
1316
1317 let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1318 ..context.buffer.read(cx).anchor_after(command_range.end);
1319 context.insert_command_output(
1320 command_range,
1321 "/command",
1322 Task::ready(Ok(stream::iter(events).boxed())),
1323 true,
1324 cx,
1325 );
1326 });
1327 cx.run_until_parked();
1328 mutation_count -= 1;
1329 }
1330 75..=84 if mutation_count > 0 => {
1331 context.update(cx, |context, cx| {
1332 if let Some(message) = context.messages(cx).choose(&mut rng) {
1333 let new_status = match rng.gen_range(0..3) {
1334 0 => MessageStatus::Done,
1335 1 => MessageStatus::Pending,
1336 _ => MessageStatus::Error(SharedString::from("Random error")),
1337 };
1338 log::info!(
1339 "Context {}: update message {:?} status to {:?}",
1340 context_index,
1341 message.id,
1342 new_status
1343 );
1344 context.update_metadata(message.id, cx, |metadata| {
1345 metadata.status = new_status;
1346 });
1347 }
1348 });
1349 mutation_count -= 1;
1350 }
1351 _ => {
1352 let replica_id = context_index as ReplicaId;
1353 if network.lock().is_disconnected(replica_id) {
1354 network.lock().reconnect_peer(replica_id, 0);
1355
1356 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1357 let host_context = &contexts[0].read(cx);
1358 let guest_context = context.read(cx);
1359 (
1360 guest_context.serialize_ops(&host_context.version(cx), cx),
1361 host_context.serialize_ops(&guest_context.version(cx), cx),
1362 )
1363 });
1364 let ops_to_send = ops_to_send.await;
1365 let ops_to_receive = ops_to_receive
1366 .await
1367 .into_iter()
1368 .map(ContextOperation::from_proto)
1369 .collect::<Result<Vec<_>>>()
1370 .unwrap();
1371 log::info!(
1372 "Context {}: reconnecting. Sent {} operations, received {} operations",
1373 context_index,
1374 ops_to_send.len(),
1375 ops_to_receive.len()
1376 );
1377
1378 network.lock().broadcast(replica_id, ops_to_send);
1379 context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1380 } else if rng.gen_bool(0.1) && replica_id != 0 {
1381 log::info!("Context {}: disconnecting", context_index);
1382 network.lock().disconnect_peer(replica_id);
1383 } else if network.lock().has_unreceived(replica_id) {
1384 log::info!("Context {}: applying operations", context_index);
1385 let ops = network.lock().receive(replica_id);
1386 let ops = ops
1387 .into_iter()
1388 .map(ContextOperation::from_proto)
1389 .collect::<Result<Vec<_>>>()
1390 .unwrap();
1391 context.update(cx, |context, cx| context.apply_ops(ops, cx));
1392 }
1393 }
1394 }
1395 }
1396
1397 cx.read(|cx| {
1398 let first_context = contexts[0].read(cx);
1399 for context in &contexts[1..] {
1400 let context = context.read(cx);
1401 assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
1402 assert_eq!(
1403 context.buffer.read(cx).text(),
1404 first_context.buffer.read(cx).text(),
1405 "Context {} text != Context 0 text",
1406 context.buffer.read(cx).replica_id()
1407 );
1408 assert_eq!(
1409 context.message_anchors,
1410 first_context.message_anchors,
1411 "Context {} messages != Context 0 messages",
1412 context.buffer.read(cx).replica_id()
1413 );
1414 assert_eq!(
1415 context.messages_metadata,
1416 first_context.messages_metadata,
1417 "Context {} message metadata != Context 0 message metadata",
1418 context.buffer.read(cx).replica_id()
1419 );
1420 assert_eq!(
1421 context.slash_command_output_sections,
1422 first_context.slash_command_output_sections,
1423 "Context {} slash command output sections != Context 0 slash command output sections",
1424 context.buffer.read(cx).replica_id()
1425 );
1426 }
1427 });
1428}
1429
1430#[gpui::test]
1431fn test_mark_cache_anchors(cx: &mut App) {
1432 let settings_store = SettingsStore::test(cx);
1433 LanguageModelRegistry::test(cx);
1434 cx.set_global(settings_store);
1435 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1436 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1437 let context = cx.new(|cx| {
1438 AssistantContext::local(
1439 registry,
1440 None,
1441 None,
1442 prompt_builder.clone(),
1443 Arc::new(SlashCommandWorkingSet::default()),
1444 cx,
1445 )
1446 });
1447 let buffer = context.read(cx).buffer.clone();
1448
1449 // Create a test cache configuration
1450 let cache_configuration = &Some(LanguageModelCacheConfiguration {
1451 max_cache_anchors: 3,
1452 should_speculate: true,
1453 min_total_token: 10,
1454 });
1455
1456 let message_1 = context.read(cx).message_anchors[0].clone();
1457
1458 context.update(cx, |context, cx| {
1459 context.mark_cache_anchors(cache_configuration, false, cx)
1460 });
1461
1462 assert_eq!(
1463 messages_cache(&context, cx)
1464 .iter()
1465 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1466 .count(),
1467 0,
1468 "Empty messages should not have any cache anchors."
1469 );
1470
1471 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1472 let message_2 = context
1473 .update(cx, |context, cx| {
1474 context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1475 })
1476 .unwrap();
1477
1478 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1479 let message_3 = context
1480 .update(cx, |context, cx| {
1481 context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1482 })
1483 .unwrap();
1484 buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1485
1486 context.update(cx, |context, cx| {
1487 context.mark_cache_anchors(cache_configuration, false, cx)
1488 });
1489 assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1490 assert_eq!(
1491 messages_cache(&context, cx)
1492 .iter()
1493 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1494 .count(),
1495 0,
1496 "Messages should not be marked for cache before going over the token minimum."
1497 );
1498 context.update(cx, |context, _| {
1499 context.token_count = Some(20);
1500 });
1501
1502 context.update(cx, |context, cx| {
1503 context.mark_cache_anchors(cache_configuration, true, cx)
1504 });
1505 assert_eq!(
1506 messages_cache(&context, cx)
1507 .iter()
1508 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1509 .collect::<Vec<bool>>(),
1510 vec![true, true, false],
1511 "Last message should not be an anchor on speculative request."
1512 );
1513
1514 context
1515 .update(cx, |context, cx| {
1516 context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1517 })
1518 .unwrap();
1519
1520 context.update(cx, |context, cx| {
1521 context.mark_cache_anchors(cache_configuration, false, cx)
1522 });
1523 assert_eq!(
1524 messages_cache(&context, cx)
1525 .iter()
1526 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1527 .collect::<Vec<bool>>(),
1528 vec![false, true, true, false],
1529 "Most recent message should also be cached if not a speculative request."
1530 );
1531 context.update(cx, |context, cx| {
1532 context.update_cache_status_for_completion(cx)
1533 });
1534 assert_eq!(
1535 messages_cache(&context, cx)
1536 .iter()
1537 .map(|(_, cache)| cache
1538 .as_ref()
1539 .map_or(None, |cache| Some(cache.status.clone())))
1540 .collect::<Vec<Option<CacheStatus>>>(),
1541 vec![
1542 Some(CacheStatus::Cached),
1543 Some(CacheStatus::Cached),
1544 Some(CacheStatus::Cached),
1545 None
1546 ],
1547 "All user messages prior to anchor should be marked as cached."
1548 );
1549
1550 buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1551 context.update(cx, |context, cx| {
1552 context.mark_cache_anchors(cache_configuration, false, cx)
1553 });
1554 assert_eq!(
1555 messages_cache(&context, cx)
1556 .iter()
1557 .map(|(_, cache)| cache
1558 .as_ref()
1559 .map_or(None, |cache| Some(cache.status.clone())))
1560 .collect::<Vec<Option<CacheStatus>>>(),
1561 vec![
1562 Some(CacheStatus::Cached),
1563 Some(CacheStatus::Cached),
1564 Some(CacheStatus::Pending),
1565 None
1566 ],
1567 "Modifying a message should invalidate it's cache but leave previous messages."
1568 );
1569 buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1570 context.update(cx, |context, cx| {
1571 context.mark_cache_anchors(cache_configuration, false, cx)
1572 });
1573 assert_eq!(
1574 messages_cache(&context, cx)
1575 .iter()
1576 .map(|(_, cache)| cache
1577 .as_ref()
1578 .map_or(None, |cache| Some(cache.status.clone())))
1579 .collect::<Vec<Option<CacheStatus>>>(),
1580 vec![
1581 Some(CacheStatus::Pending),
1582 Some(CacheStatus::Pending),
1583 Some(CacheStatus::Pending),
1584 None
1585 ],
1586 "Modifying a message should invalidate all future messages."
1587 );
1588}
1589
1590fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1591 context
1592 .read(cx)
1593 .messages(cx)
1594 .map(|message| (message.id, message.role, message.offset_range))
1595 .collect()
1596}
1597
1598fn messages_cache(
1599 context: &Entity<AssistantContext>,
1600 cx: &App,
1601) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1602 context
1603 .read(cx)
1604 .messages(cx)
1605 .map(|message| (message.id, message.cache.clone()))
1606 .collect()
1607}
1608
1609#[derive(Clone)]
1610struct FakeSlashCommand(String);
1611
1612impl SlashCommand for FakeSlashCommand {
1613 fn name(&self) -> String {
1614 self.0.clone()
1615 }
1616
1617 fn description(&self) -> String {
1618 format!("Fake slash command: {}", self.0)
1619 }
1620
1621 fn menu_text(&self) -> String {
1622 format!("Run fake command: {}", self.0)
1623 }
1624
1625 fn complete_argument(
1626 self: Arc<Self>,
1627 _arguments: &[String],
1628 _cancel: Arc<AtomicBool>,
1629 _workspace: Option<WeakEntity<Workspace>>,
1630 _window: &mut Window,
1631 _cx: &mut App,
1632 ) -> Task<Result<Vec<ArgumentCompletion>>> {
1633 Task::ready(Ok(vec![]))
1634 }
1635
1636 fn requires_argument(&self) -> bool {
1637 false
1638 }
1639
1640 fn run(
1641 self: Arc<Self>,
1642 _arguments: &[String],
1643 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1644 _context_buffer: BufferSnapshot,
1645 _workspace: WeakEntity<Workspace>,
1646 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1647 _window: &mut Window,
1648 _cx: &mut App,
1649 ) -> Task<SlashCommandResult> {
1650 Task::ready(Ok(SlashCommandOutput {
1651 text: format!("Executed fake command: {}", self.0),
1652 sections: vec![],
1653 run_commands_in_text: false,
1654 }
1655 .to_event_stream()))
1656 }
1657}