1use crate::{
2 AssistantContext, AssistantEdit, AssistantEditKind, CacheStatus, ContextEvent, ContextId,
3 ContextOperation, ContextSummary, InvokedSlashCommandId, MessageCacheMetadata, MessageId,
4 MessageStatus, RequestType,
5};
6use anyhow::Result;
7use assistant_slash_command::{
8 ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
9 SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
10};
11use assistant_slash_commands::FileSlashCommand;
12use collections::{HashMap, HashSet};
13use fs::FakeFs;
14use futures::{
15 channel::mpsc,
16 stream::{self, StreamExt},
17};
18use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
19use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
20use language_model::{
21 ConfiguredModel, LanguageModelCacheConfiguration, LanguageModelRegistry, Role,
22 fake_provider::{FakeLanguageModel, FakeLanguageModelProvider},
23};
24use parking_lot::Mutex;
25use pretty_assertions::assert_eq;
26use project::Project;
27use prompt_store::PromptBuilder;
28use rand::prelude::*;
29use serde_json::json;
30use settings::SettingsStore;
31use std::{
32 cell::RefCell,
33 env,
34 ops::Range,
35 path::Path,
36 rc::Rc,
37 sync::{Arc, atomic::AtomicBool},
38};
39use text::{OffsetRangeExt as _, ReplicaId, ToOffset, network::Network};
40use ui::{IconName, Window};
41use unindent::Unindent;
42use util::{
43 RandomCharIter,
44 test::{generate_marked_text, marked_text_ranges},
45};
46use workspace::Workspace;
47
48#[gpui::test]
49fn test_inserting_and_removing_messages(cx: &mut App) {
50 init_test(cx);
51
52 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
53 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
54 let context = cx.new(|cx| {
55 AssistantContext::local(
56 registry,
57 None,
58 None,
59 prompt_builder.clone(),
60 Arc::new(SlashCommandWorkingSet::default()),
61 cx,
62 )
63 });
64 let buffer = context.read(cx).buffer.clone();
65
66 let message_1 = context.read(cx).message_anchors[0].clone();
67 assert_eq!(
68 messages(&context, cx),
69 vec![(message_1.id, Role::User, 0..0)]
70 );
71
72 let message_2 = context.update(cx, |context, cx| {
73 context
74 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
75 .unwrap()
76 });
77 assert_eq!(
78 messages(&context, cx),
79 vec![
80 (message_1.id, Role::User, 0..1),
81 (message_2.id, Role::Assistant, 1..1)
82 ]
83 );
84
85 buffer.update(cx, |buffer, cx| {
86 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
87 });
88 assert_eq!(
89 messages(&context, cx),
90 vec![
91 (message_1.id, Role::User, 0..2),
92 (message_2.id, Role::Assistant, 2..3)
93 ]
94 );
95
96 let message_3 = context.update(cx, |context, cx| {
97 context
98 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
99 .unwrap()
100 });
101 assert_eq!(
102 messages(&context, cx),
103 vec![
104 (message_1.id, Role::User, 0..2),
105 (message_2.id, Role::Assistant, 2..4),
106 (message_3.id, Role::User, 4..4)
107 ]
108 );
109
110 let message_4 = context.update(cx, |context, cx| {
111 context
112 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
113 .unwrap()
114 });
115 assert_eq!(
116 messages(&context, cx),
117 vec![
118 (message_1.id, Role::User, 0..2),
119 (message_2.id, Role::Assistant, 2..4),
120 (message_4.id, Role::User, 4..5),
121 (message_3.id, Role::User, 5..5),
122 ]
123 );
124
125 buffer.update(cx, |buffer, cx| {
126 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
127 });
128 assert_eq!(
129 messages(&context, cx),
130 vec![
131 (message_1.id, Role::User, 0..2),
132 (message_2.id, Role::Assistant, 2..4),
133 (message_4.id, Role::User, 4..6),
134 (message_3.id, Role::User, 6..7),
135 ]
136 );
137
138 // Deleting across message boundaries merges the messages.
139 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
140 assert_eq!(
141 messages(&context, cx),
142 vec![
143 (message_1.id, Role::User, 0..3),
144 (message_3.id, Role::User, 3..4),
145 ]
146 );
147
148 // Undoing the deletion should also undo the merge.
149 buffer.update(cx, |buffer, cx| buffer.undo(cx));
150 assert_eq!(
151 messages(&context, cx),
152 vec![
153 (message_1.id, Role::User, 0..2),
154 (message_2.id, Role::Assistant, 2..4),
155 (message_4.id, Role::User, 4..6),
156 (message_3.id, Role::User, 6..7),
157 ]
158 );
159
160 // Redoing the deletion should also redo the merge.
161 buffer.update(cx, |buffer, cx| buffer.redo(cx));
162 assert_eq!(
163 messages(&context, cx),
164 vec![
165 (message_1.id, Role::User, 0..3),
166 (message_3.id, Role::User, 3..4),
167 ]
168 );
169
170 // Ensure we can still insert after a merged message.
171 let message_5 = context.update(cx, |context, cx| {
172 context
173 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
174 .unwrap()
175 });
176 assert_eq!(
177 messages(&context, cx),
178 vec![
179 (message_1.id, Role::User, 0..3),
180 (message_5.id, Role::System, 3..4),
181 (message_3.id, Role::User, 4..5)
182 ]
183 );
184}
185
186#[gpui::test]
187fn test_message_splitting(cx: &mut App) {
188 init_test(cx);
189
190 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
191
192 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
193 let context = cx.new(|cx| {
194 AssistantContext::local(
195 registry.clone(),
196 None,
197 None,
198 prompt_builder.clone(),
199 Arc::new(SlashCommandWorkingSet::default()),
200 cx,
201 )
202 });
203 let buffer = context.read(cx).buffer.clone();
204
205 let message_1 = context.read(cx).message_anchors[0].clone();
206 assert_eq!(
207 messages(&context, cx),
208 vec![(message_1.id, Role::User, 0..0)]
209 );
210
211 buffer.update(cx, |buffer, cx| {
212 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
213 });
214
215 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
216 let message_2 = message_2.unwrap();
217
218 // We recycle newlines in the middle of a split message
219 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
220 assert_eq!(
221 messages(&context, cx),
222 vec![
223 (message_1.id, Role::User, 0..4),
224 (message_2.id, Role::User, 4..16),
225 ]
226 );
227
228 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
229 let message_3 = message_3.unwrap();
230
231 // We don't recycle newlines at the end of a split message
232 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
233 assert_eq!(
234 messages(&context, cx),
235 vec![
236 (message_1.id, Role::User, 0..4),
237 (message_3.id, Role::User, 4..5),
238 (message_2.id, Role::User, 5..17),
239 ]
240 );
241
242 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
243 let message_4 = message_4.unwrap();
244 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
245 assert_eq!(
246 messages(&context, cx),
247 vec![
248 (message_1.id, Role::User, 0..4),
249 (message_3.id, Role::User, 4..5),
250 (message_2.id, Role::User, 5..9),
251 (message_4.id, Role::User, 9..17),
252 ]
253 );
254
255 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
256 let message_5 = message_5.unwrap();
257 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
258 assert_eq!(
259 messages(&context, cx),
260 vec![
261 (message_1.id, Role::User, 0..4),
262 (message_3.id, Role::User, 4..5),
263 (message_2.id, Role::User, 5..9),
264 (message_4.id, Role::User, 9..10),
265 (message_5.id, Role::User, 10..18),
266 ]
267 );
268
269 let (message_6, message_7) =
270 context.update(cx, |context, cx| context.split_message(14..16, cx));
271 let message_6 = message_6.unwrap();
272 let message_7 = message_7.unwrap();
273 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
274 assert_eq!(
275 messages(&context, cx),
276 vec![
277 (message_1.id, Role::User, 0..4),
278 (message_3.id, Role::User, 4..5),
279 (message_2.id, Role::User, 5..9),
280 (message_4.id, Role::User, 9..10),
281 (message_5.id, Role::User, 10..14),
282 (message_6.id, Role::User, 14..17),
283 (message_7.id, Role::User, 17..19),
284 ]
285 );
286}
287
288#[gpui::test]
289fn test_messages_for_offsets(cx: &mut App) {
290 init_test(cx);
291
292 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
293 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
294 let context = cx.new(|cx| {
295 AssistantContext::local(
296 registry,
297 None,
298 None,
299 prompt_builder.clone(),
300 Arc::new(SlashCommandWorkingSet::default()),
301 cx,
302 )
303 });
304 let buffer = context.read(cx).buffer.clone();
305
306 let message_1 = context.read(cx).message_anchors[0].clone();
307 assert_eq!(
308 messages(&context, cx),
309 vec![(message_1.id, Role::User, 0..0)]
310 );
311
312 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
313 let message_2 = context
314 .update(cx, |context, cx| {
315 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
316 })
317 .unwrap();
318 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
319
320 let message_3 = context
321 .update(cx, |context, cx| {
322 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
323 })
324 .unwrap();
325 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
326
327 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
328 assert_eq!(
329 messages(&context, cx),
330 vec![
331 (message_1.id, Role::User, 0..4),
332 (message_2.id, Role::User, 4..8),
333 (message_3.id, Role::User, 8..11)
334 ]
335 );
336
337 assert_eq!(
338 message_ids_for_offsets(&context, &[0, 4, 9], cx),
339 [message_1.id, message_2.id, message_3.id]
340 );
341 assert_eq!(
342 message_ids_for_offsets(&context, &[0, 1, 11], cx),
343 [message_1.id, message_3.id]
344 );
345
346 let message_4 = context
347 .update(cx, |context, cx| {
348 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
349 })
350 .unwrap();
351 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
352 assert_eq!(
353 messages(&context, cx),
354 vec![
355 (message_1.id, Role::User, 0..4),
356 (message_2.id, Role::User, 4..8),
357 (message_3.id, Role::User, 8..12),
358 (message_4.id, Role::User, 12..12)
359 ]
360 );
361 assert_eq!(
362 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
363 [message_1.id, message_2.id, message_3.id, message_4.id]
364 );
365
366 fn message_ids_for_offsets(
367 context: &Entity<AssistantContext>,
368 offsets: &[usize],
369 cx: &App,
370 ) -> Vec<MessageId> {
371 context
372 .read(cx)
373 .messages_for_offsets(offsets.iter().copied(), cx)
374 .into_iter()
375 .map(|message| message.id)
376 .collect()
377 }
378}
379
380#[gpui::test]
381async fn test_slash_commands(cx: &mut TestAppContext) {
382 cx.update(init_test);
383
384 let fs = FakeFs::new(cx.background_executor.clone());
385
386 fs.insert_tree(
387 "/test",
388 json!({
389 "src": {
390 "lib.rs": "fn one() -> usize { 1 }",
391 "main.rs": "
392 use crate::one;
393 fn main() { one(); }
394 ".unindent(),
395 }
396 }),
397 )
398 .await;
399
400 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
401 slash_command_registry.register_command(FileSlashCommand, false);
402
403 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
404 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
405 let context = cx.new(|cx| {
406 AssistantContext::local(
407 registry.clone(),
408 None,
409 None,
410 prompt_builder.clone(),
411 Arc::new(SlashCommandWorkingSet::default()),
412 cx,
413 )
414 });
415
416 #[derive(Default)]
417 struct ContextRanges {
418 parsed_commands: HashSet<Range<language::Anchor>>,
419 command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
420 output_sections: HashSet<Range<language::Anchor>>,
421 }
422
423 let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
424 context.update(cx, |_, cx| {
425 cx.subscribe(&context, {
426 let context_ranges = context_ranges.clone();
427 move |context, _, event, _| {
428 let mut context_ranges = context_ranges.borrow_mut();
429 match event {
430 ContextEvent::InvokedSlashCommandChanged { command_id } => {
431 let command = context.invoked_slash_command(command_id).unwrap();
432 context_ranges
433 .command_outputs
434 .insert(*command_id, command.range.clone());
435 }
436 ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
437 for range in removed {
438 context_ranges.parsed_commands.remove(range);
439 }
440 for command in updated {
441 context_ranges
442 .parsed_commands
443 .insert(command.source_range.clone());
444 }
445 }
446 ContextEvent::SlashCommandOutputSectionAdded { section } => {
447 context_ranges.output_sections.insert(section.range.clone());
448 }
449 _ => {}
450 }
451 }
452 })
453 .detach();
454 });
455
456 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
457
458 // Insert a slash command
459 buffer.update(cx, |buffer, cx| {
460 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
461 });
462 assert_text_and_context_ranges(
463 &buffer,
464 &context_ranges,
465 &"
466 «/file src/lib.rs»"
467 .unindent(),
468 cx,
469 );
470
471 // Edit the argument of the slash command.
472 buffer.update(cx, |buffer, cx| {
473 let edit_offset = buffer.text().find("lib.rs").unwrap();
474 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
475 });
476 assert_text_and_context_ranges(
477 &buffer,
478 &context_ranges,
479 &"
480 «/file src/main.rs»"
481 .unindent(),
482 cx,
483 );
484
485 // Edit the name of the slash command, using one that doesn't exist.
486 buffer.update(cx, |buffer, cx| {
487 let edit_offset = buffer.text().find("/file").unwrap();
488 buffer.edit(
489 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
490 None,
491 cx,
492 );
493 });
494 assert_text_and_context_ranges(
495 &buffer,
496 &context_ranges,
497 &"
498 /unknown src/main.rs"
499 .unindent(),
500 cx,
501 );
502
503 // Undoing the insertion of an non-existent slash command resorts the previous one.
504 buffer.update(cx, |buffer, cx| buffer.undo(cx));
505 assert_text_and_context_ranges(
506 &buffer,
507 &context_ranges,
508 &"
509 «/file src/main.rs»"
510 .unindent(),
511 cx,
512 );
513
514 let (command_output_tx, command_output_rx) = mpsc::unbounded();
515 context.update(cx, |context, cx| {
516 let command_source_range = context.parsed_slash_commands[0].source_range.clone();
517 context.insert_command_output(
518 command_source_range,
519 "file",
520 Task::ready(Ok(command_output_rx.boxed())),
521 true,
522 cx,
523 );
524 });
525 assert_text_and_context_ranges(
526 &buffer,
527 &context_ranges,
528 &"
529 ⟦«/file src/main.rs»
530 …⟧
531 "
532 .unindent(),
533 cx,
534 );
535
536 command_output_tx
537 .unbounded_send(Ok(SlashCommandEvent::StartSection {
538 icon: IconName::Ai,
539 label: "src/main.rs".into(),
540 metadata: None,
541 }))
542 .unwrap();
543 command_output_tx
544 .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
545 .unwrap();
546 cx.run_until_parked();
547 assert_text_and_context_ranges(
548 &buffer,
549 &context_ranges,
550 &"
551 ⟦«/file src/main.rs»
552 src/main.rs…⟧
553 "
554 .unindent(),
555 cx,
556 );
557
558 command_output_tx
559 .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
560 .unwrap();
561 cx.run_until_parked();
562 assert_text_and_context_ranges(
563 &buffer,
564 &context_ranges,
565 &"
566 ⟦«/file src/main.rs»
567 src/main.rs
568 fn main() {}…⟧
569 "
570 .unindent(),
571 cx,
572 );
573
574 command_output_tx
575 .unbounded_send(Ok(SlashCommandEvent::EndSection))
576 .unwrap();
577 cx.run_until_parked();
578 assert_text_and_context_ranges(
579 &buffer,
580 &context_ranges,
581 &"
582 ⟦«/file src/main.rs»
583 ⟪src/main.rs
584 fn main() {}⟫…⟧
585 "
586 .unindent(),
587 cx,
588 );
589
590 drop(command_output_tx);
591 cx.run_until_parked();
592 assert_text_and_context_ranges(
593 &buffer,
594 &context_ranges,
595 &"
596 ⟦⟪src/main.rs
597 fn main() {}⟫⟧
598 "
599 .unindent(),
600 cx,
601 );
602
603 #[track_caller]
604 fn assert_text_and_context_ranges(
605 buffer: &Entity<Buffer>,
606 ranges: &RefCell<ContextRanges>,
607 expected_marked_text: &str,
608 cx: &mut TestAppContext,
609 ) {
610 let mut actual_marked_text = String::new();
611 buffer.update(cx, |buffer, _| {
612 struct Endpoint {
613 offset: usize,
614 marker: char,
615 }
616
617 let ranges = ranges.borrow();
618 let mut endpoints = Vec::new();
619 for range in ranges.command_outputs.values() {
620 endpoints.push(Endpoint {
621 offset: range.start.to_offset(buffer),
622 marker: '⟦',
623 });
624 }
625 for range in ranges.parsed_commands.iter() {
626 endpoints.push(Endpoint {
627 offset: range.start.to_offset(buffer),
628 marker: '«',
629 });
630 }
631 for range in ranges.output_sections.iter() {
632 endpoints.push(Endpoint {
633 offset: range.start.to_offset(buffer),
634 marker: '⟪',
635 });
636 }
637
638 for range in ranges.output_sections.iter() {
639 endpoints.push(Endpoint {
640 offset: range.end.to_offset(buffer),
641 marker: '⟫',
642 });
643 }
644 for range in ranges.parsed_commands.iter() {
645 endpoints.push(Endpoint {
646 offset: range.end.to_offset(buffer),
647 marker: '»',
648 });
649 }
650 for range in ranges.command_outputs.values() {
651 endpoints.push(Endpoint {
652 offset: range.end.to_offset(buffer),
653 marker: '⟧',
654 });
655 }
656
657 endpoints.sort_by_key(|endpoint| endpoint.offset);
658 let mut offset = 0;
659 for endpoint in endpoints {
660 actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
661 actual_marked_text.push(endpoint.marker);
662 offset = endpoint.offset;
663 }
664 actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
665 });
666
667 assert_eq!(actual_marked_text, expected_marked_text);
668 }
669}
670
671#[gpui::test]
672async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
673 cx.update(|cx| {
674 init_test(cx);
675 cx.update_global(|settings_store: &mut SettingsStore, cx| {
676 settings_store
677 .set_user_settings(
678 r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
679 cx,
680 )
681 .unwrap()
682 })
683 });
684 let fs = FakeFs::new(cx.executor());
685 let project = Project::test(fs, [Path::new("/root")], cx).await;
686
687 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
688
689 // Create a new context
690 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
691 let context = cx.new(|cx| {
692 AssistantContext::local(
693 registry.clone(),
694 Some(project),
695 None,
696 prompt_builder.clone(),
697 Arc::new(SlashCommandWorkingSet::default()),
698 cx,
699 )
700 });
701
702 // Insert an assistant message to simulate a response.
703 let assistant_message_id = context.update(cx, |context, cx| {
704 let user_message_id = context.messages(cx).next().unwrap().id;
705 context
706 .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
707 .unwrap()
708 .id
709 });
710
711 // No edit tags
712 edit(
713 &context,
714 "
715
716 «one
717 two
718 »",
719 cx,
720 );
721 expect_patches(
722 &context,
723 "
724
725 one
726 two
727 ",
728 &[],
729 cx,
730 );
731
732 // Partial edit step tag is added
733 edit(
734 &context,
735 "
736
737 one
738 two
739 «
740 <patch»",
741 cx,
742 );
743 expect_patches(
744 &context,
745 "
746
747 one
748 two
749
750 <patch",
751 &[],
752 cx,
753 );
754
755 // The rest of the step tag is added. The unclosed
756 // step is treated as incomplete.
757 edit(
758 &context,
759 "
760
761 one
762 two
763
764 <patch«>
765 <edit>»",
766 cx,
767 );
768 expect_patches(
769 &context,
770 "
771
772 one
773 two
774
775 «<patch>
776 <edit>»",
777 &[&[]],
778 cx,
779 );
780
781 // The full patch is added
782 edit(
783 &context,
784 "
785
786 one
787 two
788
789 <patch>
790 <edit>«
791 <description>add a `two` function</description>
792 <path>src/lib.rs</path>
793 <operation>insert_after</operation>
794 <old_text>fn one</old_text>
795 <new_text>
796 fn two() {}
797 </new_text>
798 </edit>
799 </patch>
800
801 also,»",
802 cx,
803 );
804 expect_patches(
805 &context,
806 "
807
808 one
809 two
810
811 «<patch>
812 <edit>
813 <description>add a `two` function</description>
814 <path>src/lib.rs</path>
815 <operation>insert_after</operation>
816 <old_text>fn one</old_text>
817 <new_text>
818 fn two() {}
819 </new_text>
820 </edit>
821 </patch>
822 »
823 also,",
824 &[&[AssistantEdit {
825 path: "src/lib.rs".into(),
826 kind: AssistantEditKind::InsertAfter {
827 old_text: "fn one".into(),
828 new_text: "fn two() {}".into(),
829 description: Some("add a `two` function".into()),
830 },
831 }]],
832 cx,
833 );
834
835 // The step is manually edited.
836 edit(
837 &context,
838 "
839
840 one
841 two
842
843 <patch>
844 <edit>
845 <description>add a `two` function</description>
846 <path>src/lib.rs</path>
847 <operation>insert_after</operation>
848 <old_text>«fn zero»</old_text>
849 <new_text>
850 fn two() {}
851 </new_text>
852 </edit>
853 </patch>
854
855 also,",
856 cx,
857 );
858 expect_patches(
859 &context,
860 "
861
862 one
863 two
864
865 «<patch>
866 <edit>
867 <description>add a `two` function</description>
868 <path>src/lib.rs</path>
869 <operation>insert_after</operation>
870 <old_text>fn zero</old_text>
871 <new_text>
872 fn two() {}
873 </new_text>
874 </edit>
875 </patch>
876 »
877 also,",
878 &[&[AssistantEdit {
879 path: "src/lib.rs".into(),
880 kind: AssistantEditKind::InsertAfter {
881 old_text: "fn zero".into(),
882 new_text: "fn two() {}".into(),
883 description: Some("add a `two` function".into()),
884 },
885 }]],
886 cx,
887 );
888
889 // When setting the message role to User, the steps are cleared.
890 context.update(cx, |context, cx| {
891 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
892 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
893 });
894 expect_patches(
895 &context,
896 "
897
898 one
899 two
900
901 <patch>
902 <edit>
903 <description>add a `two` function</description>
904 <path>src/lib.rs</path>
905 <operation>insert_after</operation>
906 <old_text>fn zero</old_text>
907 <new_text>
908 fn two() {}
909 </new_text>
910 </edit>
911 </patch>
912
913 also,",
914 &[],
915 cx,
916 );
917
918 // When setting the message role back to Assistant, the steps are reparsed.
919 context.update(cx, |context, cx| {
920 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
921 });
922 expect_patches(
923 &context,
924 "
925
926 one
927 two
928
929 «<patch>
930 <edit>
931 <description>add a `two` function</description>
932 <path>src/lib.rs</path>
933 <operation>insert_after</operation>
934 <old_text>fn zero</old_text>
935 <new_text>
936 fn two() {}
937 </new_text>
938 </edit>
939 </patch>
940 »
941 also,",
942 &[&[AssistantEdit {
943 path: "src/lib.rs".into(),
944 kind: AssistantEditKind::InsertAfter {
945 old_text: "fn zero".into(),
946 new_text: "fn two() {}".into(),
947 description: Some("add a `two` function".into()),
948 },
949 }]],
950 cx,
951 );
952
953 // Ensure steps are re-parsed when deserializing.
954 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
955 let deserialized_context = cx.new(|cx| {
956 AssistantContext::deserialize(
957 serialized_context,
958 Path::new("").into(),
959 registry.clone(),
960 prompt_builder.clone(),
961 Arc::new(SlashCommandWorkingSet::default()),
962 None,
963 None,
964 cx,
965 )
966 });
967 expect_patches(
968 &deserialized_context,
969 "
970
971 one
972 two
973
974 «<patch>
975 <edit>
976 <description>add a `two` function</description>
977 <path>src/lib.rs</path>
978 <operation>insert_after</operation>
979 <old_text>fn zero</old_text>
980 <new_text>
981 fn two() {}
982 </new_text>
983 </edit>
984 </patch>
985 »
986 also,",
987 &[&[AssistantEdit {
988 path: "src/lib.rs".into(),
989 kind: AssistantEditKind::InsertAfter {
990 old_text: "fn zero".into(),
991 new_text: "fn two() {}".into(),
992 description: Some("add a `two` function".into()),
993 },
994 }]],
995 cx,
996 );
997
998 fn edit(
999 context: &Entity<AssistantContext>,
1000 new_text_marked_with_edits: &str,
1001 cx: &mut TestAppContext,
1002 ) {
1003 context.update(cx, |context, cx| {
1004 context.buffer.update(cx, |buffer, cx| {
1005 buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
1006 });
1007 });
1008 cx.executor().run_until_parked();
1009 }
1010
1011 #[track_caller]
1012 fn expect_patches(
1013 context: &Entity<AssistantContext>,
1014 expected_marked_text: &str,
1015 expected_suggestions: &[&[AssistantEdit]],
1016 cx: &mut TestAppContext,
1017 ) {
1018 let expected_marked_text = expected_marked_text.unindent();
1019 let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
1020
1021 let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
1022 context.buffer.read_with(cx, |buffer, _| {
1023 let ranges = context
1024 .patches
1025 .iter()
1026 .map(|entry| entry.range.to_offset(buffer))
1027 .collect::<Vec<_>>();
1028 (
1029 buffer.text(),
1030 ranges,
1031 context
1032 .patches
1033 .iter()
1034 .map(|step| step.edits.clone())
1035 .collect::<Vec<_>>(),
1036 )
1037 })
1038 });
1039
1040 assert_eq!(buffer_text, expected_text);
1041
1042 let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
1043 assert_eq!(actual_marked_text, expected_marked_text);
1044
1045 assert_eq!(
1046 patches
1047 .iter()
1048 .map(|patch| {
1049 patch
1050 .iter()
1051 .map(|edit| {
1052 let edit = edit.as_ref().unwrap();
1053 AssistantEdit {
1054 path: edit.path.clone(),
1055 kind: edit.kind.clone(),
1056 }
1057 })
1058 .collect::<Vec<_>>()
1059 })
1060 .collect::<Vec<_>>(),
1061 expected_suggestions
1062 );
1063 }
1064}
1065
1066#[gpui::test]
1067async fn test_serialization(cx: &mut TestAppContext) {
1068 cx.update(init_test);
1069
1070 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1071 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1072 let context = cx.new(|cx| {
1073 AssistantContext::local(
1074 registry.clone(),
1075 None,
1076 None,
1077 prompt_builder.clone(),
1078 Arc::new(SlashCommandWorkingSet::default()),
1079 cx,
1080 )
1081 });
1082 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
1083 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
1084 let message_1 = context.update(cx, |context, cx| {
1085 context
1086 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
1087 .unwrap()
1088 });
1089 let message_2 = context.update(cx, |context, cx| {
1090 context
1091 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1092 .unwrap()
1093 });
1094 buffer.update(cx, |buffer, cx| {
1095 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
1096 buffer.finalize_last_transaction();
1097 });
1098 let _message_3 = context.update(cx, |context, cx| {
1099 context
1100 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
1101 .unwrap()
1102 });
1103 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1104 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
1105 assert_eq!(
1106 cx.read(|cx| messages(&context, cx)),
1107 [
1108 (message_0, Role::User, 0..2),
1109 (message_1.id, Role::Assistant, 2..6),
1110 (message_2.id, Role::System, 6..6),
1111 ]
1112 );
1113
1114 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
1115 let deserialized_context = cx.new(|cx| {
1116 AssistantContext::deserialize(
1117 serialized_context,
1118 Path::new("").into(),
1119 registry.clone(),
1120 prompt_builder.clone(),
1121 Arc::new(SlashCommandWorkingSet::default()),
1122 None,
1123 None,
1124 cx,
1125 )
1126 });
1127 let deserialized_buffer =
1128 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
1129 assert_eq!(
1130 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
1131 "a\nb\nc\n"
1132 );
1133 assert_eq!(
1134 cx.read(|cx| messages(&deserialized_context, cx)),
1135 [
1136 (message_0, Role::User, 0..2),
1137 (message_1.id, Role::Assistant, 2..6),
1138 (message_2.id, Role::System, 6..6),
1139 ]
1140 );
1141}
1142
1143#[gpui::test(iterations = 100)]
1144async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
1145 cx.update(init_test);
1146
1147 let min_peers = env::var("MIN_PEERS")
1148 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
1149 .unwrap_or(2);
1150 let max_peers = env::var("MAX_PEERS")
1151 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
1152 .unwrap_or(5);
1153 let operations = env::var("OPERATIONS")
1154 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
1155 .unwrap_or(50);
1156
1157 let slash_commands = cx.update(SlashCommandRegistry::default_global);
1158 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
1159 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
1160 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
1161
1162 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
1163 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
1164 let mut contexts = Vec::new();
1165
1166 let num_peers = rng.gen_range(min_peers..=max_peers);
1167 let context_id = ContextId::new();
1168 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1169 for i in 0..num_peers {
1170 let context = cx.new(|cx| {
1171 AssistantContext::new(
1172 context_id.clone(),
1173 i as ReplicaId,
1174 language::Capability::ReadWrite,
1175 registry.clone(),
1176 prompt_builder.clone(),
1177 Arc::new(SlashCommandWorkingSet::default()),
1178 None,
1179 None,
1180 cx,
1181 )
1182 });
1183
1184 cx.update(|cx| {
1185 cx.subscribe(&context, {
1186 let network = network.clone();
1187 move |_, event, _| {
1188 if let ContextEvent::Operation(op) = event {
1189 network
1190 .lock()
1191 .broadcast(i as ReplicaId, vec![op.to_proto()]);
1192 }
1193 }
1194 })
1195 .detach();
1196 });
1197
1198 contexts.push(context);
1199 network.lock().add_peer(i as ReplicaId);
1200 }
1201
1202 let mut mutation_count = operations;
1203
1204 while mutation_count > 0
1205 || !network.lock().is_idle()
1206 || network.lock().contains_disconnected_peers()
1207 {
1208 let context_index = rng.gen_range(0..contexts.len());
1209 let context = &contexts[context_index];
1210
1211 match rng.gen_range(0..100) {
1212 0..=29 if mutation_count > 0 => {
1213 log::info!("Context {}: edit buffer", context_index);
1214 context.update(cx, |context, cx| {
1215 context
1216 .buffer
1217 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1218 });
1219 mutation_count -= 1;
1220 }
1221 30..=44 if mutation_count > 0 => {
1222 context.update(cx, |context, cx| {
1223 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1224 log::info!("Context {}: split message at {:?}", context_index, range);
1225 context.split_message(range, cx);
1226 });
1227 mutation_count -= 1;
1228 }
1229 45..=59 if mutation_count > 0 => {
1230 context.update(cx, |context, cx| {
1231 if let Some(message) = context.messages(cx).choose(&mut rng) {
1232 let role = *[Role::User, Role::Assistant, Role::System]
1233 .choose(&mut rng)
1234 .unwrap();
1235 log::info!(
1236 "Context {}: insert message after {:?} with {:?}",
1237 context_index,
1238 message.id,
1239 role
1240 );
1241 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1242 }
1243 });
1244 mutation_count -= 1;
1245 }
1246 60..=74 if mutation_count > 0 => {
1247 context.update(cx, |context, cx| {
1248 let command_text = "/".to_string()
1249 + slash_commands
1250 .command_names()
1251 .choose(&mut rng)
1252 .unwrap()
1253 .clone()
1254 .as_ref();
1255
1256 let command_range = context.buffer.update(cx, |buffer, cx| {
1257 let offset = buffer.random_byte_range(0, &mut rng).start;
1258 buffer.edit(
1259 [(offset..offset, format!("\n{}\n", command_text))],
1260 None,
1261 cx,
1262 );
1263 offset + 1..offset + 1 + command_text.len()
1264 });
1265
1266 let output_text = RandomCharIter::new(&mut rng)
1267 .filter(|c| *c != '\r')
1268 .take(10)
1269 .collect::<String>();
1270
1271 let mut events = vec![Ok(SlashCommandEvent::StartMessage {
1272 role: Role::User,
1273 merge_same_roles: true,
1274 })];
1275
1276 let num_sections = rng.gen_range(0..=3);
1277 let mut section_start = 0;
1278 for _ in 0..num_sections {
1279 let mut section_end = rng.gen_range(section_start..=output_text.len());
1280 while !output_text.is_char_boundary(section_end) {
1281 section_end += 1;
1282 }
1283 events.push(Ok(SlashCommandEvent::StartSection {
1284 icon: IconName::Ai,
1285 label: "section".into(),
1286 metadata: None,
1287 }));
1288 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1289 text: output_text[section_start..section_end].to_string(),
1290 run_commands_in_text: false,
1291 })));
1292 events.push(Ok(SlashCommandEvent::EndSection));
1293 section_start = section_end;
1294 }
1295
1296 if section_start < output_text.len() {
1297 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1298 text: output_text[section_start..].to_string(),
1299 run_commands_in_text: false,
1300 })));
1301 }
1302
1303 log::info!(
1304 "Context {}: insert slash command output at {:?} with {:?} events",
1305 context_index,
1306 command_range,
1307 events.len()
1308 );
1309
1310 let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1311 ..context.buffer.read(cx).anchor_after(command_range.end);
1312 context.insert_command_output(
1313 command_range,
1314 "/command",
1315 Task::ready(Ok(stream::iter(events).boxed())),
1316 true,
1317 cx,
1318 );
1319 });
1320 cx.run_until_parked();
1321 mutation_count -= 1;
1322 }
1323 75..=84 if mutation_count > 0 => {
1324 context.update(cx, |context, cx| {
1325 if let Some(message) = context.messages(cx).choose(&mut rng) {
1326 let new_status = match rng.gen_range(0..3) {
1327 0 => MessageStatus::Done,
1328 1 => MessageStatus::Pending,
1329 _ => MessageStatus::Error(SharedString::from("Random error")),
1330 };
1331 log::info!(
1332 "Context {}: update message {:?} status to {:?}",
1333 context_index,
1334 message.id,
1335 new_status
1336 );
1337 context.update_metadata(message.id, cx, |metadata| {
1338 metadata.status = new_status;
1339 });
1340 }
1341 });
1342 mutation_count -= 1;
1343 }
1344 _ => {
1345 let replica_id = context_index as ReplicaId;
1346 if network.lock().is_disconnected(replica_id) {
1347 network.lock().reconnect_peer(replica_id, 0);
1348
1349 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1350 let host_context = &contexts[0].read(cx);
1351 let guest_context = context.read(cx);
1352 (
1353 guest_context.serialize_ops(&host_context.version(cx), cx),
1354 host_context.serialize_ops(&guest_context.version(cx), cx),
1355 )
1356 });
1357 let ops_to_send = ops_to_send.await;
1358 let ops_to_receive = ops_to_receive
1359 .await
1360 .into_iter()
1361 .map(ContextOperation::from_proto)
1362 .collect::<Result<Vec<_>>>()
1363 .unwrap();
1364 log::info!(
1365 "Context {}: reconnecting. Sent {} operations, received {} operations",
1366 context_index,
1367 ops_to_send.len(),
1368 ops_to_receive.len()
1369 );
1370
1371 network.lock().broadcast(replica_id, ops_to_send);
1372 context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1373 } else if rng.gen_bool(0.1) && replica_id != 0 {
1374 log::info!("Context {}: disconnecting", context_index);
1375 network.lock().disconnect_peer(replica_id);
1376 } else if network.lock().has_unreceived(replica_id) {
1377 log::info!("Context {}: applying operations", context_index);
1378 let ops = network.lock().receive(replica_id);
1379 let ops = ops
1380 .into_iter()
1381 .map(ContextOperation::from_proto)
1382 .collect::<Result<Vec<_>>>()
1383 .unwrap();
1384 context.update(cx, |context, cx| context.apply_ops(ops, cx));
1385 }
1386 }
1387 }
1388 }
1389
1390 cx.read(|cx| {
1391 let first_context = contexts[0].read(cx);
1392 for context in &contexts[1..] {
1393 let context = context.read(cx);
1394 assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
1395 assert_eq!(
1396 context.buffer.read(cx).text(),
1397 first_context.buffer.read(cx).text(),
1398 "Context {} text != Context 0 text",
1399 context.buffer.read(cx).replica_id()
1400 );
1401 assert_eq!(
1402 context.message_anchors,
1403 first_context.message_anchors,
1404 "Context {} messages != Context 0 messages",
1405 context.buffer.read(cx).replica_id()
1406 );
1407 assert_eq!(
1408 context.messages_metadata,
1409 first_context.messages_metadata,
1410 "Context {} message metadata != Context 0 message metadata",
1411 context.buffer.read(cx).replica_id()
1412 );
1413 assert_eq!(
1414 context.slash_command_output_sections,
1415 first_context.slash_command_output_sections,
1416 "Context {} slash command output sections != Context 0 slash command output sections",
1417 context.buffer.read(cx).replica_id()
1418 );
1419 }
1420 });
1421}
1422
1423#[gpui::test]
1424fn test_mark_cache_anchors(cx: &mut App) {
1425 init_test(cx);
1426
1427 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1428 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1429 let context = cx.new(|cx| {
1430 AssistantContext::local(
1431 registry,
1432 None,
1433 None,
1434 prompt_builder.clone(),
1435 Arc::new(SlashCommandWorkingSet::default()),
1436 cx,
1437 )
1438 });
1439 let buffer = context.read(cx).buffer.clone();
1440
1441 // Create a test cache configuration
1442 let cache_configuration = &Some(LanguageModelCacheConfiguration {
1443 max_cache_anchors: 3,
1444 should_speculate: true,
1445 min_total_token: 10,
1446 });
1447
1448 let message_1 = context.read(cx).message_anchors[0].clone();
1449
1450 context.update(cx, |context, cx| {
1451 context.mark_cache_anchors(cache_configuration, false, cx)
1452 });
1453
1454 assert_eq!(
1455 messages_cache(&context, cx)
1456 .iter()
1457 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1458 .count(),
1459 0,
1460 "Empty messages should not have any cache anchors."
1461 );
1462
1463 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1464 let message_2 = context
1465 .update(cx, |context, cx| {
1466 context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1467 })
1468 .unwrap();
1469
1470 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1471 let message_3 = context
1472 .update(cx, |context, cx| {
1473 context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1474 })
1475 .unwrap();
1476 buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1477
1478 context.update(cx, |context, cx| {
1479 context.mark_cache_anchors(cache_configuration, false, cx)
1480 });
1481 assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1482 assert_eq!(
1483 messages_cache(&context, cx)
1484 .iter()
1485 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1486 .count(),
1487 0,
1488 "Messages should not be marked for cache before going over the token minimum."
1489 );
1490 context.update(cx, |context, _| {
1491 context.token_count = Some(20);
1492 });
1493
1494 context.update(cx, |context, cx| {
1495 context.mark_cache_anchors(cache_configuration, true, cx)
1496 });
1497 assert_eq!(
1498 messages_cache(&context, cx)
1499 .iter()
1500 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1501 .collect::<Vec<bool>>(),
1502 vec![true, true, false],
1503 "Last message should not be an anchor on speculative request."
1504 );
1505
1506 context
1507 .update(cx, |context, cx| {
1508 context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1509 })
1510 .unwrap();
1511
1512 context.update(cx, |context, cx| {
1513 context.mark_cache_anchors(cache_configuration, false, cx)
1514 });
1515 assert_eq!(
1516 messages_cache(&context, cx)
1517 .iter()
1518 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1519 .collect::<Vec<bool>>(),
1520 vec![false, true, true, false],
1521 "Most recent message should also be cached if not a speculative request."
1522 );
1523 context.update(cx, |context, cx| {
1524 context.update_cache_status_for_completion(cx)
1525 });
1526 assert_eq!(
1527 messages_cache(&context, cx)
1528 .iter()
1529 .map(|(_, cache)| cache
1530 .as_ref()
1531 .map_or(None, |cache| Some(cache.status.clone())))
1532 .collect::<Vec<Option<CacheStatus>>>(),
1533 vec![
1534 Some(CacheStatus::Cached),
1535 Some(CacheStatus::Cached),
1536 Some(CacheStatus::Cached),
1537 None
1538 ],
1539 "All user messages prior to anchor should be marked as cached."
1540 );
1541
1542 buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1543 context.update(cx, |context, cx| {
1544 context.mark_cache_anchors(cache_configuration, false, cx)
1545 });
1546 assert_eq!(
1547 messages_cache(&context, cx)
1548 .iter()
1549 .map(|(_, cache)| cache
1550 .as_ref()
1551 .map_or(None, |cache| Some(cache.status.clone())))
1552 .collect::<Vec<Option<CacheStatus>>>(),
1553 vec![
1554 Some(CacheStatus::Cached),
1555 Some(CacheStatus::Cached),
1556 Some(CacheStatus::Pending),
1557 None
1558 ],
1559 "Modifying a message should invalidate it's cache but leave previous messages."
1560 );
1561 buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1562 context.update(cx, |context, cx| {
1563 context.mark_cache_anchors(cache_configuration, false, cx)
1564 });
1565 assert_eq!(
1566 messages_cache(&context, cx)
1567 .iter()
1568 .map(|(_, cache)| cache
1569 .as_ref()
1570 .map_or(None, |cache| Some(cache.status.clone())))
1571 .collect::<Vec<Option<CacheStatus>>>(),
1572 vec![
1573 Some(CacheStatus::Pending),
1574 Some(CacheStatus::Pending),
1575 Some(CacheStatus::Pending),
1576 None
1577 ],
1578 "Modifying a message should invalidate all future messages."
1579 );
1580}
1581
1582#[gpui::test]
1583async fn test_summarization(cx: &mut TestAppContext) {
1584 let (context, fake_model) = setup_context_editor_with_fake_model(cx);
1585
1586 // Initial state should be pending
1587 context.read_with(cx, |context, _| {
1588 assert!(matches!(context.summary(), ContextSummary::Pending));
1589 assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT);
1590 });
1591
1592 let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone());
1593 context.update(cx, |context, cx| {
1594 context
1595 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1596 .unwrap();
1597 });
1598
1599 // Send a message
1600 context.update(cx, |context, cx| {
1601 context.assist(RequestType::Chat, cx);
1602 });
1603
1604 simulate_successful_response(&fake_model, cx);
1605
1606 // Should start generating summary when there are >= 2 messages
1607 context.read_with(cx, |context, _| {
1608 assert!(!context.summary().content().unwrap().done);
1609 });
1610
1611 cx.run_until_parked();
1612 fake_model.stream_last_completion_response("Brief".into());
1613 fake_model.stream_last_completion_response(" Introduction".into());
1614 fake_model.end_last_completion_stream();
1615 cx.run_until_parked();
1616
1617 // Summary should be set
1618 context.read_with(cx, |context, _| {
1619 assert_eq!(context.summary().or_default(), "Brief Introduction");
1620 });
1621
1622 // We should be able to manually set a summary
1623 context.update(cx, |context, cx| {
1624 context.set_custom_summary("Brief Intro".into(), cx);
1625 });
1626
1627 context.read_with(cx, |context, _| {
1628 assert_eq!(context.summary().or_default(), "Brief Intro");
1629 });
1630}
1631
1632#[gpui::test]
1633async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
1634 let (context, fake_model) = setup_context_editor_with_fake_model(cx);
1635
1636 test_summarize_error(&fake_model, &context, cx);
1637
1638 // Now we should be able to set a summary
1639 context.update(cx, |context, cx| {
1640 context.set_custom_summary("Brief Intro".into(), cx);
1641 });
1642
1643 context.read_with(cx, |context, _| {
1644 assert_eq!(context.summary().or_default(), "Brief Intro");
1645 });
1646}
1647
1648#[gpui::test]
1649async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
1650 let (context, fake_model) = setup_context_editor_with_fake_model(cx);
1651
1652 test_summarize_error(&fake_model, &context, cx);
1653
1654 // Sending another message should not trigger another summarize request
1655 context.update(cx, |context, cx| {
1656 context.assist(RequestType::Chat, cx);
1657 });
1658
1659 simulate_successful_response(&fake_model, cx);
1660
1661 context.read_with(cx, |context, _| {
1662 // State is still Error, not Generating
1663 assert!(matches!(context.summary(), ContextSummary::Error));
1664 });
1665
1666 // But the summarize request can be invoked manually
1667 context.update(cx, |context, cx| {
1668 context.summarize(true, cx);
1669 });
1670
1671 context.read_with(cx, |context, _| {
1672 assert!(!context.summary().content().unwrap().done);
1673 });
1674
1675 cx.run_until_parked();
1676 fake_model.stream_last_completion_response("A successful summary".into());
1677 fake_model.end_last_completion_stream();
1678 cx.run_until_parked();
1679
1680 context.read_with(cx, |context, _| {
1681 assert_eq!(context.summary().or_default(), "A successful summary");
1682 });
1683}
1684
1685fn test_summarize_error(
1686 model: &Arc<FakeLanguageModel>,
1687 context: &Entity<AssistantContext>,
1688 cx: &mut TestAppContext,
1689) {
1690 let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone());
1691 context.update(cx, |context, cx| {
1692 context
1693 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1694 .unwrap();
1695 });
1696
1697 // Send a message
1698 context.update(cx, |context, cx| {
1699 context.assist(RequestType::Chat, cx);
1700 });
1701
1702 simulate_successful_response(&model, cx);
1703
1704 context.read_with(cx, |context, _| {
1705 assert!(!context.summary().content().unwrap().done);
1706 });
1707
1708 // Simulate summary request ending
1709 cx.run_until_parked();
1710 model.end_last_completion_stream();
1711 cx.run_until_parked();
1712
1713 // State is set to Error and default message
1714 context.read_with(cx, |context, _| {
1715 assert_eq!(*context.summary(), ContextSummary::Error);
1716 assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT);
1717 });
1718}
1719
1720fn setup_context_editor_with_fake_model(
1721 cx: &mut TestAppContext,
1722) -> (Entity<AssistantContext>, Arc<FakeLanguageModel>) {
1723 let registry = Arc::new(LanguageRegistry::test(cx.executor().clone()));
1724
1725 let fake_provider = Arc::new(FakeLanguageModelProvider);
1726 let fake_model = Arc::new(fake_provider.test_model());
1727
1728 cx.update(|cx| {
1729 init_test(cx);
1730 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1731 registry.set_default_model(
1732 Some(ConfiguredModel {
1733 provider: fake_provider.clone(),
1734 model: fake_model.clone(),
1735 }),
1736 cx,
1737 )
1738 })
1739 });
1740
1741 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1742 let context = cx.new(|cx| {
1743 AssistantContext::local(
1744 registry,
1745 None,
1746 None,
1747 prompt_builder.clone(),
1748 Arc::new(SlashCommandWorkingSet::default()),
1749 cx,
1750 )
1751 });
1752
1753 (context, fake_model)
1754}
1755
1756fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
1757 cx.run_until_parked();
1758 fake_model.stream_last_completion_response("Assistant response".into());
1759 fake_model.end_last_completion_stream();
1760 cx.run_until_parked();
1761}
1762
1763fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1764 context
1765 .read(cx)
1766 .messages(cx)
1767 .map(|message| (message.id, message.role, message.offset_range))
1768 .collect()
1769}
1770
1771fn messages_cache(
1772 context: &Entity<AssistantContext>,
1773 cx: &App,
1774) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1775 context
1776 .read(cx)
1777 .messages(cx)
1778 .map(|message| (message.id, message.cache.clone()))
1779 .collect()
1780}
1781
1782fn init_test(cx: &mut App) {
1783 let settings_store = SettingsStore::test(cx);
1784 prompt_store::init(cx);
1785 LanguageModelRegistry::test(cx);
1786 cx.set_global(settings_store);
1787 language::init(cx);
1788 assistant_settings::init(cx);
1789 Project::init_settings(cx);
1790}
1791
1792#[derive(Clone)]
1793struct FakeSlashCommand(String);
1794
1795impl SlashCommand for FakeSlashCommand {
1796 fn name(&self) -> String {
1797 self.0.clone()
1798 }
1799
1800 fn description(&self) -> String {
1801 format!("Fake slash command: {}", self.0)
1802 }
1803
1804 fn menu_text(&self) -> String {
1805 format!("Run fake command: {}", self.0)
1806 }
1807
1808 fn complete_argument(
1809 self: Arc<Self>,
1810 _arguments: &[String],
1811 _cancel: Arc<AtomicBool>,
1812 _workspace: Option<WeakEntity<Workspace>>,
1813 _window: &mut Window,
1814 _cx: &mut App,
1815 ) -> Task<Result<Vec<ArgumentCompletion>>> {
1816 Task::ready(Ok(vec![]))
1817 }
1818
1819 fn requires_argument(&self) -> bool {
1820 false
1821 }
1822
1823 fn run(
1824 self: Arc<Self>,
1825 _arguments: &[String],
1826 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1827 _context_buffer: BufferSnapshot,
1828 _workspace: WeakEntity<Workspace>,
1829 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1830 _window: &mut Window,
1831 _cx: &mut App,
1832 ) -> Task<SlashCommandResult> {
1833 Task::ready(Ok(SlashCommandOutput {
1834 text: format!("Executed fake command: {}", self.0),
1835 sections: vec![],
1836 run_commands_in_text: false,
1837 }
1838 .to_event_stream()))
1839 }
1840}