1use super::{AssistantEdit, MessageCacheMetadata};
2use crate::{
3 assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus,
4 Context, ContextEvent, ContextId, ContextOperation, InvokedSlashCommandId, MessageId,
5 MessageStatus, PromptBuilder,
6};
7use anyhow::Result;
8use assistant_slash_command::{
9 ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
10 SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
11};
12use assistant_tool::ToolWorkingSet;
13use collections::{HashMap, HashSet};
14use fs::FakeFs;
15use futures::{
16 channel::mpsc,
17 stream::{self, StreamExt},
18};
19use gpui::{prelude::*, AppContext, Model, SharedString, Task, TestAppContext, WeakView};
20use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
21use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
22use parking_lot::Mutex;
23use pretty_assertions::assert_eq;
24use project::Project;
25use rand::prelude::*;
26use serde_json::json;
27use settings::SettingsStore;
28use std::{
29 cell::RefCell,
30 env,
31 ops::Range,
32 path::Path,
33 rc::Rc,
34 sync::{atomic::AtomicBool, Arc},
35};
36use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToOffset};
37use ui::{IconName, WindowContext};
38use unindent::Unindent;
39use util::{
40 test::{generate_marked_text, marked_text_ranges},
41 RandomCharIter,
42};
43use workspace::Workspace;
44
45#[gpui::test]
46fn test_inserting_and_removing_messages(cx: &mut AppContext) {
47 let settings_store = SettingsStore::test(cx);
48 LanguageModelRegistry::test(cx);
49 cx.set_global(settings_store);
50 assistant_panel::init(cx);
51 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
52 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
53 let context = cx.new_model(|cx| {
54 Context::local(
55 registry,
56 None,
57 None,
58 prompt_builder.clone(),
59 Arc::new(SlashCommandWorkingSet::default()),
60 Arc::new(ToolWorkingSet::default()),
61 cx,
62 )
63 });
64 let buffer = context.read(cx).buffer.clone();
65
66 let message_1 = context.read(cx).message_anchors[0].clone();
67 assert_eq!(
68 messages(&context, cx),
69 vec![(message_1.id, Role::User, 0..0)]
70 );
71
72 let message_2 = context.update(cx, |context, cx| {
73 context
74 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
75 .unwrap()
76 });
77 assert_eq!(
78 messages(&context, cx),
79 vec![
80 (message_1.id, Role::User, 0..1),
81 (message_2.id, Role::Assistant, 1..1)
82 ]
83 );
84
85 buffer.update(cx, |buffer, cx| {
86 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
87 });
88 assert_eq!(
89 messages(&context, cx),
90 vec![
91 (message_1.id, Role::User, 0..2),
92 (message_2.id, Role::Assistant, 2..3)
93 ]
94 );
95
96 let message_3 = context.update(cx, |context, cx| {
97 context
98 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
99 .unwrap()
100 });
101 assert_eq!(
102 messages(&context, cx),
103 vec![
104 (message_1.id, Role::User, 0..2),
105 (message_2.id, Role::Assistant, 2..4),
106 (message_3.id, Role::User, 4..4)
107 ]
108 );
109
110 let message_4 = context.update(cx, |context, cx| {
111 context
112 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
113 .unwrap()
114 });
115 assert_eq!(
116 messages(&context, cx),
117 vec![
118 (message_1.id, Role::User, 0..2),
119 (message_2.id, Role::Assistant, 2..4),
120 (message_4.id, Role::User, 4..5),
121 (message_3.id, Role::User, 5..5),
122 ]
123 );
124
125 buffer.update(cx, |buffer, cx| {
126 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
127 });
128 assert_eq!(
129 messages(&context, cx),
130 vec![
131 (message_1.id, Role::User, 0..2),
132 (message_2.id, Role::Assistant, 2..4),
133 (message_4.id, Role::User, 4..6),
134 (message_3.id, Role::User, 6..7),
135 ]
136 );
137
138 // Deleting across message boundaries merges the messages.
139 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
140 assert_eq!(
141 messages(&context, cx),
142 vec![
143 (message_1.id, Role::User, 0..3),
144 (message_3.id, Role::User, 3..4),
145 ]
146 );
147
148 // Undoing the deletion should also undo the merge.
149 buffer.update(cx, |buffer, cx| buffer.undo(cx));
150 assert_eq!(
151 messages(&context, cx),
152 vec![
153 (message_1.id, Role::User, 0..2),
154 (message_2.id, Role::Assistant, 2..4),
155 (message_4.id, Role::User, 4..6),
156 (message_3.id, Role::User, 6..7),
157 ]
158 );
159
160 // Redoing the deletion should also redo the merge.
161 buffer.update(cx, |buffer, cx| buffer.redo(cx));
162 assert_eq!(
163 messages(&context, cx),
164 vec![
165 (message_1.id, Role::User, 0..3),
166 (message_3.id, Role::User, 3..4),
167 ]
168 );
169
170 // Ensure we can still insert after a merged message.
171 let message_5 = context.update(cx, |context, cx| {
172 context
173 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
174 .unwrap()
175 });
176 assert_eq!(
177 messages(&context, cx),
178 vec![
179 (message_1.id, Role::User, 0..3),
180 (message_5.id, Role::System, 3..4),
181 (message_3.id, Role::User, 4..5)
182 ]
183 );
184}
185
186#[gpui::test]
187fn test_message_splitting(cx: &mut AppContext) {
188 let settings_store = SettingsStore::test(cx);
189 cx.set_global(settings_store);
190 LanguageModelRegistry::test(cx);
191 assistant_panel::init(cx);
192 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
193
194 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
195 let context = cx.new_model(|cx| {
196 Context::local(
197 registry.clone(),
198 None,
199 None,
200 prompt_builder.clone(),
201 Arc::new(SlashCommandWorkingSet::default()),
202 Arc::new(ToolWorkingSet::default()),
203 cx,
204 )
205 });
206 let buffer = context.read(cx).buffer.clone();
207
208 let message_1 = context.read(cx).message_anchors[0].clone();
209 assert_eq!(
210 messages(&context, cx),
211 vec![(message_1.id, Role::User, 0..0)]
212 );
213
214 buffer.update(cx, |buffer, cx| {
215 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
216 });
217
218 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
219 let message_2 = message_2.unwrap();
220
221 // We recycle newlines in the middle of a split message
222 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
223 assert_eq!(
224 messages(&context, cx),
225 vec![
226 (message_1.id, Role::User, 0..4),
227 (message_2.id, Role::User, 4..16),
228 ]
229 );
230
231 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
232 let message_3 = message_3.unwrap();
233
234 // We don't recycle newlines at the end of a split message
235 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
236 assert_eq!(
237 messages(&context, cx),
238 vec![
239 (message_1.id, Role::User, 0..4),
240 (message_3.id, Role::User, 4..5),
241 (message_2.id, Role::User, 5..17),
242 ]
243 );
244
245 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
246 let message_4 = message_4.unwrap();
247 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
248 assert_eq!(
249 messages(&context, cx),
250 vec![
251 (message_1.id, Role::User, 0..4),
252 (message_3.id, Role::User, 4..5),
253 (message_2.id, Role::User, 5..9),
254 (message_4.id, Role::User, 9..17),
255 ]
256 );
257
258 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
259 let message_5 = message_5.unwrap();
260 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
261 assert_eq!(
262 messages(&context, cx),
263 vec![
264 (message_1.id, Role::User, 0..4),
265 (message_3.id, Role::User, 4..5),
266 (message_2.id, Role::User, 5..9),
267 (message_4.id, Role::User, 9..10),
268 (message_5.id, Role::User, 10..18),
269 ]
270 );
271
272 let (message_6, message_7) =
273 context.update(cx, |context, cx| context.split_message(14..16, cx));
274 let message_6 = message_6.unwrap();
275 let message_7 = message_7.unwrap();
276 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
277 assert_eq!(
278 messages(&context, cx),
279 vec![
280 (message_1.id, Role::User, 0..4),
281 (message_3.id, Role::User, 4..5),
282 (message_2.id, Role::User, 5..9),
283 (message_4.id, Role::User, 9..10),
284 (message_5.id, Role::User, 10..14),
285 (message_6.id, Role::User, 14..17),
286 (message_7.id, Role::User, 17..19),
287 ]
288 );
289}
290
291#[gpui::test]
292fn test_messages_for_offsets(cx: &mut AppContext) {
293 let settings_store = SettingsStore::test(cx);
294 LanguageModelRegistry::test(cx);
295 cx.set_global(settings_store);
296 assistant_panel::init(cx);
297 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
298 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
299 let context = cx.new_model(|cx| {
300 Context::local(
301 registry,
302 None,
303 None,
304 prompt_builder.clone(),
305 Arc::new(SlashCommandWorkingSet::default()),
306 Arc::new(ToolWorkingSet::default()),
307 cx,
308 )
309 });
310 let buffer = context.read(cx).buffer.clone();
311
312 let message_1 = context.read(cx).message_anchors[0].clone();
313 assert_eq!(
314 messages(&context, cx),
315 vec![(message_1.id, Role::User, 0..0)]
316 );
317
318 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
319 let message_2 = context
320 .update(cx, |context, cx| {
321 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
322 })
323 .unwrap();
324 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
325
326 let message_3 = context
327 .update(cx, |context, cx| {
328 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
329 })
330 .unwrap();
331 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
332
333 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
334 assert_eq!(
335 messages(&context, cx),
336 vec![
337 (message_1.id, Role::User, 0..4),
338 (message_2.id, Role::User, 4..8),
339 (message_3.id, Role::User, 8..11)
340 ]
341 );
342
343 assert_eq!(
344 message_ids_for_offsets(&context, &[0, 4, 9], cx),
345 [message_1.id, message_2.id, message_3.id]
346 );
347 assert_eq!(
348 message_ids_for_offsets(&context, &[0, 1, 11], cx),
349 [message_1.id, message_3.id]
350 );
351
352 let message_4 = context
353 .update(cx, |context, cx| {
354 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
355 })
356 .unwrap();
357 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
358 assert_eq!(
359 messages(&context, cx),
360 vec![
361 (message_1.id, Role::User, 0..4),
362 (message_2.id, Role::User, 4..8),
363 (message_3.id, Role::User, 8..12),
364 (message_4.id, Role::User, 12..12)
365 ]
366 );
367 assert_eq!(
368 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
369 [message_1.id, message_2.id, message_3.id, message_4.id]
370 );
371
372 fn message_ids_for_offsets(
373 context: &Model<Context>,
374 offsets: &[usize],
375 cx: &AppContext,
376 ) -> Vec<MessageId> {
377 context
378 .read(cx)
379 .messages_for_offsets(offsets.iter().copied(), cx)
380 .into_iter()
381 .map(|message| message.id)
382 .collect()
383 }
384}
385
386#[gpui::test]
387async fn test_slash_commands(cx: &mut TestAppContext) {
388 let settings_store = cx.update(SettingsStore::test);
389 cx.set_global(settings_store);
390 cx.update(LanguageModelRegistry::test);
391 cx.update(Project::init_settings);
392 cx.update(assistant_panel::init);
393 let fs = FakeFs::new(cx.background_executor.clone());
394
395 fs.insert_tree(
396 "/test",
397 json!({
398 "src": {
399 "lib.rs": "fn one() -> usize { 1 }",
400 "main.rs": "
401 use crate::one;
402 fn main() { one(); }
403 ".unindent(),
404 }
405 }),
406 )
407 .await;
408
409 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
410 slash_command_registry.register_command(file_command::FileSlashCommand, false);
411
412 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
413 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
414 let context = cx.new_model(|cx| {
415 Context::local(
416 registry.clone(),
417 None,
418 None,
419 prompt_builder.clone(),
420 Arc::new(SlashCommandWorkingSet::default()),
421 Arc::new(ToolWorkingSet::default()),
422 cx,
423 )
424 });
425
426 #[derive(Default)]
427 struct ContextRanges {
428 parsed_commands: HashSet<Range<language::Anchor>>,
429 command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
430 output_sections: HashSet<Range<language::Anchor>>,
431 }
432
433 let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
434 context.update(cx, |_, cx| {
435 cx.subscribe(&context, {
436 let context_ranges = context_ranges.clone();
437 move |context, _, event, _| {
438 let mut context_ranges = context_ranges.borrow_mut();
439 match event {
440 ContextEvent::InvokedSlashCommandChanged { command_id } => {
441 let command = context.invoked_slash_command(command_id).unwrap();
442 context_ranges
443 .command_outputs
444 .insert(*command_id, command.range.clone());
445 }
446 ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
447 for range in removed {
448 context_ranges.parsed_commands.remove(range);
449 }
450 for command in updated {
451 context_ranges
452 .parsed_commands
453 .insert(command.source_range.clone());
454 }
455 }
456 ContextEvent::SlashCommandOutputSectionAdded { section } => {
457 context_ranges.output_sections.insert(section.range.clone());
458 }
459 _ => {}
460 }
461 }
462 })
463 .detach();
464 });
465
466 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
467
468 // Insert a slash command
469 buffer.update(cx, |buffer, cx| {
470 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
471 });
472 assert_text_and_context_ranges(
473 &buffer,
474 &context_ranges,
475 &"
476 «/file src/lib.rs»"
477 .unindent(),
478 cx,
479 );
480
481 // Edit the argument of the slash command.
482 buffer.update(cx, |buffer, cx| {
483 let edit_offset = buffer.text().find("lib.rs").unwrap();
484 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
485 });
486 assert_text_and_context_ranges(
487 &buffer,
488 &context_ranges,
489 &"
490 «/file src/main.rs»"
491 .unindent(),
492 cx,
493 );
494
495 // Edit the name of the slash command, using one that doesn't exist.
496 buffer.update(cx, |buffer, cx| {
497 let edit_offset = buffer.text().find("/file").unwrap();
498 buffer.edit(
499 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
500 None,
501 cx,
502 );
503 });
504 assert_text_and_context_ranges(
505 &buffer,
506 &context_ranges,
507 &"
508 /unknown src/main.rs"
509 .unindent(),
510 cx,
511 );
512
513 // Undoing the insertion of an non-existent slash command resorts the previous one.
514 buffer.update(cx, |buffer, cx| buffer.undo(cx));
515 assert_text_and_context_ranges(
516 &buffer,
517 &context_ranges,
518 &"
519 «/file src/main.rs»"
520 .unindent(),
521 cx,
522 );
523
524 let (command_output_tx, command_output_rx) = mpsc::unbounded();
525 context.update(cx, |context, cx| {
526 let command_source_range = context.parsed_slash_commands[0].source_range.clone();
527 context.insert_command_output(
528 command_source_range,
529 "file",
530 Task::ready(Ok(command_output_rx.boxed())),
531 true,
532 cx,
533 );
534 });
535 assert_text_and_context_ranges(
536 &buffer,
537 &context_ranges,
538 &"
539 ⟦«/file src/main.rs»
540 …⟧
541 "
542 .unindent(),
543 cx,
544 );
545
546 command_output_tx
547 .unbounded_send(Ok(SlashCommandEvent::StartSection {
548 icon: IconName::Ai,
549 label: "src/main.rs".into(),
550 metadata: None,
551 }))
552 .unwrap();
553 command_output_tx
554 .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
555 .unwrap();
556 cx.run_until_parked();
557 assert_text_and_context_ranges(
558 &buffer,
559 &context_ranges,
560 &"
561 ⟦«/file src/main.rs»
562 src/main.rs…⟧
563 "
564 .unindent(),
565 cx,
566 );
567
568 command_output_tx
569 .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
570 .unwrap();
571 cx.run_until_parked();
572 assert_text_and_context_ranges(
573 &buffer,
574 &context_ranges,
575 &"
576 ⟦«/file src/main.rs»
577 src/main.rs
578 fn main() {}…⟧
579 "
580 .unindent(),
581 cx,
582 );
583
584 command_output_tx
585 .unbounded_send(Ok(SlashCommandEvent::EndSection))
586 .unwrap();
587 cx.run_until_parked();
588 assert_text_and_context_ranges(
589 &buffer,
590 &context_ranges,
591 &"
592 ⟦«/file src/main.rs»
593 ⟪src/main.rs
594 fn main() {}⟫…⟧
595 "
596 .unindent(),
597 cx,
598 );
599
600 drop(command_output_tx);
601 cx.run_until_parked();
602 assert_text_and_context_ranges(
603 &buffer,
604 &context_ranges,
605 &"
606 ⟦⟪src/main.rs
607 fn main() {}⟫⟧
608 "
609 .unindent(),
610 cx,
611 );
612
613 #[track_caller]
614 fn assert_text_and_context_ranges(
615 buffer: &Model<Buffer>,
616 ranges: &RefCell<ContextRanges>,
617 expected_marked_text: &str,
618 cx: &mut TestAppContext,
619 ) {
620 let mut actual_marked_text = String::new();
621 buffer.update(cx, |buffer, _| {
622 struct Endpoint {
623 offset: usize,
624 marker: char,
625 }
626
627 let ranges = ranges.borrow();
628 let mut endpoints = Vec::new();
629 for range in ranges.command_outputs.values() {
630 endpoints.push(Endpoint {
631 offset: range.start.to_offset(buffer),
632 marker: '⟦',
633 });
634 }
635 for range in ranges.parsed_commands.iter() {
636 endpoints.push(Endpoint {
637 offset: range.start.to_offset(buffer),
638 marker: '«',
639 });
640 }
641 for range in ranges.output_sections.iter() {
642 endpoints.push(Endpoint {
643 offset: range.start.to_offset(buffer),
644 marker: '⟪',
645 });
646 }
647
648 for range in ranges.output_sections.iter() {
649 endpoints.push(Endpoint {
650 offset: range.end.to_offset(buffer),
651 marker: '⟫',
652 });
653 }
654 for range in ranges.parsed_commands.iter() {
655 endpoints.push(Endpoint {
656 offset: range.end.to_offset(buffer),
657 marker: '»',
658 });
659 }
660 for range in ranges.command_outputs.values() {
661 endpoints.push(Endpoint {
662 offset: range.end.to_offset(buffer),
663 marker: '⟧',
664 });
665 }
666
667 endpoints.sort_by_key(|endpoint| endpoint.offset);
668 let mut offset = 0;
669 for endpoint in endpoints {
670 actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
671 actual_marked_text.push(endpoint.marker);
672 offset = endpoint.offset;
673 }
674 actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
675 });
676
677 assert_eq!(actual_marked_text, expected_marked_text);
678 }
679}
680
681#[gpui::test]
682async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
683 cx.update(prompt_library::init);
684 let mut settings_store = cx.update(SettingsStore::test);
685 cx.update(|cx| {
686 settings_store
687 .set_user_settings(
688 r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
689 cx,
690 )
691 .unwrap()
692 });
693 cx.set_global(settings_store);
694 cx.update(language::init);
695 cx.update(Project::init_settings);
696 let fs = FakeFs::new(cx.executor());
697 let project = Project::test(fs, [Path::new("/root")], cx).await;
698 cx.update(LanguageModelRegistry::test);
699
700 cx.update(assistant_panel::init);
701 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
702
703 // Create a new context
704 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
705 let context = cx.new_model(|cx| {
706 Context::local(
707 registry.clone(),
708 Some(project),
709 None,
710 prompt_builder.clone(),
711 Arc::new(SlashCommandWorkingSet::default()),
712 Arc::new(ToolWorkingSet::default()),
713 cx,
714 )
715 });
716
717 // Insert an assistant message to simulate a response.
718 let assistant_message_id = context.update(cx, |context, cx| {
719 let user_message_id = context.messages(cx).next().unwrap().id;
720 context
721 .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
722 .unwrap()
723 .id
724 });
725
726 // No edit tags
727 edit(
728 &context,
729 "
730
731 «one
732 two
733 »",
734 cx,
735 );
736 expect_patches(
737 &context,
738 "
739
740 one
741 two
742 ",
743 &[],
744 cx,
745 );
746
747 // Partial edit step tag is added
748 edit(
749 &context,
750 "
751
752 one
753 two
754 «
755 <patch»",
756 cx,
757 );
758 expect_patches(
759 &context,
760 "
761
762 one
763 two
764
765 <patch",
766 &[],
767 cx,
768 );
769
770 // The rest of the step tag is added. The unclosed
771 // step is treated as incomplete.
772 edit(
773 &context,
774 "
775
776 one
777 two
778
779 <patch«>
780 <edit>»",
781 cx,
782 );
783 expect_patches(
784 &context,
785 "
786
787 one
788 two
789
790 «<patch>
791 <edit>»",
792 &[&[]],
793 cx,
794 );
795
796 // The full patch is added
797 edit(
798 &context,
799 "
800
801 one
802 two
803
804 <patch>
805 <edit>«
806 <description>add a `two` function</description>
807 <path>src/lib.rs</path>
808 <operation>insert_after</operation>
809 <old_text>fn one</old_text>
810 <new_text>
811 fn two() {}
812 </new_text>
813 </edit>
814 </patch>
815
816 also,»",
817 cx,
818 );
819 expect_patches(
820 &context,
821 "
822
823 one
824 two
825
826 «<patch>
827 <edit>
828 <description>add a `two` function</description>
829 <path>src/lib.rs</path>
830 <operation>insert_after</operation>
831 <old_text>fn one</old_text>
832 <new_text>
833 fn two() {}
834 </new_text>
835 </edit>
836 </patch>
837 »
838 also,",
839 &[&[AssistantEdit {
840 path: "src/lib.rs".into(),
841 kind: AssistantEditKind::InsertAfter {
842 old_text: "fn one".into(),
843 new_text: "fn two() {}".into(),
844 description: Some("add a `two` function".into()),
845 },
846 }]],
847 cx,
848 );
849
850 // The step is manually edited.
851 edit(
852 &context,
853 "
854
855 one
856 two
857
858 <patch>
859 <edit>
860 <description>add a `two` function</description>
861 <path>src/lib.rs</path>
862 <operation>insert_after</operation>
863 <old_text>«fn zero»</old_text>
864 <new_text>
865 fn two() {}
866 </new_text>
867 </edit>
868 </patch>
869
870 also,",
871 cx,
872 );
873 expect_patches(
874 &context,
875 "
876
877 one
878 two
879
880 «<patch>
881 <edit>
882 <description>add a `two` function</description>
883 <path>src/lib.rs</path>
884 <operation>insert_after</operation>
885 <old_text>fn zero</old_text>
886 <new_text>
887 fn two() {}
888 </new_text>
889 </edit>
890 </patch>
891 »
892 also,",
893 &[&[AssistantEdit {
894 path: "src/lib.rs".into(),
895 kind: AssistantEditKind::InsertAfter {
896 old_text: "fn zero".into(),
897 new_text: "fn two() {}".into(),
898 description: Some("add a `two` function".into()),
899 },
900 }]],
901 cx,
902 );
903
904 // When setting the message role to User, the steps are cleared.
905 context.update(cx, |context, cx| {
906 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
907 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
908 });
909 expect_patches(
910 &context,
911 "
912
913 one
914 two
915
916 <patch>
917 <edit>
918 <description>add a `two` function</description>
919 <path>src/lib.rs</path>
920 <operation>insert_after</operation>
921 <old_text>fn zero</old_text>
922 <new_text>
923 fn two() {}
924 </new_text>
925 </edit>
926 </patch>
927
928 also,",
929 &[],
930 cx,
931 );
932
933 // When setting the message role back to Assistant, the steps are reparsed.
934 context.update(cx, |context, cx| {
935 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
936 });
937 expect_patches(
938 &context,
939 "
940
941 one
942 two
943
944 «<patch>
945 <edit>
946 <description>add a `two` function</description>
947 <path>src/lib.rs</path>
948 <operation>insert_after</operation>
949 <old_text>fn zero</old_text>
950 <new_text>
951 fn two() {}
952 </new_text>
953 </edit>
954 </patch>
955 »
956 also,",
957 &[&[AssistantEdit {
958 path: "src/lib.rs".into(),
959 kind: AssistantEditKind::InsertAfter {
960 old_text: "fn zero".into(),
961 new_text: "fn two() {}".into(),
962 description: Some("add a `two` function".into()),
963 },
964 }]],
965 cx,
966 );
967
968 // Ensure steps are re-parsed when deserializing.
969 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
970 let deserialized_context = cx.new_model(|cx| {
971 Context::deserialize(
972 serialized_context,
973 Default::default(),
974 registry.clone(),
975 prompt_builder.clone(),
976 Arc::new(SlashCommandWorkingSet::default()),
977 Arc::new(ToolWorkingSet::default()),
978 None,
979 None,
980 cx,
981 )
982 });
983 expect_patches(
984 &deserialized_context,
985 "
986
987 one
988 two
989
990 «<patch>
991 <edit>
992 <description>add a `two` function</description>
993 <path>src/lib.rs</path>
994 <operation>insert_after</operation>
995 <old_text>fn zero</old_text>
996 <new_text>
997 fn two() {}
998 </new_text>
999 </edit>
1000 </patch>
1001 »
1002 also,",
1003 &[&[AssistantEdit {
1004 path: "src/lib.rs".into(),
1005 kind: AssistantEditKind::InsertAfter {
1006 old_text: "fn zero".into(),
1007 new_text: "fn two() {}".into(),
1008 description: Some("add a `two` function".into()),
1009 },
1010 }]],
1011 cx,
1012 );
1013
1014 fn edit(context: &Model<Context>, new_text_marked_with_edits: &str, cx: &mut TestAppContext) {
1015 context.update(cx, |context, cx| {
1016 context.buffer.update(cx, |buffer, cx| {
1017 buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
1018 });
1019 });
1020 cx.executor().run_until_parked();
1021 }
1022
1023 #[track_caller]
1024 fn expect_patches(
1025 context: &Model<Context>,
1026 expected_marked_text: &str,
1027 expected_suggestions: &[&[AssistantEdit]],
1028 cx: &mut TestAppContext,
1029 ) {
1030 let expected_marked_text = expected_marked_text.unindent();
1031 let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
1032
1033 let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
1034 context.buffer.read_with(cx, |buffer, _| {
1035 let ranges = context
1036 .patches
1037 .iter()
1038 .map(|entry| entry.range.to_offset(buffer))
1039 .collect::<Vec<_>>();
1040 (
1041 buffer.text(),
1042 ranges,
1043 context
1044 .patches
1045 .iter()
1046 .map(|step| step.edits.clone())
1047 .collect::<Vec<_>>(),
1048 )
1049 })
1050 });
1051
1052 assert_eq!(buffer_text, expected_text);
1053
1054 let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
1055 assert_eq!(actual_marked_text, expected_marked_text);
1056
1057 assert_eq!(
1058 patches
1059 .iter()
1060 .map(|patch| {
1061 patch
1062 .iter()
1063 .map(|edit| {
1064 let edit = edit.as_ref().unwrap();
1065 AssistantEdit {
1066 path: edit.path.clone(),
1067 kind: edit.kind.clone(),
1068 }
1069 })
1070 .collect::<Vec<_>>()
1071 })
1072 .collect::<Vec<_>>(),
1073 expected_suggestions
1074 );
1075 }
1076}
1077
1078#[gpui::test]
1079async fn test_serialization(cx: &mut TestAppContext) {
1080 let settings_store = cx.update(SettingsStore::test);
1081 cx.set_global(settings_store);
1082 cx.update(LanguageModelRegistry::test);
1083 cx.update(assistant_panel::init);
1084 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1085 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1086 let context = cx.new_model(|cx| {
1087 Context::local(
1088 registry.clone(),
1089 None,
1090 None,
1091 prompt_builder.clone(),
1092 Arc::new(SlashCommandWorkingSet::default()),
1093 Arc::new(ToolWorkingSet::default()),
1094 cx,
1095 )
1096 });
1097 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
1098 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
1099 let message_1 = context.update(cx, |context, cx| {
1100 context
1101 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
1102 .unwrap()
1103 });
1104 let message_2 = context.update(cx, |context, cx| {
1105 context
1106 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1107 .unwrap()
1108 });
1109 buffer.update(cx, |buffer, cx| {
1110 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
1111 buffer.finalize_last_transaction();
1112 });
1113 let _message_3 = context.update(cx, |context, cx| {
1114 context
1115 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
1116 .unwrap()
1117 });
1118 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1119 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
1120 assert_eq!(
1121 cx.read(|cx| messages(&context, cx)),
1122 [
1123 (message_0, Role::User, 0..2),
1124 (message_1.id, Role::Assistant, 2..6),
1125 (message_2.id, Role::System, 6..6),
1126 ]
1127 );
1128
1129 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
1130 let deserialized_context = cx.new_model(|cx| {
1131 Context::deserialize(
1132 serialized_context,
1133 Default::default(),
1134 registry.clone(),
1135 prompt_builder.clone(),
1136 Arc::new(SlashCommandWorkingSet::default()),
1137 Arc::new(ToolWorkingSet::default()),
1138 None,
1139 None,
1140 cx,
1141 )
1142 });
1143 let deserialized_buffer =
1144 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
1145 assert_eq!(
1146 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
1147 "a\nb\nc\n"
1148 );
1149 assert_eq!(
1150 cx.read(|cx| messages(&deserialized_context, cx)),
1151 [
1152 (message_0, Role::User, 0..2),
1153 (message_1.id, Role::Assistant, 2..6),
1154 (message_2.id, Role::System, 6..6),
1155 ]
1156 );
1157}
1158
1159#[gpui::test(iterations = 100)]
1160async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
1161 let min_peers = env::var("MIN_PEERS")
1162 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
1163 .unwrap_or(2);
1164 let max_peers = env::var("MAX_PEERS")
1165 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
1166 .unwrap_or(5);
1167 let operations = env::var("OPERATIONS")
1168 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
1169 .unwrap_or(50);
1170
1171 let settings_store = cx.update(SettingsStore::test);
1172 cx.set_global(settings_store);
1173 cx.update(LanguageModelRegistry::test);
1174
1175 cx.update(assistant_panel::init);
1176 let slash_commands = cx.update(SlashCommandRegistry::default_global);
1177 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
1178 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
1179 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
1180
1181 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
1182 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
1183 let mut contexts = Vec::new();
1184
1185 let num_peers = rng.gen_range(min_peers..=max_peers);
1186 let context_id = ContextId::new();
1187 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1188 for i in 0..num_peers {
1189 let context = cx.new_model(|cx| {
1190 Context::new(
1191 context_id.clone(),
1192 i as ReplicaId,
1193 language::Capability::ReadWrite,
1194 registry.clone(),
1195 prompt_builder.clone(),
1196 Arc::new(SlashCommandWorkingSet::default()),
1197 Arc::new(ToolWorkingSet::default()),
1198 None,
1199 None,
1200 cx,
1201 )
1202 });
1203
1204 cx.update(|cx| {
1205 cx.subscribe(&context, {
1206 let network = network.clone();
1207 move |_, event, _| {
1208 if let ContextEvent::Operation(op) = event {
1209 network
1210 .lock()
1211 .broadcast(i as ReplicaId, vec![op.to_proto()]);
1212 }
1213 }
1214 })
1215 .detach();
1216 });
1217
1218 contexts.push(context);
1219 network.lock().add_peer(i as ReplicaId);
1220 }
1221
1222 let mut mutation_count = operations;
1223
1224 while mutation_count > 0
1225 || !network.lock().is_idle()
1226 || network.lock().contains_disconnected_peers()
1227 {
1228 let context_index = rng.gen_range(0..contexts.len());
1229 let context = &contexts[context_index];
1230
1231 match rng.gen_range(0..100) {
1232 0..=29 if mutation_count > 0 => {
1233 log::info!("Context {}: edit buffer", context_index);
1234 context.update(cx, |context, cx| {
1235 context
1236 .buffer
1237 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1238 });
1239 mutation_count -= 1;
1240 }
1241 30..=44 if mutation_count > 0 => {
1242 context.update(cx, |context, cx| {
1243 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1244 log::info!("Context {}: split message at {:?}", context_index, range);
1245 context.split_message(range, cx);
1246 });
1247 mutation_count -= 1;
1248 }
1249 45..=59 if mutation_count > 0 => {
1250 context.update(cx, |context, cx| {
1251 if let Some(message) = context.messages(cx).choose(&mut rng) {
1252 let role = *[Role::User, Role::Assistant, Role::System]
1253 .choose(&mut rng)
1254 .unwrap();
1255 log::info!(
1256 "Context {}: insert message after {:?} with {:?}",
1257 context_index,
1258 message.id,
1259 role
1260 );
1261 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1262 }
1263 });
1264 mutation_count -= 1;
1265 }
1266 60..=74 if mutation_count > 0 => {
1267 context.update(cx, |context, cx| {
1268 let command_text = "/".to_string()
1269 + slash_commands
1270 .command_names()
1271 .choose(&mut rng)
1272 .unwrap()
1273 .clone()
1274 .as_ref();
1275
1276 let command_range = context.buffer.update(cx, |buffer, cx| {
1277 let offset = buffer.random_byte_range(0, &mut rng).start;
1278 buffer.edit(
1279 [(offset..offset, format!("\n{}\n", command_text))],
1280 None,
1281 cx,
1282 );
1283 offset + 1..offset + 1 + command_text.len()
1284 });
1285
1286 let output_text = RandomCharIter::new(&mut rng)
1287 .filter(|c| *c != '\r')
1288 .take(10)
1289 .collect::<String>();
1290
1291 let mut events = vec![Ok(SlashCommandEvent::StartMessage {
1292 role: Role::User,
1293 merge_same_roles: true,
1294 })];
1295
1296 let num_sections = rng.gen_range(0..=3);
1297 let mut section_start = 0;
1298 for _ in 0..num_sections {
1299 let mut section_end = rng.gen_range(section_start..=output_text.len());
1300 while !output_text.is_char_boundary(section_end) {
1301 section_end += 1;
1302 }
1303 events.push(Ok(SlashCommandEvent::StartSection {
1304 icon: IconName::Ai,
1305 label: "section".into(),
1306 metadata: None,
1307 }));
1308 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1309 text: output_text[section_start..section_end].to_string(),
1310 run_commands_in_text: false,
1311 })));
1312 events.push(Ok(SlashCommandEvent::EndSection));
1313 section_start = section_end;
1314 }
1315
1316 if section_start < output_text.len() {
1317 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1318 text: output_text[section_start..].to_string(),
1319 run_commands_in_text: false,
1320 })));
1321 }
1322
1323 log::info!(
1324 "Context {}: insert slash command output at {:?} with {:?} events",
1325 context_index,
1326 command_range,
1327 events.len()
1328 );
1329
1330 let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1331 ..context.buffer.read(cx).anchor_after(command_range.end);
1332 context.insert_command_output(
1333 command_range,
1334 "/command",
1335 Task::ready(Ok(stream::iter(events).boxed())),
1336 true,
1337 cx,
1338 );
1339 });
1340 cx.run_until_parked();
1341 mutation_count -= 1;
1342 }
1343 75..=84 if mutation_count > 0 => {
1344 context.update(cx, |context, cx| {
1345 if let Some(message) = context.messages(cx).choose(&mut rng) {
1346 let new_status = match rng.gen_range(0..3) {
1347 0 => MessageStatus::Done,
1348 1 => MessageStatus::Pending,
1349 _ => MessageStatus::Error(SharedString::from("Random error")),
1350 };
1351 log::info!(
1352 "Context {}: update message {:?} status to {:?}",
1353 context_index,
1354 message.id,
1355 new_status
1356 );
1357 context.update_metadata(message.id, cx, |metadata| {
1358 metadata.status = new_status;
1359 });
1360 }
1361 });
1362 mutation_count -= 1;
1363 }
1364 _ => {
1365 let replica_id = context_index as ReplicaId;
1366 if network.lock().is_disconnected(replica_id) {
1367 network.lock().reconnect_peer(replica_id, 0);
1368
1369 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1370 let host_context = &contexts[0].read(cx);
1371 let guest_context = context.read(cx);
1372 (
1373 guest_context.serialize_ops(&host_context.version(cx), cx),
1374 host_context.serialize_ops(&guest_context.version(cx), cx),
1375 )
1376 });
1377 let ops_to_send = ops_to_send.await;
1378 let ops_to_receive = ops_to_receive
1379 .await
1380 .into_iter()
1381 .map(ContextOperation::from_proto)
1382 .collect::<Result<Vec<_>>>()
1383 .unwrap();
1384 log::info!(
1385 "Context {}: reconnecting. Sent {} operations, received {} operations",
1386 context_index,
1387 ops_to_send.len(),
1388 ops_to_receive.len()
1389 );
1390
1391 network.lock().broadcast(replica_id, ops_to_send);
1392 context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1393 } else if rng.gen_bool(0.1) && replica_id != 0 {
1394 log::info!("Context {}: disconnecting", context_index);
1395 network.lock().disconnect_peer(replica_id);
1396 } else if network.lock().has_unreceived(replica_id) {
1397 log::info!("Context {}: applying operations", context_index);
1398 let ops = network.lock().receive(replica_id);
1399 let ops = ops
1400 .into_iter()
1401 .map(ContextOperation::from_proto)
1402 .collect::<Result<Vec<_>>>()
1403 .unwrap();
1404 context.update(cx, |context, cx| context.apply_ops(ops, cx));
1405 }
1406 }
1407 }
1408 }
1409
1410 cx.read(|cx| {
1411 let first_context = contexts[0].read(cx);
1412 for context in &contexts[1..] {
1413 let context = context.read(cx);
1414 assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
1415 assert_eq!(
1416 context.buffer.read(cx).text(),
1417 first_context.buffer.read(cx).text(),
1418 "Context {} text != Context 0 text",
1419 context.buffer.read(cx).replica_id()
1420 );
1421 assert_eq!(
1422 context.message_anchors,
1423 first_context.message_anchors,
1424 "Context {} messages != Context 0 messages",
1425 context.buffer.read(cx).replica_id()
1426 );
1427 assert_eq!(
1428 context.messages_metadata,
1429 first_context.messages_metadata,
1430 "Context {} message metadata != Context 0 message metadata",
1431 context.buffer.read(cx).replica_id()
1432 );
1433 assert_eq!(
1434 context.slash_command_output_sections,
1435 first_context.slash_command_output_sections,
1436 "Context {} slash command output sections != Context 0 slash command output sections",
1437 context.buffer.read(cx).replica_id()
1438 );
1439 }
1440 });
1441}
1442
1443#[gpui::test]
1444fn test_mark_cache_anchors(cx: &mut AppContext) {
1445 let settings_store = SettingsStore::test(cx);
1446 LanguageModelRegistry::test(cx);
1447 cx.set_global(settings_store);
1448 assistant_panel::init(cx);
1449 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1450 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1451 let context = cx.new_model(|cx| {
1452 Context::local(
1453 registry,
1454 None,
1455 None,
1456 prompt_builder.clone(),
1457 Arc::new(SlashCommandWorkingSet::default()),
1458 Arc::new(ToolWorkingSet::default()),
1459 cx,
1460 )
1461 });
1462 let buffer = context.read(cx).buffer.clone();
1463
1464 // Create a test cache configuration
1465 let cache_configuration = &Some(LanguageModelCacheConfiguration {
1466 max_cache_anchors: 3,
1467 should_speculate: true,
1468 min_total_token: 10,
1469 });
1470
1471 let message_1 = context.read(cx).message_anchors[0].clone();
1472
1473 context.update(cx, |context, cx| {
1474 context.mark_cache_anchors(cache_configuration, false, cx)
1475 });
1476
1477 assert_eq!(
1478 messages_cache(&context, cx)
1479 .iter()
1480 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1481 .count(),
1482 0,
1483 "Empty messages should not have any cache anchors."
1484 );
1485
1486 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1487 let message_2 = context
1488 .update(cx, |context, cx| {
1489 context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1490 })
1491 .unwrap();
1492
1493 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1494 let message_3 = context
1495 .update(cx, |context, cx| {
1496 context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1497 })
1498 .unwrap();
1499 buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1500
1501 context.update(cx, |context, cx| {
1502 context.mark_cache_anchors(cache_configuration, false, cx)
1503 });
1504 assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1505 assert_eq!(
1506 messages_cache(&context, cx)
1507 .iter()
1508 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1509 .count(),
1510 0,
1511 "Messages should not be marked for cache before going over the token minimum."
1512 );
1513 context.update(cx, |context, _| {
1514 context.token_count = Some(20);
1515 });
1516
1517 context.update(cx, |context, cx| {
1518 context.mark_cache_anchors(cache_configuration, true, cx)
1519 });
1520 assert_eq!(
1521 messages_cache(&context, cx)
1522 .iter()
1523 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1524 .collect::<Vec<bool>>(),
1525 vec![true, true, false],
1526 "Last message should not be an anchor on speculative request."
1527 );
1528
1529 context
1530 .update(cx, |context, cx| {
1531 context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1532 })
1533 .unwrap();
1534
1535 context.update(cx, |context, cx| {
1536 context.mark_cache_anchors(cache_configuration, false, cx)
1537 });
1538 assert_eq!(
1539 messages_cache(&context, cx)
1540 .iter()
1541 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1542 .collect::<Vec<bool>>(),
1543 vec![false, true, true, false],
1544 "Most recent message should also be cached if not a speculative request."
1545 );
1546 context.update(cx, |context, cx| {
1547 context.update_cache_status_for_completion(cx)
1548 });
1549 assert_eq!(
1550 messages_cache(&context, cx)
1551 .iter()
1552 .map(|(_, cache)| cache
1553 .as_ref()
1554 .map_or(None, |cache| Some(cache.status.clone())))
1555 .collect::<Vec<Option<CacheStatus>>>(),
1556 vec![
1557 Some(CacheStatus::Cached),
1558 Some(CacheStatus::Cached),
1559 Some(CacheStatus::Cached),
1560 None
1561 ],
1562 "All user messages prior to anchor should be marked as cached."
1563 );
1564
1565 buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1566 context.update(cx, |context, cx| {
1567 context.mark_cache_anchors(cache_configuration, false, cx)
1568 });
1569 assert_eq!(
1570 messages_cache(&context, cx)
1571 .iter()
1572 .map(|(_, cache)| cache
1573 .as_ref()
1574 .map_or(None, |cache| Some(cache.status.clone())))
1575 .collect::<Vec<Option<CacheStatus>>>(),
1576 vec![
1577 Some(CacheStatus::Cached),
1578 Some(CacheStatus::Cached),
1579 Some(CacheStatus::Pending),
1580 None
1581 ],
1582 "Modifying a message should invalidate it's cache but leave previous messages."
1583 );
1584 buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1585 context.update(cx, |context, cx| {
1586 context.mark_cache_anchors(cache_configuration, false, cx)
1587 });
1588 assert_eq!(
1589 messages_cache(&context, cx)
1590 .iter()
1591 .map(|(_, cache)| cache
1592 .as_ref()
1593 .map_or(None, |cache| Some(cache.status.clone())))
1594 .collect::<Vec<Option<CacheStatus>>>(),
1595 vec![
1596 Some(CacheStatus::Pending),
1597 Some(CacheStatus::Pending),
1598 Some(CacheStatus::Pending),
1599 None
1600 ],
1601 "Modifying a message should invalidate all future messages."
1602 );
1603}
1604
1605fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
1606 context
1607 .read(cx)
1608 .messages(cx)
1609 .map(|message| (message.id, message.role, message.offset_range))
1610 .collect()
1611}
1612
1613fn messages_cache(
1614 context: &Model<Context>,
1615 cx: &AppContext,
1616) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1617 context
1618 .read(cx)
1619 .messages(cx)
1620 .map(|message| (message.id, message.cache.clone()))
1621 .collect()
1622}
1623
1624#[derive(Clone)]
1625struct FakeSlashCommand(String);
1626
1627impl SlashCommand for FakeSlashCommand {
1628 fn name(&self) -> String {
1629 self.0.clone()
1630 }
1631
1632 fn description(&self) -> String {
1633 format!("Fake slash command: {}", self.0)
1634 }
1635
1636 fn menu_text(&self) -> String {
1637 format!("Run fake command: {}", self.0)
1638 }
1639
1640 fn complete_argument(
1641 self: Arc<Self>,
1642 _arguments: &[String],
1643 _cancel: Arc<AtomicBool>,
1644 _workspace: Option<WeakView<Workspace>>,
1645 _cx: &mut WindowContext,
1646 ) -> Task<Result<Vec<ArgumentCompletion>>> {
1647 Task::ready(Ok(vec![]))
1648 }
1649
1650 fn requires_argument(&self) -> bool {
1651 false
1652 }
1653
1654 fn run(
1655 self: Arc<Self>,
1656 _arguments: &[String],
1657 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1658 _context_buffer: BufferSnapshot,
1659 _workspace: WeakView<Workspace>,
1660 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1661 _cx: &mut WindowContext,
1662 ) -> Task<SlashCommandResult> {
1663 Task::ready(Ok(SlashCommandOutput {
1664 text: format!("Executed fake command: {}", self.0),
1665 sections: vec![],
1666 run_commands_in_text: false,
1667 }
1668 .to_event_stream()))
1669 }
1670}