1use super::{AssistantEdit, MessageCacheMetadata};
2use crate::slash_command_working_set::SlashCommandWorkingSet;
3use crate::{
4 assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus,
5 Context, ContextEvent, ContextId, ContextOperation, InvokedSlashCommandId, MessageId,
6 MessageStatus, PromptBuilder,
7};
8use anyhow::Result;
9use assistant_slash_command::{
10 ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
11 SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult,
12};
13use assistant_tool::ToolWorkingSet;
14use collections::{HashMap, HashSet};
15use fs::FakeFs;
16use futures::{
17 channel::mpsc,
18 stream::{self, StreamExt},
19};
20use gpui::{prelude::*, AppContext, Model, SharedString, Task, TestAppContext, WeakView};
21use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
22use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
23use parking_lot::Mutex;
24use pretty_assertions::assert_eq;
25use project::Project;
26use rand::prelude::*;
27use serde_json::json;
28use settings::SettingsStore;
29use std::{
30 cell::RefCell,
31 env,
32 ops::Range,
33 path::Path,
34 rc::Rc,
35 sync::{atomic::AtomicBool, Arc},
36};
37use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToOffset};
38use ui::{IconName, WindowContext};
39use unindent::Unindent;
40use util::{
41 test::{generate_marked_text, marked_text_ranges},
42 RandomCharIter,
43};
44use workspace::Workspace;
45
46#[gpui::test]
47fn test_inserting_and_removing_messages(cx: &mut AppContext) {
48 let settings_store = SettingsStore::test(cx);
49 LanguageModelRegistry::test(cx);
50 cx.set_global(settings_store);
51 assistant_panel::init(cx);
52 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
53 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
54 let context = cx.new_model(|cx| {
55 Context::local(
56 registry,
57 None,
58 None,
59 prompt_builder.clone(),
60 Arc::new(SlashCommandWorkingSet::default()),
61 Arc::new(ToolWorkingSet::default()),
62 cx,
63 )
64 });
65 let buffer = context.read(cx).buffer.clone();
66
67 let message_1 = context.read(cx).message_anchors[0].clone();
68 assert_eq!(
69 messages(&context, cx),
70 vec![(message_1.id, Role::User, 0..0)]
71 );
72
73 let message_2 = context.update(cx, |context, cx| {
74 context
75 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
76 .unwrap()
77 });
78 assert_eq!(
79 messages(&context, cx),
80 vec![
81 (message_1.id, Role::User, 0..1),
82 (message_2.id, Role::Assistant, 1..1)
83 ]
84 );
85
86 buffer.update(cx, |buffer, cx| {
87 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
88 });
89 assert_eq!(
90 messages(&context, cx),
91 vec![
92 (message_1.id, Role::User, 0..2),
93 (message_2.id, Role::Assistant, 2..3)
94 ]
95 );
96
97 let message_3 = context.update(cx, |context, cx| {
98 context
99 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
100 .unwrap()
101 });
102 assert_eq!(
103 messages(&context, cx),
104 vec![
105 (message_1.id, Role::User, 0..2),
106 (message_2.id, Role::Assistant, 2..4),
107 (message_3.id, Role::User, 4..4)
108 ]
109 );
110
111 let message_4 = context.update(cx, |context, cx| {
112 context
113 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
114 .unwrap()
115 });
116 assert_eq!(
117 messages(&context, cx),
118 vec![
119 (message_1.id, Role::User, 0..2),
120 (message_2.id, Role::Assistant, 2..4),
121 (message_4.id, Role::User, 4..5),
122 (message_3.id, Role::User, 5..5),
123 ]
124 );
125
126 buffer.update(cx, |buffer, cx| {
127 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
128 });
129 assert_eq!(
130 messages(&context, cx),
131 vec![
132 (message_1.id, Role::User, 0..2),
133 (message_2.id, Role::Assistant, 2..4),
134 (message_4.id, Role::User, 4..6),
135 (message_3.id, Role::User, 6..7),
136 ]
137 );
138
139 // Deleting across message boundaries merges the messages.
140 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
141 assert_eq!(
142 messages(&context, cx),
143 vec![
144 (message_1.id, Role::User, 0..3),
145 (message_3.id, Role::User, 3..4),
146 ]
147 );
148
149 // Undoing the deletion should also undo the merge.
150 buffer.update(cx, |buffer, cx| buffer.undo(cx));
151 assert_eq!(
152 messages(&context, cx),
153 vec![
154 (message_1.id, Role::User, 0..2),
155 (message_2.id, Role::Assistant, 2..4),
156 (message_4.id, Role::User, 4..6),
157 (message_3.id, Role::User, 6..7),
158 ]
159 );
160
161 // Redoing the deletion should also redo the merge.
162 buffer.update(cx, |buffer, cx| buffer.redo(cx));
163 assert_eq!(
164 messages(&context, cx),
165 vec![
166 (message_1.id, Role::User, 0..3),
167 (message_3.id, Role::User, 3..4),
168 ]
169 );
170
171 // Ensure we can still insert after a merged message.
172 let message_5 = context.update(cx, |context, cx| {
173 context
174 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
175 .unwrap()
176 });
177 assert_eq!(
178 messages(&context, cx),
179 vec![
180 (message_1.id, Role::User, 0..3),
181 (message_5.id, Role::System, 3..4),
182 (message_3.id, Role::User, 4..5)
183 ]
184 );
185}
186
187#[gpui::test]
188fn test_message_splitting(cx: &mut AppContext) {
189 let settings_store = SettingsStore::test(cx);
190 cx.set_global(settings_store);
191 LanguageModelRegistry::test(cx);
192 assistant_panel::init(cx);
193 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
194
195 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
196 let context = cx.new_model(|cx| {
197 Context::local(
198 registry.clone(),
199 None,
200 None,
201 prompt_builder.clone(),
202 Arc::new(SlashCommandWorkingSet::default()),
203 Arc::new(ToolWorkingSet::default()),
204 cx,
205 )
206 });
207 let buffer = context.read(cx).buffer.clone();
208
209 let message_1 = context.read(cx).message_anchors[0].clone();
210 assert_eq!(
211 messages(&context, cx),
212 vec![(message_1.id, Role::User, 0..0)]
213 );
214
215 buffer.update(cx, |buffer, cx| {
216 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
217 });
218
219 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
220 let message_2 = message_2.unwrap();
221
222 // We recycle newlines in the middle of a split message
223 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
224 assert_eq!(
225 messages(&context, cx),
226 vec![
227 (message_1.id, Role::User, 0..4),
228 (message_2.id, Role::User, 4..16),
229 ]
230 );
231
232 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
233 let message_3 = message_3.unwrap();
234
235 // We don't recycle newlines at the end of a split message
236 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
237 assert_eq!(
238 messages(&context, cx),
239 vec![
240 (message_1.id, Role::User, 0..4),
241 (message_3.id, Role::User, 4..5),
242 (message_2.id, Role::User, 5..17),
243 ]
244 );
245
246 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
247 let message_4 = message_4.unwrap();
248 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
249 assert_eq!(
250 messages(&context, cx),
251 vec![
252 (message_1.id, Role::User, 0..4),
253 (message_3.id, Role::User, 4..5),
254 (message_2.id, Role::User, 5..9),
255 (message_4.id, Role::User, 9..17),
256 ]
257 );
258
259 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
260 let message_5 = message_5.unwrap();
261 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
262 assert_eq!(
263 messages(&context, cx),
264 vec![
265 (message_1.id, Role::User, 0..4),
266 (message_3.id, Role::User, 4..5),
267 (message_2.id, Role::User, 5..9),
268 (message_4.id, Role::User, 9..10),
269 (message_5.id, Role::User, 10..18),
270 ]
271 );
272
273 let (message_6, message_7) =
274 context.update(cx, |context, cx| context.split_message(14..16, cx));
275 let message_6 = message_6.unwrap();
276 let message_7 = message_7.unwrap();
277 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
278 assert_eq!(
279 messages(&context, cx),
280 vec![
281 (message_1.id, Role::User, 0..4),
282 (message_3.id, Role::User, 4..5),
283 (message_2.id, Role::User, 5..9),
284 (message_4.id, Role::User, 9..10),
285 (message_5.id, Role::User, 10..14),
286 (message_6.id, Role::User, 14..17),
287 (message_7.id, Role::User, 17..19),
288 ]
289 );
290}
291
292#[gpui::test]
293fn test_messages_for_offsets(cx: &mut AppContext) {
294 let settings_store = SettingsStore::test(cx);
295 LanguageModelRegistry::test(cx);
296 cx.set_global(settings_store);
297 assistant_panel::init(cx);
298 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
299 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
300 let context = cx.new_model(|cx| {
301 Context::local(
302 registry,
303 None,
304 None,
305 prompt_builder.clone(),
306 Arc::new(SlashCommandWorkingSet::default()),
307 Arc::new(ToolWorkingSet::default()),
308 cx,
309 )
310 });
311 let buffer = context.read(cx).buffer.clone();
312
313 let message_1 = context.read(cx).message_anchors[0].clone();
314 assert_eq!(
315 messages(&context, cx),
316 vec![(message_1.id, Role::User, 0..0)]
317 );
318
319 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
320 let message_2 = context
321 .update(cx, |context, cx| {
322 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
323 })
324 .unwrap();
325 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
326
327 let message_3 = context
328 .update(cx, |context, cx| {
329 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
330 })
331 .unwrap();
332 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
333
334 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
335 assert_eq!(
336 messages(&context, cx),
337 vec![
338 (message_1.id, Role::User, 0..4),
339 (message_2.id, Role::User, 4..8),
340 (message_3.id, Role::User, 8..11)
341 ]
342 );
343
344 assert_eq!(
345 message_ids_for_offsets(&context, &[0, 4, 9], cx),
346 [message_1.id, message_2.id, message_3.id]
347 );
348 assert_eq!(
349 message_ids_for_offsets(&context, &[0, 1, 11], cx),
350 [message_1.id, message_3.id]
351 );
352
353 let message_4 = context
354 .update(cx, |context, cx| {
355 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
356 })
357 .unwrap();
358 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
359 assert_eq!(
360 messages(&context, cx),
361 vec![
362 (message_1.id, Role::User, 0..4),
363 (message_2.id, Role::User, 4..8),
364 (message_3.id, Role::User, 8..12),
365 (message_4.id, Role::User, 12..12)
366 ]
367 );
368 assert_eq!(
369 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
370 [message_1.id, message_2.id, message_3.id, message_4.id]
371 );
372
373 fn message_ids_for_offsets(
374 context: &Model<Context>,
375 offsets: &[usize],
376 cx: &AppContext,
377 ) -> Vec<MessageId> {
378 context
379 .read(cx)
380 .messages_for_offsets(offsets.iter().copied(), cx)
381 .into_iter()
382 .map(|message| message.id)
383 .collect()
384 }
385}
386
387#[gpui::test]
388async fn test_slash_commands(cx: &mut TestAppContext) {
389 let settings_store = cx.update(SettingsStore::test);
390 cx.set_global(settings_store);
391 cx.update(LanguageModelRegistry::test);
392 cx.update(Project::init_settings);
393 cx.update(assistant_panel::init);
394 let fs = FakeFs::new(cx.background_executor.clone());
395
396 fs.insert_tree(
397 "/test",
398 json!({
399 "src": {
400 "lib.rs": "fn one() -> usize { 1 }",
401 "main.rs": "
402 use crate::one;
403 fn main() { one(); }
404 ".unindent(),
405 }
406 }),
407 )
408 .await;
409
410 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
411 slash_command_registry.register_command(file_command::FileSlashCommand, false);
412
413 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
414 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
415 let context = cx.new_model(|cx| {
416 Context::local(
417 registry.clone(),
418 None,
419 None,
420 prompt_builder.clone(),
421 Arc::new(SlashCommandWorkingSet::default()),
422 Arc::new(ToolWorkingSet::default()),
423 cx,
424 )
425 });
426
427 #[derive(Default)]
428 struct ContextRanges {
429 parsed_commands: HashSet<Range<language::Anchor>>,
430 command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
431 output_sections: HashSet<Range<language::Anchor>>,
432 }
433
434 let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
435 context.update(cx, |_, cx| {
436 cx.subscribe(&context, {
437 let context_ranges = context_ranges.clone();
438 move |context, _, event, _| {
439 let mut context_ranges = context_ranges.borrow_mut();
440 match event {
441 ContextEvent::InvokedSlashCommandChanged { command_id } => {
442 let command = context.invoked_slash_command(command_id).unwrap();
443 context_ranges
444 .command_outputs
445 .insert(*command_id, command.range.clone());
446 }
447 ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
448 for range in removed {
449 context_ranges.parsed_commands.remove(range);
450 }
451 for command in updated {
452 context_ranges
453 .parsed_commands
454 .insert(command.source_range.clone());
455 }
456 }
457 ContextEvent::SlashCommandOutputSectionAdded { section } => {
458 context_ranges.output_sections.insert(section.range.clone());
459 }
460 _ => {}
461 }
462 }
463 })
464 .detach();
465 });
466
467 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
468
469 // Insert a slash command
470 buffer.update(cx, |buffer, cx| {
471 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
472 });
473 assert_text_and_context_ranges(
474 &buffer,
475 &context_ranges,
476 &"
477 «/file src/lib.rs»"
478 .unindent(),
479 cx,
480 );
481
482 // Edit the argument of the slash command.
483 buffer.update(cx, |buffer, cx| {
484 let edit_offset = buffer.text().find("lib.rs").unwrap();
485 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
486 });
487 assert_text_and_context_ranges(
488 &buffer,
489 &context_ranges,
490 &"
491 «/file src/main.rs»"
492 .unindent(),
493 cx,
494 );
495
496 // Edit the name of the slash command, using one that doesn't exist.
497 buffer.update(cx, |buffer, cx| {
498 let edit_offset = buffer.text().find("/file").unwrap();
499 buffer.edit(
500 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
501 None,
502 cx,
503 );
504 });
505 assert_text_and_context_ranges(
506 &buffer,
507 &context_ranges,
508 &"
509 /unknown src/main.rs"
510 .unindent(),
511 cx,
512 );
513
514 // Undoing the insertion of an non-existent slash command resorts the previous one.
515 buffer.update(cx, |buffer, cx| buffer.undo(cx));
516 assert_text_and_context_ranges(
517 &buffer,
518 &context_ranges,
519 &"
520 «/file src/main.rs»"
521 .unindent(),
522 cx,
523 );
524
525 let (command_output_tx, command_output_rx) = mpsc::unbounded();
526 context.update(cx, |context, cx| {
527 let command_source_range = context.parsed_slash_commands[0].source_range.clone();
528 context.insert_command_output(
529 command_source_range,
530 "file",
531 Task::ready(Ok(command_output_rx.boxed())),
532 true,
533 cx,
534 );
535 });
536 assert_text_and_context_ranges(
537 &buffer,
538 &context_ranges,
539 &"
540 ⟦«/file src/main.rs»
541 …⟧
542 "
543 .unindent(),
544 cx,
545 );
546
547 command_output_tx
548 .unbounded_send(Ok(SlashCommandEvent::StartSection {
549 icon: IconName::Ai,
550 label: "src/main.rs".into(),
551 metadata: None,
552 }))
553 .unwrap();
554 command_output_tx
555 .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
556 .unwrap();
557 cx.run_until_parked();
558 assert_text_and_context_ranges(
559 &buffer,
560 &context_ranges,
561 &"
562 ⟦«/file src/main.rs»
563 src/main.rs…⟧
564 "
565 .unindent(),
566 cx,
567 );
568
569 command_output_tx
570 .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
571 .unwrap();
572 cx.run_until_parked();
573 assert_text_and_context_ranges(
574 &buffer,
575 &context_ranges,
576 &"
577 ⟦«/file src/main.rs»
578 src/main.rs
579 fn main() {}…⟧
580 "
581 .unindent(),
582 cx,
583 );
584
585 command_output_tx
586 .unbounded_send(Ok(SlashCommandEvent::EndSection))
587 .unwrap();
588 cx.run_until_parked();
589 assert_text_and_context_ranges(
590 &buffer,
591 &context_ranges,
592 &"
593 ⟦«/file src/main.rs»
594 ⟪src/main.rs
595 fn main() {}⟫…⟧
596 "
597 .unindent(),
598 cx,
599 );
600
601 drop(command_output_tx);
602 cx.run_until_parked();
603 assert_text_and_context_ranges(
604 &buffer,
605 &context_ranges,
606 &"
607 ⟦⟪src/main.rs
608 fn main() {}⟫⟧
609 "
610 .unindent(),
611 cx,
612 );
613
614 #[track_caller]
615 fn assert_text_and_context_ranges(
616 buffer: &Model<Buffer>,
617 ranges: &RefCell<ContextRanges>,
618 expected_marked_text: &str,
619 cx: &mut TestAppContext,
620 ) {
621 let mut actual_marked_text = String::new();
622 buffer.update(cx, |buffer, _| {
623 struct Endpoint {
624 offset: usize,
625 marker: char,
626 }
627
628 let ranges = ranges.borrow();
629 let mut endpoints = Vec::new();
630 for range in ranges.command_outputs.values() {
631 endpoints.push(Endpoint {
632 offset: range.start.to_offset(buffer),
633 marker: '⟦',
634 });
635 }
636 for range in ranges.parsed_commands.iter() {
637 endpoints.push(Endpoint {
638 offset: range.start.to_offset(buffer),
639 marker: '«',
640 });
641 }
642 for range in ranges.output_sections.iter() {
643 endpoints.push(Endpoint {
644 offset: range.start.to_offset(buffer),
645 marker: '⟪',
646 });
647 }
648
649 for range in ranges.output_sections.iter() {
650 endpoints.push(Endpoint {
651 offset: range.end.to_offset(buffer),
652 marker: '⟫',
653 });
654 }
655 for range in ranges.parsed_commands.iter() {
656 endpoints.push(Endpoint {
657 offset: range.end.to_offset(buffer),
658 marker: '»',
659 });
660 }
661 for range in ranges.command_outputs.values() {
662 endpoints.push(Endpoint {
663 offset: range.end.to_offset(buffer),
664 marker: '⟧',
665 });
666 }
667
668 endpoints.sort_by_key(|endpoint| endpoint.offset);
669 let mut offset = 0;
670 for endpoint in endpoints {
671 actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
672 actual_marked_text.push(endpoint.marker);
673 offset = endpoint.offset;
674 }
675 actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
676 });
677
678 assert_eq!(actual_marked_text, expected_marked_text);
679 }
680}
681
682#[gpui::test]
683async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
684 cx.update(prompt_library::init);
685 let mut settings_store = cx.update(SettingsStore::test);
686 cx.update(|cx| {
687 settings_store
688 .set_user_settings(
689 r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
690 cx,
691 )
692 .unwrap()
693 });
694 cx.set_global(settings_store);
695 cx.update(language::init);
696 cx.update(Project::init_settings);
697 let fs = FakeFs::new(cx.executor());
698 let project = Project::test(fs, [Path::new("/root")], cx).await;
699 cx.update(LanguageModelRegistry::test);
700
701 cx.update(assistant_panel::init);
702 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
703
704 // Create a new context
705 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
706 let context = cx.new_model(|cx| {
707 Context::local(
708 registry.clone(),
709 Some(project),
710 None,
711 prompt_builder.clone(),
712 Arc::new(SlashCommandWorkingSet::default()),
713 Arc::new(ToolWorkingSet::default()),
714 cx,
715 )
716 });
717
718 // Insert an assistant message to simulate a response.
719 let assistant_message_id = context.update(cx, |context, cx| {
720 let user_message_id = context.messages(cx).next().unwrap().id;
721 context
722 .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
723 .unwrap()
724 .id
725 });
726
727 // No edit tags
728 edit(
729 &context,
730 "
731
732 «one
733 two
734 »",
735 cx,
736 );
737 expect_patches(
738 &context,
739 "
740
741 one
742 two
743 ",
744 &[],
745 cx,
746 );
747
748 // Partial edit step tag is added
749 edit(
750 &context,
751 "
752
753 one
754 two
755 «
756 <patch»",
757 cx,
758 );
759 expect_patches(
760 &context,
761 "
762
763 one
764 two
765
766 <patch",
767 &[],
768 cx,
769 );
770
771 // The rest of the step tag is added. The unclosed
772 // step is treated as incomplete.
773 edit(
774 &context,
775 "
776
777 one
778 two
779
780 <patch«>
781 <edit>»",
782 cx,
783 );
784 expect_patches(
785 &context,
786 "
787
788 one
789 two
790
791 «<patch>
792 <edit>»",
793 &[&[]],
794 cx,
795 );
796
797 // The full patch is added
798 edit(
799 &context,
800 "
801
802 one
803 two
804
805 <patch>
806 <edit>«
807 <description>add a `two` function</description>
808 <path>src/lib.rs</path>
809 <operation>insert_after</operation>
810 <old_text>fn one</old_text>
811 <new_text>
812 fn two() {}
813 </new_text>
814 </edit>
815 </patch>
816
817 also,»",
818 cx,
819 );
820 expect_patches(
821 &context,
822 "
823
824 one
825 two
826
827 «<patch>
828 <edit>
829 <description>add a `two` function</description>
830 <path>src/lib.rs</path>
831 <operation>insert_after</operation>
832 <old_text>fn one</old_text>
833 <new_text>
834 fn two() {}
835 </new_text>
836 </edit>
837 </patch>
838 »
839 also,",
840 &[&[AssistantEdit {
841 path: "src/lib.rs".into(),
842 kind: AssistantEditKind::InsertAfter {
843 old_text: "fn one".into(),
844 new_text: "fn two() {}".into(),
845 description: Some("add a `two` function".into()),
846 },
847 }]],
848 cx,
849 );
850
851 // The step is manually edited.
852 edit(
853 &context,
854 "
855
856 one
857 two
858
859 <patch>
860 <edit>
861 <description>add a `two` function</description>
862 <path>src/lib.rs</path>
863 <operation>insert_after</operation>
864 <old_text>«fn zero»</old_text>
865 <new_text>
866 fn two() {}
867 </new_text>
868 </edit>
869 </patch>
870
871 also,",
872 cx,
873 );
874 expect_patches(
875 &context,
876 "
877
878 one
879 two
880
881 «<patch>
882 <edit>
883 <description>add a `two` function</description>
884 <path>src/lib.rs</path>
885 <operation>insert_after</operation>
886 <old_text>fn zero</old_text>
887 <new_text>
888 fn two() {}
889 </new_text>
890 </edit>
891 </patch>
892 »
893 also,",
894 &[&[AssistantEdit {
895 path: "src/lib.rs".into(),
896 kind: AssistantEditKind::InsertAfter {
897 old_text: "fn zero".into(),
898 new_text: "fn two() {}".into(),
899 description: Some("add a `two` function".into()),
900 },
901 }]],
902 cx,
903 );
904
905 // When setting the message role to User, the steps are cleared.
906 context.update(cx, |context, cx| {
907 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
908 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
909 });
910 expect_patches(
911 &context,
912 "
913
914 one
915 two
916
917 <patch>
918 <edit>
919 <description>add a `two` function</description>
920 <path>src/lib.rs</path>
921 <operation>insert_after</operation>
922 <old_text>fn zero</old_text>
923 <new_text>
924 fn two() {}
925 </new_text>
926 </edit>
927 </patch>
928
929 also,",
930 &[],
931 cx,
932 );
933
934 // When setting the message role back to Assistant, the steps are reparsed.
935 context.update(cx, |context, cx| {
936 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
937 });
938 expect_patches(
939 &context,
940 "
941
942 one
943 two
944
945 «<patch>
946 <edit>
947 <description>add a `two` function</description>
948 <path>src/lib.rs</path>
949 <operation>insert_after</operation>
950 <old_text>fn zero</old_text>
951 <new_text>
952 fn two() {}
953 </new_text>
954 </edit>
955 </patch>
956 »
957 also,",
958 &[&[AssistantEdit {
959 path: "src/lib.rs".into(),
960 kind: AssistantEditKind::InsertAfter {
961 old_text: "fn zero".into(),
962 new_text: "fn two() {}".into(),
963 description: Some("add a `two` function".into()),
964 },
965 }]],
966 cx,
967 );
968
969 // Ensure steps are re-parsed when deserializing.
970 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
971 let deserialized_context = cx.new_model(|cx| {
972 Context::deserialize(
973 serialized_context,
974 Default::default(),
975 registry.clone(),
976 prompt_builder.clone(),
977 Arc::new(SlashCommandWorkingSet::default()),
978 Arc::new(ToolWorkingSet::default()),
979 None,
980 None,
981 cx,
982 )
983 });
984 expect_patches(
985 &deserialized_context,
986 "
987
988 one
989 two
990
991 «<patch>
992 <edit>
993 <description>add a `two` function</description>
994 <path>src/lib.rs</path>
995 <operation>insert_after</operation>
996 <old_text>fn zero</old_text>
997 <new_text>
998 fn two() {}
999 </new_text>
1000 </edit>
1001 </patch>
1002 »
1003 also,",
1004 &[&[AssistantEdit {
1005 path: "src/lib.rs".into(),
1006 kind: AssistantEditKind::InsertAfter {
1007 old_text: "fn zero".into(),
1008 new_text: "fn two() {}".into(),
1009 description: Some("add a `two` function".into()),
1010 },
1011 }]],
1012 cx,
1013 );
1014
1015 fn edit(context: &Model<Context>, new_text_marked_with_edits: &str, cx: &mut TestAppContext) {
1016 context.update(cx, |context, cx| {
1017 context.buffer.update(cx, |buffer, cx| {
1018 buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
1019 });
1020 });
1021 cx.executor().run_until_parked();
1022 }
1023
1024 #[track_caller]
1025 fn expect_patches(
1026 context: &Model<Context>,
1027 expected_marked_text: &str,
1028 expected_suggestions: &[&[AssistantEdit]],
1029 cx: &mut TestAppContext,
1030 ) {
1031 let expected_marked_text = expected_marked_text.unindent();
1032 let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
1033
1034 let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
1035 context.buffer.read_with(cx, |buffer, _| {
1036 let ranges = context
1037 .patches
1038 .iter()
1039 .map(|entry| entry.range.to_offset(buffer))
1040 .collect::<Vec<_>>();
1041 (
1042 buffer.text(),
1043 ranges,
1044 context
1045 .patches
1046 .iter()
1047 .map(|step| step.edits.clone())
1048 .collect::<Vec<_>>(),
1049 )
1050 })
1051 });
1052
1053 assert_eq!(buffer_text, expected_text);
1054
1055 let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
1056 assert_eq!(actual_marked_text, expected_marked_text);
1057
1058 assert_eq!(
1059 patches
1060 .iter()
1061 .map(|patch| {
1062 patch
1063 .iter()
1064 .map(|edit| {
1065 let edit = edit.as_ref().unwrap();
1066 AssistantEdit {
1067 path: edit.path.clone(),
1068 kind: edit.kind.clone(),
1069 }
1070 })
1071 .collect::<Vec<_>>()
1072 })
1073 .collect::<Vec<_>>(),
1074 expected_suggestions
1075 );
1076 }
1077}
1078
1079#[gpui::test]
1080async fn test_serialization(cx: &mut TestAppContext) {
1081 let settings_store = cx.update(SettingsStore::test);
1082 cx.set_global(settings_store);
1083 cx.update(LanguageModelRegistry::test);
1084 cx.update(assistant_panel::init);
1085 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1086 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1087 let context = cx.new_model(|cx| {
1088 Context::local(
1089 registry.clone(),
1090 None,
1091 None,
1092 prompt_builder.clone(),
1093 Arc::new(SlashCommandWorkingSet::default()),
1094 Arc::new(ToolWorkingSet::default()),
1095 cx,
1096 )
1097 });
1098 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
1099 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
1100 let message_1 = context.update(cx, |context, cx| {
1101 context
1102 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
1103 .unwrap()
1104 });
1105 let message_2 = context.update(cx, |context, cx| {
1106 context
1107 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1108 .unwrap()
1109 });
1110 buffer.update(cx, |buffer, cx| {
1111 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
1112 buffer.finalize_last_transaction();
1113 });
1114 let _message_3 = context.update(cx, |context, cx| {
1115 context
1116 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
1117 .unwrap()
1118 });
1119 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1120 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
1121 assert_eq!(
1122 cx.read(|cx| messages(&context, cx)),
1123 [
1124 (message_0, Role::User, 0..2),
1125 (message_1.id, Role::Assistant, 2..6),
1126 (message_2.id, Role::System, 6..6),
1127 ]
1128 );
1129
1130 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
1131 let deserialized_context = cx.new_model(|cx| {
1132 Context::deserialize(
1133 serialized_context,
1134 Default::default(),
1135 registry.clone(),
1136 prompt_builder.clone(),
1137 Arc::new(SlashCommandWorkingSet::default()),
1138 Arc::new(ToolWorkingSet::default()),
1139 None,
1140 None,
1141 cx,
1142 )
1143 });
1144 let deserialized_buffer =
1145 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
1146 assert_eq!(
1147 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
1148 "a\nb\nc\n"
1149 );
1150 assert_eq!(
1151 cx.read(|cx| messages(&deserialized_context, cx)),
1152 [
1153 (message_0, Role::User, 0..2),
1154 (message_1.id, Role::Assistant, 2..6),
1155 (message_2.id, Role::System, 6..6),
1156 ]
1157 );
1158}
1159
1160#[gpui::test(iterations = 100)]
1161async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
1162 let min_peers = env::var("MIN_PEERS")
1163 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
1164 .unwrap_or(2);
1165 let max_peers = env::var("MAX_PEERS")
1166 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
1167 .unwrap_or(5);
1168 let operations = env::var("OPERATIONS")
1169 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
1170 .unwrap_or(50);
1171
1172 let settings_store = cx.update(SettingsStore::test);
1173 cx.set_global(settings_store);
1174 cx.update(LanguageModelRegistry::test);
1175
1176 cx.update(assistant_panel::init);
1177 let slash_commands = cx.update(SlashCommandRegistry::default_global);
1178 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
1179 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
1180 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
1181
1182 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
1183 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
1184 let mut contexts = Vec::new();
1185
1186 let num_peers = rng.gen_range(min_peers..=max_peers);
1187 let context_id = ContextId::new();
1188 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1189 for i in 0..num_peers {
1190 let context = cx.new_model(|cx| {
1191 Context::new(
1192 context_id.clone(),
1193 i as ReplicaId,
1194 language::Capability::ReadWrite,
1195 registry.clone(),
1196 prompt_builder.clone(),
1197 Arc::new(SlashCommandWorkingSet::default()),
1198 Arc::new(ToolWorkingSet::default()),
1199 None,
1200 None,
1201 cx,
1202 )
1203 });
1204
1205 cx.update(|cx| {
1206 cx.subscribe(&context, {
1207 let network = network.clone();
1208 move |_, event, _| {
1209 if let ContextEvent::Operation(op) = event {
1210 network
1211 .lock()
1212 .broadcast(i as ReplicaId, vec![op.to_proto()]);
1213 }
1214 }
1215 })
1216 .detach();
1217 });
1218
1219 contexts.push(context);
1220 network.lock().add_peer(i as ReplicaId);
1221 }
1222
1223 let mut mutation_count = operations;
1224
1225 while mutation_count > 0
1226 || !network.lock().is_idle()
1227 || network.lock().contains_disconnected_peers()
1228 {
1229 let context_index = rng.gen_range(0..contexts.len());
1230 let context = &contexts[context_index];
1231
1232 match rng.gen_range(0..100) {
1233 0..=29 if mutation_count > 0 => {
1234 log::info!("Context {}: edit buffer", context_index);
1235 context.update(cx, |context, cx| {
1236 context
1237 .buffer
1238 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1239 });
1240 mutation_count -= 1;
1241 }
1242 30..=44 if mutation_count > 0 => {
1243 context.update(cx, |context, cx| {
1244 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1245 log::info!("Context {}: split message at {:?}", context_index, range);
1246 context.split_message(range, cx);
1247 });
1248 mutation_count -= 1;
1249 }
1250 45..=59 if mutation_count > 0 => {
1251 context.update(cx, |context, cx| {
1252 if let Some(message) = context.messages(cx).choose(&mut rng) {
1253 let role = *[Role::User, Role::Assistant, Role::System]
1254 .choose(&mut rng)
1255 .unwrap();
1256 log::info!(
1257 "Context {}: insert message after {:?} with {:?}",
1258 context_index,
1259 message.id,
1260 role
1261 );
1262 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1263 }
1264 });
1265 mutation_count -= 1;
1266 }
1267 60..=74 if mutation_count > 0 => {
1268 context.update(cx, |context, cx| {
1269 let command_text = "/".to_string()
1270 + slash_commands
1271 .command_names()
1272 .choose(&mut rng)
1273 .unwrap()
1274 .clone()
1275 .as_ref();
1276
1277 let command_range = context.buffer.update(cx, |buffer, cx| {
1278 let offset = buffer.random_byte_range(0, &mut rng).start;
1279 buffer.edit(
1280 [(offset..offset, format!("\n{}\n", command_text))],
1281 None,
1282 cx,
1283 );
1284 offset + 1..offset + 1 + command_text.len()
1285 });
1286
1287 let output_text = RandomCharIter::new(&mut rng)
1288 .filter(|c| *c != '\r')
1289 .take(10)
1290 .collect::<String>();
1291
1292 let mut events = vec![Ok(SlashCommandEvent::StartMessage {
1293 role: Role::User,
1294 merge_same_roles: true,
1295 })];
1296
1297 let num_sections = rng.gen_range(0..=3);
1298 let mut section_start = 0;
1299 for _ in 0..num_sections {
1300 let mut section_end = rng.gen_range(section_start..=output_text.len());
1301 while !output_text.is_char_boundary(section_end) {
1302 section_end += 1;
1303 }
1304 events.push(Ok(SlashCommandEvent::StartSection {
1305 icon: IconName::Ai,
1306 label: "section".into(),
1307 metadata: None,
1308 }));
1309 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1310 text: output_text[section_start..section_end].to_string(),
1311 run_commands_in_text: false,
1312 })));
1313 events.push(Ok(SlashCommandEvent::EndSection));
1314 section_start = section_end;
1315 }
1316
1317 if section_start < output_text.len() {
1318 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1319 text: output_text[section_start..].to_string(),
1320 run_commands_in_text: false,
1321 })));
1322 }
1323
1324 log::info!(
1325 "Context {}: insert slash command output at {:?} with {:?} events",
1326 context_index,
1327 command_range,
1328 events.len()
1329 );
1330
1331 let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1332 ..context.buffer.read(cx).anchor_after(command_range.end);
1333 context.insert_command_output(
1334 command_range,
1335 "/command",
1336 Task::ready(Ok(stream::iter(events).boxed())),
1337 true,
1338 cx,
1339 );
1340 });
1341 cx.run_until_parked();
1342 mutation_count -= 1;
1343 }
1344 75..=84 if mutation_count > 0 => {
1345 context.update(cx, |context, cx| {
1346 if let Some(message) = context.messages(cx).choose(&mut rng) {
1347 let new_status = match rng.gen_range(0..3) {
1348 0 => MessageStatus::Done,
1349 1 => MessageStatus::Pending,
1350 _ => MessageStatus::Error(SharedString::from("Random error")),
1351 };
1352 log::info!(
1353 "Context {}: update message {:?} status to {:?}",
1354 context_index,
1355 message.id,
1356 new_status
1357 );
1358 context.update_metadata(message.id, cx, |metadata| {
1359 metadata.status = new_status;
1360 });
1361 }
1362 });
1363 mutation_count -= 1;
1364 }
1365 _ => {
1366 let replica_id = context_index as ReplicaId;
1367 if network.lock().is_disconnected(replica_id) {
1368 network.lock().reconnect_peer(replica_id, 0);
1369
1370 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1371 let host_context = &contexts[0].read(cx);
1372 let guest_context = context.read(cx);
1373 (
1374 guest_context.serialize_ops(&host_context.version(cx), cx),
1375 host_context.serialize_ops(&guest_context.version(cx), cx),
1376 )
1377 });
1378 let ops_to_send = ops_to_send.await;
1379 let ops_to_receive = ops_to_receive
1380 .await
1381 .into_iter()
1382 .map(ContextOperation::from_proto)
1383 .collect::<Result<Vec<_>>>()
1384 .unwrap();
1385 log::info!(
1386 "Context {}: reconnecting. Sent {} operations, received {} operations",
1387 context_index,
1388 ops_to_send.len(),
1389 ops_to_receive.len()
1390 );
1391
1392 network.lock().broadcast(replica_id, ops_to_send);
1393 context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1394 } else if rng.gen_bool(0.1) && replica_id != 0 {
1395 log::info!("Context {}: disconnecting", context_index);
1396 network.lock().disconnect_peer(replica_id);
1397 } else if network.lock().has_unreceived(replica_id) {
1398 log::info!("Context {}: applying operations", context_index);
1399 let ops = network.lock().receive(replica_id);
1400 let ops = ops
1401 .into_iter()
1402 .map(ContextOperation::from_proto)
1403 .collect::<Result<Vec<_>>>()
1404 .unwrap();
1405 context.update(cx, |context, cx| context.apply_ops(ops, cx));
1406 }
1407 }
1408 }
1409 }
1410
1411 cx.read(|cx| {
1412 let first_context = contexts[0].read(cx);
1413 for context in &contexts[1..] {
1414 let context = context.read(cx);
1415 assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
1416 assert_eq!(
1417 context.buffer.read(cx).text(),
1418 first_context.buffer.read(cx).text(),
1419 "Context {} text != Context 0 text",
1420 context.buffer.read(cx).replica_id()
1421 );
1422 assert_eq!(
1423 context.message_anchors,
1424 first_context.message_anchors,
1425 "Context {} messages != Context 0 messages",
1426 context.buffer.read(cx).replica_id()
1427 );
1428 assert_eq!(
1429 context.messages_metadata,
1430 first_context.messages_metadata,
1431 "Context {} message metadata != Context 0 message metadata",
1432 context.buffer.read(cx).replica_id()
1433 );
1434 assert_eq!(
1435 context.slash_command_output_sections,
1436 first_context.slash_command_output_sections,
1437 "Context {} slash command output sections != Context 0 slash command output sections",
1438 context.buffer.read(cx).replica_id()
1439 );
1440 }
1441 });
1442}
1443
1444#[gpui::test]
1445fn test_mark_cache_anchors(cx: &mut AppContext) {
1446 let settings_store = SettingsStore::test(cx);
1447 LanguageModelRegistry::test(cx);
1448 cx.set_global(settings_store);
1449 assistant_panel::init(cx);
1450 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1451 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1452 let context = cx.new_model(|cx| {
1453 Context::local(
1454 registry,
1455 None,
1456 None,
1457 prompt_builder.clone(),
1458 Arc::new(SlashCommandWorkingSet::default()),
1459 Arc::new(ToolWorkingSet::default()),
1460 cx,
1461 )
1462 });
1463 let buffer = context.read(cx).buffer.clone();
1464
1465 // Create a test cache configuration
1466 let cache_configuration = &Some(LanguageModelCacheConfiguration {
1467 max_cache_anchors: 3,
1468 should_speculate: true,
1469 min_total_token: 10,
1470 });
1471
1472 let message_1 = context.read(cx).message_anchors[0].clone();
1473
1474 context.update(cx, |context, cx| {
1475 context.mark_cache_anchors(cache_configuration, false, cx)
1476 });
1477
1478 assert_eq!(
1479 messages_cache(&context, cx)
1480 .iter()
1481 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1482 .count(),
1483 0,
1484 "Empty messages should not have any cache anchors."
1485 );
1486
1487 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1488 let message_2 = context
1489 .update(cx, |context, cx| {
1490 context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1491 })
1492 .unwrap();
1493
1494 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1495 let message_3 = context
1496 .update(cx, |context, cx| {
1497 context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1498 })
1499 .unwrap();
1500 buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1501
1502 context.update(cx, |context, cx| {
1503 context.mark_cache_anchors(cache_configuration, false, cx)
1504 });
1505 assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1506 assert_eq!(
1507 messages_cache(&context, cx)
1508 .iter()
1509 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1510 .count(),
1511 0,
1512 "Messages should not be marked for cache before going over the token minimum."
1513 );
1514 context.update(cx, |context, _| {
1515 context.token_count = Some(20);
1516 });
1517
1518 context.update(cx, |context, cx| {
1519 context.mark_cache_anchors(cache_configuration, true, cx)
1520 });
1521 assert_eq!(
1522 messages_cache(&context, cx)
1523 .iter()
1524 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1525 .collect::<Vec<bool>>(),
1526 vec![true, true, false],
1527 "Last message should not be an anchor on speculative request."
1528 );
1529
1530 context
1531 .update(cx, |context, cx| {
1532 context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1533 })
1534 .unwrap();
1535
1536 context.update(cx, |context, cx| {
1537 context.mark_cache_anchors(cache_configuration, false, cx)
1538 });
1539 assert_eq!(
1540 messages_cache(&context, cx)
1541 .iter()
1542 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1543 .collect::<Vec<bool>>(),
1544 vec![false, true, true, false],
1545 "Most recent message should also be cached if not a speculative request."
1546 );
1547 context.update(cx, |context, cx| {
1548 context.update_cache_status_for_completion(cx)
1549 });
1550 assert_eq!(
1551 messages_cache(&context, cx)
1552 .iter()
1553 .map(|(_, cache)| cache
1554 .as_ref()
1555 .map_or(None, |cache| Some(cache.status.clone())))
1556 .collect::<Vec<Option<CacheStatus>>>(),
1557 vec![
1558 Some(CacheStatus::Cached),
1559 Some(CacheStatus::Cached),
1560 Some(CacheStatus::Cached),
1561 None
1562 ],
1563 "All user messages prior to anchor should be marked as cached."
1564 );
1565
1566 buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1567 context.update(cx, |context, cx| {
1568 context.mark_cache_anchors(cache_configuration, false, cx)
1569 });
1570 assert_eq!(
1571 messages_cache(&context, cx)
1572 .iter()
1573 .map(|(_, cache)| cache
1574 .as_ref()
1575 .map_or(None, |cache| Some(cache.status.clone())))
1576 .collect::<Vec<Option<CacheStatus>>>(),
1577 vec![
1578 Some(CacheStatus::Cached),
1579 Some(CacheStatus::Cached),
1580 Some(CacheStatus::Pending),
1581 None
1582 ],
1583 "Modifying a message should invalidate it's cache but leave previous messages."
1584 );
1585 buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1586 context.update(cx, |context, cx| {
1587 context.mark_cache_anchors(cache_configuration, false, cx)
1588 });
1589 assert_eq!(
1590 messages_cache(&context, cx)
1591 .iter()
1592 .map(|(_, cache)| cache
1593 .as_ref()
1594 .map_or(None, |cache| Some(cache.status.clone())))
1595 .collect::<Vec<Option<CacheStatus>>>(),
1596 vec![
1597 Some(CacheStatus::Pending),
1598 Some(CacheStatus::Pending),
1599 Some(CacheStatus::Pending),
1600 None
1601 ],
1602 "Modifying a message should invalidate all future messages."
1603 );
1604}
1605
1606fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
1607 context
1608 .read(cx)
1609 .messages(cx)
1610 .map(|message| (message.id, message.role, message.offset_range))
1611 .collect()
1612}
1613
1614fn messages_cache(
1615 context: &Model<Context>,
1616 cx: &AppContext,
1617) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1618 context
1619 .read(cx)
1620 .messages(cx)
1621 .map(|message| (message.id, message.cache.clone()))
1622 .collect()
1623}
1624
1625#[derive(Clone)]
1626struct FakeSlashCommand(String);
1627
1628impl SlashCommand for FakeSlashCommand {
1629 fn name(&self) -> String {
1630 self.0.clone()
1631 }
1632
1633 fn description(&self) -> String {
1634 format!("Fake slash command: {}", self.0)
1635 }
1636
1637 fn menu_text(&self) -> String {
1638 format!("Run fake command: {}", self.0)
1639 }
1640
1641 fn complete_argument(
1642 self: Arc<Self>,
1643 _arguments: &[String],
1644 _cancel: Arc<AtomicBool>,
1645 _workspace: Option<WeakView<Workspace>>,
1646 _cx: &mut WindowContext,
1647 ) -> Task<Result<Vec<ArgumentCompletion>>> {
1648 Task::ready(Ok(vec![]))
1649 }
1650
1651 fn requires_argument(&self) -> bool {
1652 false
1653 }
1654
1655 fn run(
1656 self: Arc<Self>,
1657 _arguments: &[String],
1658 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1659 _context_buffer: BufferSnapshot,
1660 _workspace: WeakView<Workspace>,
1661 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1662 _cx: &mut WindowContext,
1663 ) -> Task<SlashCommandResult> {
1664 Task::ready(Ok(SlashCommandOutput {
1665 text: format!("Executed fake command: {}", self.0),
1666 sections: vec![],
1667 run_commands_in_text: false,
1668 }
1669 .to_event_stream()))
1670 }
1671}