1use crate::{
2 AssistantContext, AssistantEdit, AssistantEditKind, CacheStatus, ContextEvent, ContextId,
3 ContextOperation, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
4};
5use anyhow::Result;
6use assistant_slash_command::{
7 ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
8 SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
9};
10use assistant_slash_commands::FileSlashCommand;
11use assistant_tool::ToolWorkingSet;
12use collections::{HashMap, HashSet};
13use fs::FakeFs;
14use futures::{
15 channel::mpsc,
16 stream::{self, StreamExt},
17};
18use gpui::{prelude::*, App, Entity, SharedString, Task, TestAppContext, WeakEntity};
19use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
20use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
21use parking_lot::Mutex;
22use pretty_assertions::assert_eq;
23use project::Project;
24use prompt_library::PromptBuilder;
25use rand::prelude::*;
26use serde_json::json;
27use settings::SettingsStore;
28use std::{
29 cell::RefCell,
30 env,
31 ops::Range,
32 path::Path,
33 rc::Rc,
34 sync::{atomic::AtomicBool, Arc},
35};
36use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToOffset};
37use ui::{IconName, Window};
38use unindent::Unindent;
39use util::{
40 test::{generate_marked_text, marked_text_ranges},
41 RandomCharIter,
42};
43use workspace::Workspace;
44
45#[gpui::test]
46fn test_inserting_and_removing_messages(cx: &mut App) {
47 let settings_store = SettingsStore::test(cx);
48 LanguageModelRegistry::test(cx);
49 cx.set_global(settings_store);
50 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
51 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
52 let context = cx.new(|cx| {
53 AssistantContext::local(
54 registry,
55 None,
56 None,
57 prompt_builder.clone(),
58 Arc::new(SlashCommandWorkingSet::default()),
59 Arc::new(ToolWorkingSet::default()),
60 cx,
61 )
62 });
63 let buffer = context.read(cx).buffer.clone();
64
65 let message_1 = context.read(cx).message_anchors[0].clone();
66 assert_eq!(
67 messages(&context, cx),
68 vec![(message_1.id, Role::User, 0..0)]
69 );
70
71 let message_2 = context.update(cx, |context, cx| {
72 context
73 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
74 .unwrap()
75 });
76 assert_eq!(
77 messages(&context, cx),
78 vec![
79 (message_1.id, Role::User, 0..1),
80 (message_2.id, Role::Assistant, 1..1)
81 ]
82 );
83
84 buffer.update(cx, |buffer, cx| {
85 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
86 });
87 assert_eq!(
88 messages(&context, cx),
89 vec![
90 (message_1.id, Role::User, 0..2),
91 (message_2.id, Role::Assistant, 2..3)
92 ]
93 );
94
95 let message_3 = context.update(cx, |context, cx| {
96 context
97 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
98 .unwrap()
99 });
100 assert_eq!(
101 messages(&context, cx),
102 vec![
103 (message_1.id, Role::User, 0..2),
104 (message_2.id, Role::Assistant, 2..4),
105 (message_3.id, Role::User, 4..4)
106 ]
107 );
108
109 let message_4 = context.update(cx, |context, cx| {
110 context
111 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
112 .unwrap()
113 });
114 assert_eq!(
115 messages(&context, cx),
116 vec![
117 (message_1.id, Role::User, 0..2),
118 (message_2.id, Role::Assistant, 2..4),
119 (message_4.id, Role::User, 4..5),
120 (message_3.id, Role::User, 5..5),
121 ]
122 );
123
124 buffer.update(cx, |buffer, cx| {
125 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
126 });
127 assert_eq!(
128 messages(&context, cx),
129 vec![
130 (message_1.id, Role::User, 0..2),
131 (message_2.id, Role::Assistant, 2..4),
132 (message_4.id, Role::User, 4..6),
133 (message_3.id, Role::User, 6..7),
134 ]
135 );
136
137 // Deleting across message boundaries merges the messages.
138 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
139 assert_eq!(
140 messages(&context, cx),
141 vec![
142 (message_1.id, Role::User, 0..3),
143 (message_3.id, Role::User, 3..4),
144 ]
145 );
146
147 // Undoing the deletion should also undo the merge.
148 buffer.update(cx, |buffer, cx| buffer.undo(cx));
149 assert_eq!(
150 messages(&context, cx),
151 vec![
152 (message_1.id, Role::User, 0..2),
153 (message_2.id, Role::Assistant, 2..4),
154 (message_4.id, Role::User, 4..6),
155 (message_3.id, Role::User, 6..7),
156 ]
157 );
158
159 // Redoing the deletion should also redo the merge.
160 buffer.update(cx, |buffer, cx| buffer.redo(cx));
161 assert_eq!(
162 messages(&context, cx),
163 vec![
164 (message_1.id, Role::User, 0..3),
165 (message_3.id, Role::User, 3..4),
166 ]
167 );
168
169 // Ensure we can still insert after a merged message.
170 let message_5 = context.update(cx, |context, cx| {
171 context
172 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
173 .unwrap()
174 });
175 assert_eq!(
176 messages(&context, cx),
177 vec![
178 (message_1.id, Role::User, 0..3),
179 (message_5.id, Role::System, 3..4),
180 (message_3.id, Role::User, 4..5)
181 ]
182 );
183}
184
185#[gpui::test]
186fn test_message_splitting(cx: &mut App) {
187 let settings_store = SettingsStore::test(cx);
188 cx.set_global(settings_store);
189 LanguageModelRegistry::test(cx);
190 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
191
192 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
193 let context = cx.new(|cx| {
194 AssistantContext::local(
195 registry.clone(),
196 None,
197 None,
198 prompt_builder.clone(),
199 Arc::new(SlashCommandWorkingSet::default()),
200 Arc::new(ToolWorkingSet::default()),
201 cx,
202 )
203 });
204 let buffer = context.read(cx).buffer.clone();
205
206 let message_1 = context.read(cx).message_anchors[0].clone();
207 assert_eq!(
208 messages(&context, cx),
209 vec![(message_1.id, Role::User, 0..0)]
210 );
211
212 buffer.update(cx, |buffer, cx| {
213 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
214 });
215
216 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
217 let message_2 = message_2.unwrap();
218
219 // We recycle newlines in the middle of a split message
220 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
221 assert_eq!(
222 messages(&context, cx),
223 vec![
224 (message_1.id, Role::User, 0..4),
225 (message_2.id, Role::User, 4..16),
226 ]
227 );
228
229 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
230 let message_3 = message_3.unwrap();
231
232 // We don't recycle newlines at the end of a split message
233 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
234 assert_eq!(
235 messages(&context, cx),
236 vec![
237 (message_1.id, Role::User, 0..4),
238 (message_3.id, Role::User, 4..5),
239 (message_2.id, Role::User, 5..17),
240 ]
241 );
242
243 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
244 let message_4 = message_4.unwrap();
245 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
246 assert_eq!(
247 messages(&context, cx),
248 vec![
249 (message_1.id, Role::User, 0..4),
250 (message_3.id, Role::User, 4..5),
251 (message_2.id, Role::User, 5..9),
252 (message_4.id, Role::User, 9..17),
253 ]
254 );
255
256 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
257 let message_5 = message_5.unwrap();
258 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
259 assert_eq!(
260 messages(&context, cx),
261 vec![
262 (message_1.id, Role::User, 0..4),
263 (message_3.id, Role::User, 4..5),
264 (message_2.id, Role::User, 5..9),
265 (message_4.id, Role::User, 9..10),
266 (message_5.id, Role::User, 10..18),
267 ]
268 );
269
270 let (message_6, message_7) =
271 context.update(cx, |context, cx| context.split_message(14..16, cx));
272 let message_6 = message_6.unwrap();
273 let message_7 = message_7.unwrap();
274 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
275 assert_eq!(
276 messages(&context, cx),
277 vec![
278 (message_1.id, Role::User, 0..4),
279 (message_3.id, Role::User, 4..5),
280 (message_2.id, Role::User, 5..9),
281 (message_4.id, Role::User, 9..10),
282 (message_5.id, Role::User, 10..14),
283 (message_6.id, Role::User, 14..17),
284 (message_7.id, Role::User, 17..19),
285 ]
286 );
287}
288
289#[gpui::test]
290fn test_messages_for_offsets(cx: &mut App) {
291 let settings_store = SettingsStore::test(cx);
292 LanguageModelRegistry::test(cx);
293 cx.set_global(settings_store);
294 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
295 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
296 let context = cx.new(|cx| {
297 AssistantContext::local(
298 registry,
299 None,
300 None,
301 prompt_builder.clone(),
302 Arc::new(SlashCommandWorkingSet::default()),
303 Arc::new(ToolWorkingSet::default()),
304 cx,
305 )
306 });
307 let buffer = context.read(cx).buffer.clone();
308
309 let message_1 = context.read(cx).message_anchors[0].clone();
310 assert_eq!(
311 messages(&context, cx),
312 vec![(message_1.id, Role::User, 0..0)]
313 );
314
315 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
316 let message_2 = context
317 .update(cx, |context, cx| {
318 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
319 })
320 .unwrap();
321 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
322
323 let message_3 = context
324 .update(cx, |context, cx| {
325 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
326 })
327 .unwrap();
328 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
329
330 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
331 assert_eq!(
332 messages(&context, cx),
333 vec![
334 (message_1.id, Role::User, 0..4),
335 (message_2.id, Role::User, 4..8),
336 (message_3.id, Role::User, 8..11)
337 ]
338 );
339
340 assert_eq!(
341 message_ids_for_offsets(&context, &[0, 4, 9], cx),
342 [message_1.id, message_2.id, message_3.id]
343 );
344 assert_eq!(
345 message_ids_for_offsets(&context, &[0, 1, 11], cx),
346 [message_1.id, message_3.id]
347 );
348
349 let message_4 = context
350 .update(cx, |context, cx| {
351 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
352 })
353 .unwrap();
354 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
355 assert_eq!(
356 messages(&context, cx),
357 vec![
358 (message_1.id, Role::User, 0..4),
359 (message_2.id, Role::User, 4..8),
360 (message_3.id, Role::User, 8..12),
361 (message_4.id, Role::User, 12..12)
362 ]
363 );
364 assert_eq!(
365 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
366 [message_1.id, message_2.id, message_3.id, message_4.id]
367 );
368
369 fn message_ids_for_offsets(
370 context: &Entity<AssistantContext>,
371 offsets: &[usize],
372 cx: &App,
373 ) -> Vec<MessageId> {
374 context
375 .read(cx)
376 .messages_for_offsets(offsets.iter().copied(), cx)
377 .into_iter()
378 .map(|message| message.id)
379 .collect()
380 }
381}
382
383#[gpui::test]
384async fn test_slash_commands(cx: &mut TestAppContext) {
385 let settings_store = cx.update(SettingsStore::test);
386 cx.set_global(settings_store);
387 cx.update(LanguageModelRegistry::test);
388 cx.update(Project::init_settings);
389 let fs = FakeFs::new(cx.background_executor.clone());
390
391 fs.insert_tree(
392 "/test",
393 json!({
394 "src": {
395 "lib.rs": "fn one() -> usize { 1 }",
396 "main.rs": "
397 use crate::one;
398 fn main() { one(); }
399 ".unindent(),
400 }
401 }),
402 )
403 .await;
404
405 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
406 slash_command_registry.register_command(FileSlashCommand, false);
407
408 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
409 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
410 let context = cx.new(|cx| {
411 AssistantContext::local(
412 registry.clone(),
413 None,
414 None,
415 prompt_builder.clone(),
416 Arc::new(SlashCommandWorkingSet::default()),
417 Arc::new(ToolWorkingSet::default()),
418 cx,
419 )
420 });
421
422 #[derive(Default)]
423 struct ContextRanges {
424 parsed_commands: HashSet<Range<language::Anchor>>,
425 command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
426 output_sections: HashSet<Range<language::Anchor>>,
427 }
428
429 let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
430 context.update(cx, |_, cx| {
431 cx.subscribe(&context, {
432 let context_ranges = context_ranges.clone();
433 move |context, _, event, _| {
434 let mut context_ranges = context_ranges.borrow_mut();
435 match event {
436 ContextEvent::InvokedSlashCommandChanged { command_id } => {
437 let command = context.invoked_slash_command(command_id).unwrap();
438 context_ranges
439 .command_outputs
440 .insert(*command_id, command.range.clone());
441 }
442 ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
443 for range in removed {
444 context_ranges.parsed_commands.remove(range);
445 }
446 for command in updated {
447 context_ranges
448 .parsed_commands
449 .insert(command.source_range.clone());
450 }
451 }
452 ContextEvent::SlashCommandOutputSectionAdded { section } => {
453 context_ranges.output_sections.insert(section.range.clone());
454 }
455 _ => {}
456 }
457 }
458 })
459 .detach();
460 });
461
462 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
463
464 // Insert a slash command
465 buffer.update(cx, |buffer, cx| {
466 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
467 });
468 assert_text_and_context_ranges(
469 &buffer,
470 &context_ranges,
471 &"
472 «/file src/lib.rs»"
473 .unindent(),
474 cx,
475 );
476
477 // Edit the argument of the slash command.
478 buffer.update(cx, |buffer, cx| {
479 let edit_offset = buffer.text().find("lib.rs").unwrap();
480 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
481 });
482 assert_text_and_context_ranges(
483 &buffer,
484 &context_ranges,
485 &"
486 «/file src/main.rs»"
487 .unindent(),
488 cx,
489 );
490
491 // Edit the name of the slash command, using one that doesn't exist.
492 buffer.update(cx, |buffer, cx| {
493 let edit_offset = buffer.text().find("/file").unwrap();
494 buffer.edit(
495 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
496 None,
497 cx,
498 );
499 });
500 assert_text_and_context_ranges(
501 &buffer,
502 &context_ranges,
503 &"
504 /unknown src/main.rs"
505 .unindent(),
506 cx,
507 );
508
509 // Undoing the insertion of an non-existent slash command resorts the previous one.
510 buffer.update(cx, |buffer, cx| buffer.undo(cx));
511 assert_text_and_context_ranges(
512 &buffer,
513 &context_ranges,
514 &"
515 «/file src/main.rs»"
516 .unindent(),
517 cx,
518 );
519
520 let (command_output_tx, command_output_rx) = mpsc::unbounded();
521 context.update(cx, |context, cx| {
522 let command_source_range = context.parsed_slash_commands[0].source_range.clone();
523 context.insert_command_output(
524 command_source_range,
525 "file",
526 Task::ready(Ok(command_output_rx.boxed())),
527 true,
528 cx,
529 );
530 });
531 assert_text_and_context_ranges(
532 &buffer,
533 &context_ranges,
534 &"
535 ⟦«/file src/main.rs»
536 …⟧
537 "
538 .unindent(),
539 cx,
540 );
541
542 command_output_tx
543 .unbounded_send(Ok(SlashCommandEvent::StartSection {
544 icon: IconName::Ai,
545 label: "src/main.rs".into(),
546 metadata: None,
547 }))
548 .unwrap();
549 command_output_tx
550 .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
551 .unwrap();
552 cx.run_until_parked();
553 assert_text_and_context_ranges(
554 &buffer,
555 &context_ranges,
556 &"
557 ⟦«/file src/main.rs»
558 src/main.rs…⟧
559 "
560 .unindent(),
561 cx,
562 );
563
564 command_output_tx
565 .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
566 .unwrap();
567 cx.run_until_parked();
568 assert_text_and_context_ranges(
569 &buffer,
570 &context_ranges,
571 &"
572 ⟦«/file src/main.rs»
573 src/main.rs
574 fn main() {}…⟧
575 "
576 .unindent(),
577 cx,
578 );
579
580 command_output_tx
581 .unbounded_send(Ok(SlashCommandEvent::EndSection))
582 .unwrap();
583 cx.run_until_parked();
584 assert_text_and_context_ranges(
585 &buffer,
586 &context_ranges,
587 &"
588 ⟦«/file src/main.rs»
589 ⟪src/main.rs
590 fn main() {}⟫…⟧
591 "
592 .unindent(),
593 cx,
594 );
595
596 drop(command_output_tx);
597 cx.run_until_parked();
598 assert_text_and_context_ranges(
599 &buffer,
600 &context_ranges,
601 &"
602 ⟦⟪src/main.rs
603 fn main() {}⟫⟧
604 "
605 .unindent(),
606 cx,
607 );
608
609 #[track_caller]
610 fn assert_text_and_context_ranges(
611 buffer: &Entity<Buffer>,
612 ranges: &RefCell<ContextRanges>,
613 expected_marked_text: &str,
614 cx: &mut TestAppContext,
615 ) {
616 let mut actual_marked_text = String::new();
617 buffer.update(cx, |buffer, _| {
618 struct Endpoint {
619 offset: usize,
620 marker: char,
621 }
622
623 let ranges = ranges.borrow();
624 let mut endpoints = Vec::new();
625 for range in ranges.command_outputs.values() {
626 endpoints.push(Endpoint {
627 offset: range.start.to_offset(buffer),
628 marker: '⟦',
629 });
630 }
631 for range in ranges.parsed_commands.iter() {
632 endpoints.push(Endpoint {
633 offset: range.start.to_offset(buffer),
634 marker: '«',
635 });
636 }
637 for range in ranges.output_sections.iter() {
638 endpoints.push(Endpoint {
639 offset: range.start.to_offset(buffer),
640 marker: '⟪',
641 });
642 }
643
644 for range in ranges.output_sections.iter() {
645 endpoints.push(Endpoint {
646 offset: range.end.to_offset(buffer),
647 marker: '⟫',
648 });
649 }
650 for range in ranges.parsed_commands.iter() {
651 endpoints.push(Endpoint {
652 offset: range.end.to_offset(buffer),
653 marker: '»',
654 });
655 }
656 for range in ranges.command_outputs.values() {
657 endpoints.push(Endpoint {
658 offset: range.end.to_offset(buffer),
659 marker: '⟧',
660 });
661 }
662
663 endpoints.sort_by_key(|endpoint| endpoint.offset);
664 let mut offset = 0;
665 for endpoint in endpoints {
666 actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
667 actual_marked_text.push(endpoint.marker);
668 offset = endpoint.offset;
669 }
670 actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
671 });
672
673 assert_eq!(actual_marked_text, expected_marked_text);
674 }
675}
676
677#[gpui::test]
678async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
679 cx.update(prompt_library::init);
680 let mut settings_store = cx.update(SettingsStore::test);
681 cx.update(|cx| {
682 settings_store
683 .set_user_settings(
684 r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
685 cx,
686 )
687 .unwrap()
688 });
689 cx.set_global(settings_store);
690 cx.update(language::init);
691 cx.update(Project::init_settings);
692 let fs = FakeFs::new(cx.executor());
693 let project = Project::test(fs, [Path::new("/root")], cx).await;
694 cx.update(LanguageModelRegistry::test);
695
696 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
697
698 // Create a new context
699 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
700 let context = cx.new(|cx| {
701 AssistantContext::local(
702 registry.clone(),
703 Some(project),
704 None,
705 prompt_builder.clone(),
706 Arc::new(SlashCommandWorkingSet::default()),
707 Arc::new(ToolWorkingSet::default()),
708 cx,
709 )
710 });
711
712 // Insert an assistant message to simulate a response.
713 let assistant_message_id = context.update(cx, |context, cx| {
714 let user_message_id = context.messages(cx).next().unwrap().id;
715 context
716 .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
717 .unwrap()
718 .id
719 });
720
721 // No edit tags
722 edit(
723 &context,
724 "
725
726 «one
727 two
728 »",
729 cx,
730 );
731 expect_patches(
732 &context,
733 "
734
735 one
736 two
737 ",
738 &[],
739 cx,
740 );
741
742 // Partial edit step tag is added
743 edit(
744 &context,
745 "
746
747 one
748 two
749 «
750 <patch»",
751 cx,
752 );
753 expect_patches(
754 &context,
755 "
756
757 one
758 two
759
760 <patch",
761 &[],
762 cx,
763 );
764
765 // The rest of the step tag is added. The unclosed
766 // step is treated as incomplete.
767 edit(
768 &context,
769 "
770
771 one
772 two
773
774 <patch«>
775 <edit>»",
776 cx,
777 );
778 expect_patches(
779 &context,
780 "
781
782 one
783 two
784
785 «<patch>
786 <edit>»",
787 &[&[]],
788 cx,
789 );
790
791 // The full patch is added
792 edit(
793 &context,
794 "
795
796 one
797 two
798
799 <patch>
800 <edit>«
801 <description>add a `two` function</description>
802 <path>src/lib.rs</path>
803 <operation>insert_after</operation>
804 <old_text>fn one</old_text>
805 <new_text>
806 fn two() {}
807 </new_text>
808 </edit>
809 </patch>
810
811 also,»",
812 cx,
813 );
814 expect_patches(
815 &context,
816 "
817
818 one
819 two
820
821 «<patch>
822 <edit>
823 <description>add a `two` function</description>
824 <path>src/lib.rs</path>
825 <operation>insert_after</operation>
826 <old_text>fn one</old_text>
827 <new_text>
828 fn two() {}
829 </new_text>
830 </edit>
831 </patch>
832 »
833 also,",
834 &[&[AssistantEdit {
835 path: "src/lib.rs".into(),
836 kind: AssistantEditKind::InsertAfter {
837 old_text: "fn one".into(),
838 new_text: "fn two() {}".into(),
839 description: Some("add a `two` function".into()),
840 },
841 }]],
842 cx,
843 );
844
845 // The step is manually edited.
846 edit(
847 &context,
848 "
849
850 one
851 two
852
853 <patch>
854 <edit>
855 <description>add a `two` function</description>
856 <path>src/lib.rs</path>
857 <operation>insert_after</operation>
858 <old_text>«fn zero»</old_text>
859 <new_text>
860 fn two() {}
861 </new_text>
862 </edit>
863 </patch>
864
865 also,",
866 cx,
867 );
868 expect_patches(
869 &context,
870 "
871
872 one
873 two
874
875 «<patch>
876 <edit>
877 <description>add a `two` function</description>
878 <path>src/lib.rs</path>
879 <operation>insert_after</operation>
880 <old_text>fn zero</old_text>
881 <new_text>
882 fn two() {}
883 </new_text>
884 </edit>
885 </patch>
886 »
887 also,",
888 &[&[AssistantEdit {
889 path: "src/lib.rs".into(),
890 kind: AssistantEditKind::InsertAfter {
891 old_text: "fn zero".into(),
892 new_text: "fn two() {}".into(),
893 description: Some("add a `two` function".into()),
894 },
895 }]],
896 cx,
897 );
898
899 // When setting the message role to User, the steps are cleared.
900 context.update(cx, |context, cx| {
901 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
902 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
903 });
904 expect_patches(
905 &context,
906 "
907
908 one
909 two
910
911 <patch>
912 <edit>
913 <description>add a `two` function</description>
914 <path>src/lib.rs</path>
915 <operation>insert_after</operation>
916 <old_text>fn zero</old_text>
917 <new_text>
918 fn two() {}
919 </new_text>
920 </edit>
921 </patch>
922
923 also,",
924 &[],
925 cx,
926 );
927
928 // When setting the message role back to Assistant, the steps are reparsed.
929 context.update(cx, |context, cx| {
930 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
931 });
932 expect_patches(
933 &context,
934 "
935
936 one
937 two
938
939 «<patch>
940 <edit>
941 <description>add a `two` function</description>
942 <path>src/lib.rs</path>
943 <operation>insert_after</operation>
944 <old_text>fn zero</old_text>
945 <new_text>
946 fn two() {}
947 </new_text>
948 </edit>
949 </patch>
950 »
951 also,",
952 &[&[AssistantEdit {
953 path: "src/lib.rs".into(),
954 kind: AssistantEditKind::InsertAfter {
955 old_text: "fn zero".into(),
956 new_text: "fn two() {}".into(),
957 description: Some("add a `two` function".into()),
958 },
959 }]],
960 cx,
961 );
962
963 // Ensure steps are re-parsed when deserializing.
964 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
965 let deserialized_context = cx.new(|cx| {
966 AssistantContext::deserialize(
967 serialized_context,
968 Default::default(),
969 registry.clone(),
970 prompt_builder.clone(),
971 Arc::new(SlashCommandWorkingSet::default()),
972 Arc::new(ToolWorkingSet::default()),
973 None,
974 None,
975 cx,
976 )
977 });
978 expect_patches(
979 &deserialized_context,
980 "
981
982 one
983 two
984
985 «<patch>
986 <edit>
987 <description>add a `two` function</description>
988 <path>src/lib.rs</path>
989 <operation>insert_after</operation>
990 <old_text>fn zero</old_text>
991 <new_text>
992 fn two() {}
993 </new_text>
994 </edit>
995 </patch>
996 »
997 also,",
998 &[&[AssistantEdit {
999 path: "src/lib.rs".into(),
1000 kind: AssistantEditKind::InsertAfter {
1001 old_text: "fn zero".into(),
1002 new_text: "fn two() {}".into(),
1003 description: Some("add a `two` function".into()),
1004 },
1005 }]],
1006 cx,
1007 );
1008
1009 fn edit(
1010 context: &Entity<AssistantContext>,
1011 new_text_marked_with_edits: &str,
1012 cx: &mut TestAppContext,
1013 ) {
1014 context.update(cx, |context, cx| {
1015 context.buffer.update(cx, |buffer, cx| {
1016 buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
1017 });
1018 });
1019 cx.executor().run_until_parked();
1020 }
1021
1022 #[track_caller]
1023 fn expect_patches(
1024 context: &Entity<AssistantContext>,
1025 expected_marked_text: &str,
1026 expected_suggestions: &[&[AssistantEdit]],
1027 cx: &mut TestAppContext,
1028 ) {
1029 let expected_marked_text = expected_marked_text.unindent();
1030 let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
1031
1032 let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
1033 context.buffer.read_with(cx, |buffer, _| {
1034 let ranges = context
1035 .patches
1036 .iter()
1037 .map(|entry| entry.range.to_offset(buffer))
1038 .collect::<Vec<_>>();
1039 (
1040 buffer.text(),
1041 ranges,
1042 context
1043 .patches
1044 .iter()
1045 .map(|step| step.edits.clone())
1046 .collect::<Vec<_>>(),
1047 )
1048 })
1049 });
1050
1051 assert_eq!(buffer_text, expected_text);
1052
1053 let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
1054 assert_eq!(actual_marked_text, expected_marked_text);
1055
1056 assert_eq!(
1057 patches
1058 .iter()
1059 .map(|patch| {
1060 patch
1061 .iter()
1062 .map(|edit| {
1063 let edit = edit.as_ref().unwrap();
1064 AssistantEdit {
1065 path: edit.path.clone(),
1066 kind: edit.kind.clone(),
1067 }
1068 })
1069 .collect::<Vec<_>>()
1070 })
1071 .collect::<Vec<_>>(),
1072 expected_suggestions
1073 );
1074 }
1075}
1076
1077#[gpui::test]
1078async fn test_serialization(cx: &mut TestAppContext) {
1079 let settings_store = cx.update(SettingsStore::test);
1080 cx.set_global(settings_store);
1081 cx.update(LanguageModelRegistry::test);
1082 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1083 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1084 let context = cx.new(|cx| {
1085 AssistantContext::local(
1086 registry.clone(),
1087 None,
1088 None,
1089 prompt_builder.clone(),
1090 Arc::new(SlashCommandWorkingSet::default()),
1091 Arc::new(ToolWorkingSet::default()),
1092 cx,
1093 )
1094 });
1095 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
1096 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
1097 let message_1 = context.update(cx, |context, cx| {
1098 context
1099 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
1100 .unwrap()
1101 });
1102 let message_2 = context.update(cx, |context, cx| {
1103 context
1104 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1105 .unwrap()
1106 });
1107 buffer.update(cx, |buffer, cx| {
1108 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
1109 buffer.finalize_last_transaction();
1110 });
1111 let _message_3 = context.update(cx, |context, cx| {
1112 context
1113 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
1114 .unwrap()
1115 });
1116 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1117 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
1118 assert_eq!(
1119 cx.read(|cx| messages(&context, cx)),
1120 [
1121 (message_0, Role::User, 0..2),
1122 (message_1.id, Role::Assistant, 2..6),
1123 (message_2.id, Role::System, 6..6),
1124 ]
1125 );
1126
1127 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
1128 let deserialized_context = cx.new(|cx| {
1129 AssistantContext::deserialize(
1130 serialized_context,
1131 Default::default(),
1132 registry.clone(),
1133 prompt_builder.clone(),
1134 Arc::new(SlashCommandWorkingSet::default()),
1135 Arc::new(ToolWorkingSet::default()),
1136 None,
1137 None,
1138 cx,
1139 )
1140 });
1141 let deserialized_buffer =
1142 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
1143 assert_eq!(
1144 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
1145 "a\nb\nc\n"
1146 );
1147 assert_eq!(
1148 cx.read(|cx| messages(&deserialized_context, cx)),
1149 [
1150 (message_0, Role::User, 0..2),
1151 (message_1.id, Role::Assistant, 2..6),
1152 (message_2.id, Role::System, 6..6),
1153 ]
1154 );
1155}
1156
1157#[gpui::test(iterations = 100)]
1158async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
1159 let min_peers = env::var("MIN_PEERS")
1160 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
1161 .unwrap_or(2);
1162 let max_peers = env::var("MAX_PEERS")
1163 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
1164 .unwrap_or(5);
1165 let operations = env::var("OPERATIONS")
1166 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
1167 .unwrap_or(50);
1168
1169 let settings_store = cx.update(SettingsStore::test);
1170 cx.set_global(settings_store);
1171 cx.update(LanguageModelRegistry::test);
1172
1173 let slash_commands = cx.update(SlashCommandRegistry::default_global);
1174 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
1175 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
1176 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
1177
1178 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
1179 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
1180 let mut contexts = Vec::new();
1181
1182 let num_peers = rng.gen_range(min_peers..=max_peers);
1183 let context_id = ContextId::new();
1184 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1185 for i in 0..num_peers {
1186 let context = cx.new(|cx| {
1187 AssistantContext::new(
1188 context_id.clone(),
1189 i as ReplicaId,
1190 language::Capability::ReadWrite,
1191 registry.clone(),
1192 prompt_builder.clone(),
1193 Arc::new(SlashCommandWorkingSet::default()),
1194 Arc::new(ToolWorkingSet::default()),
1195 None,
1196 None,
1197 cx,
1198 )
1199 });
1200
1201 cx.update(|cx| {
1202 cx.subscribe(&context, {
1203 let network = network.clone();
1204 move |_, event, _| {
1205 if let ContextEvent::Operation(op) = event {
1206 network
1207 .lock()
1208 .broadcast(i as ReplicaId, vec![op.to_proto()]);
1209 }
1210 }
1211 })
1212 .detach();
1213 });
1214
1215 contexts.push(context);
1216 network.lock().add_peer(i as ReplicaId);
1217 }
1218
1219 let mut mutation_count = operations;
1220
1221 while mutation_count > 0
1222 || !network.lock().is_idle()
1223 || network.lock().contains_disconnected_peers()
1224 {
1225 let context_index = rng.gen_range(0..contexts.len());
1226 let context = &contexts[context_index];
1227
1228 match rng.gen_range(0..100) {
1229 0..=29 if mutation_count > 0 => {
1230 log::info!("Context {}: edit buffer", context_index);
1231 context.update(cx, |context, cx| {
1232 context
1233 .buffer
1234 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1235 });
1236 mutation_count -= 1;
1237 }
1238 30..=44 if mutation_count > 0 => {
1239 context.update(cx, |context, cx| {
1240 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1241 log::info!("Context {}: split message at {:?}", context_index, range);
1242 context.split_message(range, cx);
1243 });
1244 mutation_count -= 1;
1245 }
1246 45..=59 if mutation_count > 0 => {
1247 context.update(cx, |context, cx| {
1248 if let Some(message) = context.messages(cx).choose(&mut rng) {
1249 let role = *[Role::User, Role::Assistant, Role::System]
1250 .choose(&mut rng)
1251 .unwrap();
1252 log::info!(
1253 "Context {}: insert message after {:?} with {:?}",
1254 context_index,
1255 message.id,
1256 role
1257 );
1258 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1259 }
1260 });
1261 mutation_count -= 1;
1262 }
1263 60..=74 if mutation_count > 0 => {
1264 context.update(cx, |context, cx| {
1265 let command_text = "/".to_string()
1266 + slash_commands
1267 .command_names()
1268 .choose(&mut rng)
1269 .unwrap()
1270 .clone()
1271 .as_ref();
1272
1273 let command_range = context.buffer.update(cx, |buffer, cx| {
1274 let offset = buffer.random_byte_range(0, &mut rng).start;
1275 buffer.edit(
1276 [(offset..offset, format!("\n{}\n", command_text))],
1277 None,
1278 cx,
1279 );
1280 offset + 1..offset + 1 + command_text.len()
1281 });
1282
1283 let output_text = RandomCharIter::new(&mut rng)
1284 .filter(|c| *c != '\r')
1285 .take(10)
1286 .collect::<String>();
1287
1288 let mut events = vec![Ok(SlashCommandEvent::StartMessage {
1289 role: Role::User,
1290 merge_same_roles: true,
1291 })];
1292
1293 let num_sections = rng.gen_range(0..=3);
1294 let mut section_start = 0;
1295 for _ in 0..num_sections {
1296 let mut section_end = rng.gen_range(section_start..=output_text.len());
1297 while !output_text.is_char_boundary(section_end) {
1298 section_end += 1;
1299 }
1300 events.push(Ok(SlashCommandEvent::StartSection {
1301 icon: IconName::Ai,
1302 label: "section".into(),
1303 metadata: None,
1304 }));
1305 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1306 text: output_text[section_start..section_end].to_string(),
1307 run_commands_in_text: false,
1308 })));
1309 events.push(Ok(SlashCommandEvent::EndSection));
1310 section_start = section_end;
1311 }
1312
1313 if section_start < output_text.len() {
1314 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1315 text: output_text[section_start..].to_string(),
1316 run_commands_in_text: false,
1317 })));
1318 }
1319
1320 log::info!(
1321 "Context {}: insert slash command output at {:?} with {:?} events",
1322 context_index,
1323 command_range,
1324 events.len()
1325 );
1326
1327 let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1328 ..context.buffer.read(cx).anchor_after(command_range.end);
1329 context.insert_command_output(
1330 command_range,
1331 "/command",
1332 Task::ready(Ok(stream::iter(events).boxed())),
1333 true,
1334 cx,
1335 );
1336 });
1337 cx.run_until_parked();
1338 mutation_count -= 1;
1339 }
1340 75..=84 if mutation_count > 0 => {
1341 context.update(cx, |context, cx| {
1342 if let Some(message) = context.messages(cx).choose(&mut rng) {
1343 let new_status = match rng.gen_range(0..3) {
1344 0 => MessageStatus::Done,
1345 1 => MessageStatus::Pending,
1346 _ => MessageStatus::Error(SharedString::from("Random error")),
1347 };
1348 log::info!(
1349 "Context {}: update message {:?} status to {:?}",
1350 context_index,
1351 message.id,
1352 new_status
1353 );
1354 context.update_metadata(message.id, cx, |metadata| {
1355 metadata.status = new_status;
1356 });
1357 }
1358 });
1359 mutation_count -= 1;
1360 }
1361 _ => {
1362 let replica_id = context_index as ReplicaId;
1363 if network.lock().is_disconnected(replica_id) {
1364 network.lock().reconnect_peer(replica_id, 0);
1365
1366 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1367 let host_context = &contexts[0].read(cx);
1368 let guest_context = context.read(cx);
1369 (
1370 guest_context.serialize_ops(&host_context.version(cx), cx),
1371 host_context.serialize_ops(&guest_context.version(cx), cx),
1372 )
1373 });
1374 let ops_to_send = ops_to_send.await;
1375 let ops_to_receive = ops_to_receive
1376 .await
1377 .into_iter()
1378 .map(ContextOperation::from_proto)
1379 .collect::<Result<Vec<_>>>()
1380 .unwrap();
1381 log::info!(
1382 "Context {}: reconnecting. Sent {} operations, received {} operations",
1383 context_index,
1384 ops_to_send.len(),
1385 ops_to_receive.len()
1386 );
1387
1388 network.lock().broadcast(replica_id, ops_to_send);
1389 context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1390 } else if rng.gen_bool(0.1) && replica_id != 0 {
1391 log::info!("Context {}: disconnecting", context_index);
1392 network.lock().disconnect_peer(replica_id);
1393 } else if network.lock().has_unreceived(replica_id) {
1394 log::info!("Context {}: applying operations", context_index);
1395 let ops = network.lock().receive(replica_id);
1396 let ops = ops
1397 .into_iter()
1398 .map(ContextOperation::from_proto)
1399 .collect::<Result<Vec<_>>>()
1400 .unwrap();
1401 context.update(cx, |context, cx| context.apply_ops(ops, cx));
1402 }
1403 }
1404 }
1405 }
1406
1407 cx.read(|cx| {
1408 let first_context = contexts[0].read(cx);
1409 for context in &contexts[1..] {
1410 let context = context.read(cx);
1411 assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
1412 assert_eq!(
1413 context.buffer.read(cx).text(),
1414 first_context.buffer.read(cx).text(),
1415 "Context {} text != Context 0 text",
1416 context.buffer.read(cx).replica_id()
1417 );
1418 assert_eq!(
1419 context.message_anchors,
1420 first_context.message_anchors,
1421 "Context {} messages != Context 0 messages",
1422 context.buffer.read(cx).replica_id()
1423 );
1424 assert_eq!(
1425 context.messages_metadata,
1426 first_context.messages_metadata,
1427 "Context {} message metadata != Context 0 message metadata",
1428 context.buffer.read(cx).replica_id()
1429 );
1430 assert_eq!(
1431 context.slash_command_output_sections,
1432 first_context.slash_command_output_sections,
1433 "Context {} slash command output sections != Context 0 slash command output sections",
1434 context.buffer.read(cx).replica_id()
1435 );
1436 }
1437 });
1438}
1439
1440#[gpui::test]
1441fn test_mark_cache_anchors(cx: &mut App) {
1442 let settings_store = SettingsStore::test(cx);
1443 LanguageModelRegistry::test(cx);
1444 cx.set_global(settings_store);
1445 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1446 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1447 let context = cx.new(|cx| {
1448 AssistantContext::local(
1449 registry,
1450 None,
1451 None,
1452 prompt_builder.clone(),
1453 Arc::new(SlashCommandWorkingSet::default()),
1454 Arc::new(ToolWorkingSet::default()),
1455 cx,
1456 )
1457 });
1458 let buffer = context.read(cx).buffer.clone();
1459
1460 // Create a test cache configuration
1461 let cache_configuration = &Some(LanguageModelCacheConfiguration {
1462 max_cache_anchors: 3,
1463 should_speculate: true,
1464 min_total_token: 10,
1465 });
1466
1467 let message_1 = context.read(cx).message_anchors[0].clone();
1468
1469 context.update(cx, |context, cx| {
1470 context.mark_cache_anchors(cache_configuration, false, cx)
1471 });
1472
1473 assert_eq!(
1474 messages_cache(&context, cx)
1475 .iter()
1476 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1477 .count(),
1478 0,
1479 "Empty messages should not have any cache anchors."
1480 );
1481
1482 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1483 let message_2 = context
1484 .update(cx, |context, cx| {
1485 context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1486 })
1487 .unwrap();
1488
1489 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1490 let message_3 = context
1491 .update(cx, |context, cx| {
1492 context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1493 })
1494 .unwrap();
1495 buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1496
1497 context.update(cx, |context, cx| {
1498 context.mark_cache_anchors(cache_configuration, false, cx)
1499 });
1500 assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1501 assert_eq!(
1502 messages_cache(&context, cx)
1503 .iter()
1504 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1505 .count(),
1506 0,
1507 "Messages should not be marked for cache before going over the token minimum."
1508 );
1509 context.update(cx, |context, _| {
1510 context.token_count = Some(20);
1511 });
1512
1513 context.update(cx, |context, cx| {
1514 context.mark_cache_anchors(cache_configuration, true, cx)
1515 });
1516 assert_eq!(
1517 messages_cache(&context, cx)
1518 .iter()
1519 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1520 .collect::<Vec<bool>>(),
1521 vec![true, true, false],
1522 "Last message should not be an anchor on speculative request."
1523 );
1524
1525 context
1526 .update(cx, |context, cx| {
1527 context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1528 })
1529 .unwrap();
1530
1531 context.update(cx, |context, cx| {
1532 context.mark_cache_anchors(cache_configuration, false, cx)
1533 });
1534 assert_eq!(
1535 messages_cache(&context, cx)
1536 .iter()
1537 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1538 .collect::<Vec<bool>>(),
1539 vec![false, true, true, false],
1540 "Most recent message should also be cached if not a speculative request."
1541 );
1542 context.update(cx, |context, cx| {
1543 context.update_cache_status_for_completion(cx)
1544 });
1545 assert_eq!(
1546 messages_cache(&context, cx)
1547 .iter()
1548 .map(|(_, cache)| cache
1549 .as_ref()
1550 .map_or(None, |cache| Some(cache.status.clone())))
1551 .collect::<Vec<Option<CacheStatus>>>(),
1552 vec![
1553 Some(CacheStatus::Cached),
1554 Some(CacheStatus::Cached),
1555 Some(CacheStatus::Cached),
1556 None
1557 ],
1558 "All user messages prior to anchor should be marked as cached."
1559 );
1560
1561 buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1562 context.update(cx, |context, cx| {
1563 context.mark_cache_anchors(cache_configuration, false, cx)
1564 });
1565 assert_eq!(
1566 messages_cache(&context, cx)
1567 .iter()
1568 .map(|(_, cache)| cache
1569 .as_ref()
1570 .map_or(None, |cache| Some(cache.status.clone())))
1571 .collect::<Vec<Option<CacheStatus>>>(),
1572 vec![
1573 Some(CacheStatus::Cached),
1574 Some(CacheStatus::Cached),
1575 Some(CacheStatus::Pending),
1576 None
1577 ],
1578 "Modifying a message should invalidate it's cache but leave previous messages."
1579 );
1580 buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1581 context.update(cx, |context, cx| {
1582 context.mark_cache_anchors(cache_configuration, false, cx)
1583 });
1584 assert_eq!(
1585 messages_cache(&context, cx)
1586 .iter()
1587 .map(|(_, cache)| cache
1588 .as_ref()
1589 .map_or(None, |cache| Some(cache.status.clone())))
1590 .collect::<Vec<Option<CacheStatus>>>(),
1591 vec![
1592 Some(CacheStatus::Pending),
1593 Some(CacheStatus::Pending),
1594 Some(CacheStatus::Pending),
1595 None
1596 ],
1597 "Modifying a message should invalidate all future messages."
1598 );
1599}
1600
1601fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1602 context
1603 .read(cx)
1604 .messages(cx)
1605 .map(|message| (message.id, message.role, message.offset_range))
1606 .collect()
1607}
1608
1609fn messages_cache(
1610 context: &Entity<AssistantContext>,
1611 cx: &App,
1612) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1613 context
1614 .read(cx)
1615 .messages(cx)
1616 .map(|message| (message.id, message.cache.clone()))
1617 .collect()
1618}
1619
1620#[derive(Clone)]
1621struct FakeSlashCommand(String);
1622
1623impl SlashCommand for FakeSlashCommand {
1624 fn name(&self) -> String {
1625 self.0.clone()
1626 }
1627
1628 fn description(&self) -> String {
1629 format!("Fake slash command: {}", self.0)
1630 }
1631
1632 fn menu_text(&self) -> String {
1633 format!("Run fake command: {}", self.0)
1634 }
1635
1636 fn complete_argument(
1637 self: Arc<Self>,
1638 _arguments: &[String],
1639 _cancel: Arc<AtomicBool>,
1640 _workspace: Option<WeakEntity<Workspace>>,
1641 _window: &mut Window,
1642 _cx: &mut App,
1643 ) -> Task<Result<Vec<ArgumentCompletion>>> {
1644 Task::ready(Ok(vec![]))
1645 }
1646
1647 fn requires_argument(&self) -> bool {
1648 false
1649 }
1650
1651 fn run(
1652 self: Arc<Self>,
1653 _arguments: &[String],
1654 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1655 _context_buffer: BufferSnapshot,
1656 _workspace: WeakEntity<Workspace>,
1657 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1658 _window: &mut Window,
1659 _cx: &mut App,
1660 ) -> Task<SlashCommandResult> {
1661 Task::ready(Ok(SlashCommandOutput {
1662 text: format!("Executed fake command: {}", self.0),
1663 sections: vec![],
1664 run_commands_in_text: false,
1665 }
1666 .to_event_stream()))
1667 }
1668}