1use crate::{
2 assistant_panel, prompt_library, slash_command::file_command, workflow::tool, CacheStatus,
3 Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
4};
5use anyhow::Result;
6use assistant_slash_command::{
7 ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
8 SlashCommandRegistry,
9};
10use collections::HashSet;
11use fs::{FakeFs, Fs as _};
12use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
13use indoc::indoc;
14use language::{Buffer, LanguageRegistry, LspAdapterDelegate};
15use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
16use parking_lot::Mutex;
17use project::Project;
18use rand::prelude::*;
19use rope::Point;
20use serde_json::json;
21use settings::SettingsStore;
22use std::{
23 cell::RefCell,
24 env,
25 ops::Range,
26 path::Path,
27 rc::Rc,
28 sync::{atomic::AtomicBool, Arc},
29};
30use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToPoint as _};
31use ui::{Context as _, WindowContext};
32use unindent::Unindent;
33use util::{test::marked_text_ranges, RandomCharIter};
34use workspace::Workspace;
35
36use super::MessageCacheMetadata;
37
38#[gpui::test]
39fn test_inserting_and_removing_messages(cx: &mut AppContext) {
40 let settings_store = SettingsStore::test(cx);
41 LanguageModelRegistry::test(cx);
42 cx.set_global(settings_store);
43 assistant_panel::init(cx);
44 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
45 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
46 let context =
47 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
48 let buffer = context.read(cx).buffer.clone();
49
50 let message_1 = context.read(cx).message_anchors[0].clone();
51 assert_eq!(
52 messages(&context, cx),
53 vec![(message_1.id, Role::User, 0..0)]
54 );
55
56 let message_2 = context.update(cx, |context, cx| {
57 context
58 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
59 .unwrap()
60 });
61 assert_eq!(
62 messages(&context, cx),
63 vec![
64 (message_1.id, Role::User, 0..1),
65 (message_2.id, Role::Assistant, 1..1)
66 ]
67 );
68
69 buffer.update(cx, |buffer, cx| {
70 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
71 });
72 assert_eq!(
73 messages(&context, cx),
74 vec![
75 (message_1.id, Role::User, 0..2),
76 (message_2.id, Role::Assistant, 2..3)
77 ]
78 );
79
80 let message_3 = context.update(cx, |context, cx| {
81 context
82 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
83 .unwrap()
84 });
85 assert_eq!(
86 messages(&context, cx),
87 vec![
88 (message_1.id, Role::User, 0..2),
89 (message_2.id, Role::Assistant, 2..4),
90 (message_3.id, Role::User, 4..4)
91 ]
92 );
93
94 let message_4 = context.update(cx, |context, cx| {
95 context
96 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
97 .unwrap()
98 });
99 assert_eq!(
100 messages(&context, cx),
101 vec![
102 (message_1.id, Role::User, 0..2),
103 (message_2.id, Role::Assistant, 2..4),
104 (message_4.id, Role::User, 4..5),
105 (message_3.id, Role::User, 5..5),
106 ]
107 );
108
109 buffer.update(cx, |buffer, cx| {
110 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
111 });
112 assert_eq!(
113 messages(&context, cx),
114 vec![
115 (message_1.id, Role::User, 0..2),
116 (message_2.id, Role::Assistant, 2..4),
117 (message_4.id, Role::User, 4..6),
118 (message_3.id, Role::User, 6..7),
119 ]
120 );
121
122 // Deleting across message boundaries merges the messages.
123 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
124 assert_eq!(
125 messages(&context, cx),
126 vec![
127 (message_1.id, Role::User, 0..3),
128 (message_3.id, Role::User, 3..4),
129 ]
130 );
131
132 // Undoing the deletion should also undo the merge.
133 buffer.update(cx, |buffer, cx| buffer.undo(cx));
134 assert_eq!(
135 messages(&context, cx),
136 vec![
137 (message_1.id, Role::User, 0..2),
138 (message_2.id, Role::Assistant, 2..4),
139 (message_4.id, Role::User, 4..6),
140 (message_3.id, Role::User, 6..7),
141 ]
142 );
143
144 // Redoing the deletion should also redo the merge.
145 buffer.update(cx, |buffer, cx| buffer.redo(cx));
146 assert_eq!(
147 messages(&context, cx),
148 vec![
149 (message_1.id, Role::User, 0..3),
150 (message_3.id, Role::User, 3..4),
151 ]
152 );
153
154 // Ensure we can still insert after a merged message.
155 let message_5 = context.update(cx, |context, cx| {
156 context
157 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
158 .unwrap()
159 });
160 assert_eq!(
161 messages(&context, cx),
162 vec![
163 (message_1.id, Role::User, 0..3),
164 (message_5.id, Role::System, 3..4),
165 (message_3.id, Role::User, 4..5)
166 ]
167 );
168}
169
170#[gpui::test]
171fn test_message_splitting(cx: &mut AppContext) {
172 let settings_store = SettingsStore::test(cx);
173 cx.set_global(settings_store);
174 LanguageModelRegistry::test(cx);
175 assistant_panel::init(cx);
176 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
177
178 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
179 let context =
180 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
181 let buffer = context.read(cx).buffer.clone();
182
183 let message_1 = context.read(cx).message_anchors[0].clone();
184 assert_eq!(
185 messages(&context, cx),
186 vec![(message_1.id, Role::User, 0..0)]
187 );
188
189 buffer.update(cx, |buffer, cx| {
190 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
191 });
192
193 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
194 let message_2 = message_2.unwrap();
195
196 // We recycle newlines in the middle of a split message
197 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
198 assert_eq!(
199 messages(&context, cx),
200 vec![
201 (message_1.id, Role::User, 0..4),
202 (message_2.id, Role::User, 4..16),
203 ]
204 );
205
206 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
207 let message_3 = message_3.unwrap();
208
209 // We don't recycle newlines at the end of a split message
210 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
211 assert_eq!(
212 messages(&context, cx),
213 vec![
214 (message_1.id, Role::User, 0..4),
215 (message_3.id, Role::User, 4..5),
216 (message_2.id, Role::User, 5..17),
217 ]
218 );
219
220 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
221 let message_4 = message_4.unwrap();
222 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
223 assert_eq!(
224 messages(&context, cx),
225 vec![
226 (message_1.id, Role::User, 0..4),
227 (message_3.id, Role::User, 4..5),
228 (message_2.id, Role::User, 5..9),
229 (message_4.id, Role::User, 9..17),
230 ]
231 );
232
233 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
234 let message_5 = message_5.unwrap();
235 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
236 assert_eq!(
237 messages(&context, cx),
238 vec![
239 (message_1.id, Role::User, 0..4),
240 (message_3.id, Role::User, 4..5),
241 (message_2.id, Role::User, 5..9),
242 (message_4.id, Role::User, 9..10),
243 (message_5.id, Role::User, 10..18),
244 ]
245 );
246
247 let (message_6, message_7) =
248 context.update(cx, |context, cx| context.split_message(14..16, cx));
249 let message_6 = message_6.unwrap();
250 let message_7 = message_7.unwrap();
251 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
252 assert_eq!(
253 messages(&context, cx),
254 vec![
255 (message_1.id, Role::User, 0..4),
256 (message_3.id, Role::User, 4..5),
257 (message_2.id, Role::User, 5..9),
258 (message_4.id, Role::User, 9..10),
259 (message_5.id, Role::User, 10..14),
260 (message_6.id, Role::User, 14..17),
261 (message_7.id, Role::User, 17..19),
262 ]
263 );
264}
265
266#[gpui::test]
267fn test_messages_for_offsets(cx: &mut AppContext) {
268 let settings_store = SettingsStore::test(cx);
269 LanguageModelRegistry::test(cx);
270 cx.set_global(settings_store);
271 assistant_panel::init(cx);
272 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
273 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
274 let context =
275 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
276 let buffer = context.read(cx).buffer.clone();
277
278 let message_1 = context.read(cx).message_anchors[0].clone();
279 assert_eq!(
280 messages(&context, cx),
281 vec![(message_1.id, Role::User, 0..0)]
282 );
283
284 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
285 let message_2 = context
286 .update(cx, |context, cx| {
287 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
288 })
289 .unwrap();
290 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
291
292 let message_3 = context
293 .update(cx, |context, cx| {
294 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
295 })
296 .unwrap();
297 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
298
299 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
300 assert_eq!(
301 messages(&context, cx),
302 vec![
303 (message_1.id, Role::User, 0..4),
304 (message_2.id, Role::User, 4..8),
305 (message_3.id, Role::User, 8..11)
306 ]
307 );
308
309 assert_eq!(
310 message_ids_for_offsets(&context, &[0, 4, 9], cx),
311 [message_1.id, message_2.id, message_3.id]
312 );
313 assert_eq!(
314 message_ids_for_offsets(&context, &[0, 1, 11], cx),
315 [message_1.id, message_3.id]
316 );
317
318 let message_4 = context
319 .update(cx, |context, cx| {
320 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
321 })
322 .unwrap();
323 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
324 assert_eq!(
325 messages(&context, cx),
326 vec![
327 (message_1.id, Role::User, 0..4),
328 (message_2.id, Role::User, 4..8),
329 (message_3.id, Role::User, 8..12),
330 (message_4.id, Role::User, 12..12)
331 ]
332 );
333 assert_eq!(
334 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
335 [message_1.id, message_2.id, message_3.id, message_4.id]
336 );
337
338 fn message_ids_for_offsets(
339 context: &Model<Context>,
340 offsets: &[usize],
341 cx: &AppContext,
342 ) -> Vec<MessageId> {
343 context
344 .read(cx)
345 .messages_for_offsets(offsets.iter().copied(), cx)
346 .into_iter()
347 .map(|message| message.id)
348 .collect()
349 }
350}
351
352#[gpui::test]
353async fn test_slash_commands(cx: &mut TestAppContext) {
354 let settings_store = cx.update(SettingsStore::test);
355 cx.set_global(settings_store);
356 cx.update(LanguageModelRegistry::test);
357 cx.update(Project::init_settings);
358 cx.update(assistant_panel::init);
359 let fs = FakeFs::new(cx.background_executor.clone());
360
361 fs.insert_tree(
362 "/test",
363 json!({
364 "src": {
365 "lib.rs": "fn one() -> usize { 1 }",
366 "main.rs": "
367 use crate::one;
368 fn main() { one(); }
369 ".unindent(),
370 }
371 }),
372 )
373 .await;
374
375 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
376 slash_command_registry.register_command(file_command::FileSlashCommand, false);
377
378 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
379 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
380 let context =
381 cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
382
383 let output_ranges = Rc::new(RefCell::new(HashSet::default()));
384 context.update(cx, |_, cx| {
385 cx.subscribe(&context, {
386 let ranges = output_ranges.clone();
387 move |_, _, event, _| match event {
388 ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
389 for range in removed {
390 ranges.borrow_mut().remove(range);
391 }
392 for command in updated {
393 ranges.borrow_mut().insert(command.source_range.clone());
394 }
395 }
396 _ => {}
397 }
398 })
399 .detach();
400 });
401
402 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
403
404 // Insert a slash command
405 buffer.update(cx, |buffer, cx| {
406 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
407 });
408 assert_text_and_output_ranges(
409 &buffer,
410 &output_ranges.borrow(),
411 "
412 «/file src/lib.rs»
413 "
414 .unindent()
415 .trim_end(),
416 cx,
417 );
418
419 // Edit the argument of the slash command.
420 buffer.update(cx, |buffer, cx| {
421 let edit_offset = buffer.text().find("lib.rs").unwrap();
422 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
423 });
424 assert_text_and_output_ranges(
425 &buffer,
426 &output_ranges.borrow(),
427 "
428 «/file src/main.rs»
429 "
430 .unindent()
431 .trim_end(),
432 cx,
433 );
434
435 // Edit the name of the slash command, using one that doesn't exist.
436 buffer.update(cx, |buffer, cx| {
437 let edit_offset = buffer.text().find("/file").unwrap();
438 buffer.edit(
439 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
440 None,
441 cx,
442 );
443 });
444 assert_text_and_output_ranges(
445 &buffer,
446 &output_ranges.borrow(),
447 "
448 /unknown src/main.rs
449 "
450 .unindent()
451 .trim_end(),
452 cx,
453 );
454
455 #[track_caller]
456 fn assert_text_and_output_ranges(
457 buffer: &Model<Buffer>,
458 ranges: &HashSet<Range<language::Anchor>>,
459 expected_marked_text: &str,
460 cx: &mut TestAppContext,
461 ) {
462 let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
463 let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
464 let mut ranges = ranges
465 .iter()
466 .map(|range| range.to_offset(buffer))
467 .collect::<Vec<_>>();
468 ranges.sort_by_key(|a| a.start);
469 (buffer.text(), ranges)
470 });
471
472 assert_eq!(actual_text, expected_text);
473 assert_eq!(actual_ranges, expected_ranges);
474 }
475}
476
477#[gpui::test]
478async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
479 cx.update(prompt_library::init);
480 let settings_store = cx.update(SettingsStore::test);
481 cx.set_global(settings_store);
482 cx.update(Project::init_settings);
483 let fs = FakeFs::new(cx.executor());
484 fs.as_fake()
485 .insert_tree(
486 "/root",
487 json!({
488 "hello.rs": r#"
489 fn hello() {
490 println!("Hello, World!");
491 }
492 "#.unindent()
493 }),
494 )
495 .await;
496 let project = Project::test(fs, [Path::new("/root")], cx).await;
497 cx.update(LanguageModelRegistry::test);
498
499 let model = cx.read(|cx| {
500 LanguageModelRegistry::read_global(cx)
501 .active_model()
502 .unwrap()
503 });
504 cx.update(assistant_panel::init);
505 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
506
507 // Create a new context
508 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
509 let context = cx.new_model(|cx| {
510 Context::local(
511 registry.clone(),
512 Some(project),
513 None,
514 prompt_builder.clone(),
515 cx,
516 )
517 });
518 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
519
520 // Simulate user input
521 let user_message = indoc! {r#"
522 Please add unnecessary complexity to this code:
523
524 ```hello.rs
525 fn main() {
526 println!("Hello, World!");
527 }
528 ```
529 "#};
530 buffer.update(cx, |buffer, cx| {
531 buffer.edit([(0..0, user_message)], None, cx);
532 });
533
534 // Simulate LLM response with edit steps
535 let llm_response = indoc! {r#"
536 Sure, I can help you with that. Here's a step-by-step process:
537
538 <step>
539 First, let's extract the greeting into a separate function:
540
541 ```rust
542 fn greet() {
543 println!("Hello, World!");
544 }
545
546 fn main() {
547 greet();
548 }
549 ```
550 </step>
551
552 <step>
553 Now, let's make the greeting customizable:
554
555 ```rust
556 fn greet(name: &str) {
557 println!("Hello, {}!", name);
558 }
559
560 fn main() {
561 greet("World");
562 }
563 ```
564 </step>
565
566 These changes make the code more modular and flexible.
567 "#};
568
569 // Simulate the assist method to trigger the LLM response
570 context.update(cx, |context, cx| context.assist(cx));
571 cx.run_until_parked();
572
573 // Retrieve the assistant response message's start from the context
574 let response_start_row = context.read_with(cx, |context, cx| {
575 let buffer = context.buffer.read(cx);
576 context.message_anchors[1].start.to_point(buffer).row
577 });
578
579 // Simulate the LLM completion
580 model
581 .as_fake()
582 .stream_last_completion_response(llm_response.to_string());
583 model.as_fake().end_last_completion_stream();
584
585 // Wait for the completion to be processed
586 cx.run_until_parked();
587
588 // Verify that the edit steps were parsed correctly
589 context.read_with(cx, |context, cx| {
590 assert_eq!(
591 workflow_steps(context, cx),
592 vec![
593 (
594 Point::new(response_start_row + 2, 0)..Point::new(response_start_row + 12, 3),
595 WorkflowStepTestStatus::Pending
596 ),
597 (
598 Point::new(response_start_row + 14, 0)..Point::new(response_start_row + 24, 3),
599 WorkflowStepTestStatus::Pending
600 ),
601 ]
602 );
603 });
604
605 model
606 .as_fake()
607 .respond_to_last_tool_use(tool::WorkflowStepResolutionTool {
608 step_title: "Title".into(),
609 suggestions: vec![tool::WorkflowSuggestionTool {
610 path: "/root/hello.rs".into(),
611 // Simulate a symbol name that's slightly different than our outline query
612 kind: tool::WorkflowSuggestionToolKind::Update {
613 symbol: "fn main()".into(),
614 description: "Extract a greeting function".into(),
615 },
616 }],
617 });
618
619 // Wait for tool use to be processed.
620 cx.run_until_parked();
621
622 // Verify that the first edit step is not pending anymore.
623 context.read_with(cx, |context, cx| {
624 assert_eq!(
625 workflow_steps(context, cx),
626 vec![
627 (
628 Point::new(response_start_row + 2, 0)..Point::new(response_start_row + 12, 3),
629 WorkflowStepTestStatus::Resolved
630 ),
631 (
632 Point::new(response_start_row + 14, 0)..Point::new(response_start_row + 24, 3),
633 WorkflowStepTestStatus::Pending
634 ),
635 ]
636 );
637 });
638
639 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
640 enum WorkflowStepTestStatus {
641 Pending,
642 Resolved,
643 Error,
644 }
645
646 fn workflow_steps(
647 context: &Context,
648 cx: &AppContext,
649 ) -> Vec<(Range<Point>, WorkflowStepTestStatus)> {
650 context
651 .workflow_steps
652 .iter()
653 .map(|step| {
654 let buffer = context.buffer.read(cx);
655 let status = match &step.step.read(cx).resolution {
656 None => WorkflowStepTestStatus::Pending,
657 Some(Ok(_)) => WorkflowStepTestStatus::Resolved,
658 Some(Err(_)) => WorkflowStepTestStatus::Error,
659 };
660 (step.range.to_point(buffer), status)
661 })
662 .collect()
663 }
664}
665
666#[gpui::test]
667async fn test_serialization(cx: &mut TestAppContext) {
668 let settings_store = cx.update(SettingsStore::test);
669 cx.set_global(settings_store);
670 cx.update(LanguageModelRegistry::test);
671 cx.update(assistant_panel::init);
672 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
673 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
674 let context =
675 cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
676 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
677 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
678 let message_1 = context.update(cx, |context, cx| {
679 context
680 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
681 .unwrap()
682 });
683 let message_2 = context.update(cx, |context, cx| {
684 context
685 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
686 .unwrap()
687 });
688 buffer.update(cx, |buffer, cx| {
689 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
690 buffer.finalize_last_transaction();
691 });
692 let _message_3 = context.update(cx, |context, cx| {
693 context
694 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
695 .unwrap()
696 });
697 buffer.update(cx, |buffer, cx| buffer.undo(cx));
698 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
699 assert_eq!(
700 cx.read(|cx| messages(&context, cx)),
701 [
702 (message_0, Role::User, 0..2),
703 (message_1.id, Role::Assistant, 2..6),
704 (message_2.id, Role::System, 6..6),
705 ]
706 );
707
708 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
709 let deserialized_context = cx.new_model(|cx| {
710 Context::deserialize(
711 serialized_context,
712 Default::default(),
713 registry.clone(),
714 prompt_builder.clone(),
715 None,
716 None,
717 cx,
718 )
719 });
720 let deserialized_buffer =
721 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
722 assert_eq!(
723 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
724 "a\nb\nc\n"
725 );
726 assert_eq!(
727 cx.read(|cx| messages(&deserialized_context, cx)),
728 [
729 (message_0, Role::User, 0..2),
730 (message_1.id, Role::Assistant, 2..6),
731 (message_2.id, Role::System, 6..6),
732 ]
733 );
734}
735
736#[gpui::test(iterations = 100)]
737async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
738 let min_peers = env::var("MIN_PEERS")
739 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
740 .unwrap_or(2);
741 let max_peers = env::var("MAX_PEERS")
742 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
743 .unwrap_or(5);
744 let operations = env::var("OPERATIONS")
745 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
746 .unwrap_or(50);
747
748 let settings_store = cx.update(SettingsStore::test);
749 cx.set_global(settings_store);
750 cx.update(LanguageModelRegistry::test);
751
752 cx.update(assistant_panel::init);
753 let slash_commands = cx.update(SlashCommandRegistry::default_global);
754 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
755 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
756 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
757
758 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
759 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
760 let mut contexts = Vec::new();
761
762 let num_peers = rng.gen_range(min_peers..=max_peers);
763 let context_id = ContextId::new();
764 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
765 for i in 0..num_peers {
766 let context = cx.new_model(|cx| {
767 Context::new(
768 context_id.clone(),
769 i as ReplicaId,
770 language::Capability::ReadWrite,
771 registry.clone(),
772 prompt_builder.clone(),
773 None,
774 None,
775 cx,
776 )
777 });
778
779 cx.update(|cx| {
780 cx.subscribe(&context, {
781 let network = network.clone();
782 move |_, event, _| {
783 if let ContextEvent::Operation(op) = event {
784 network
785 .lock()
786 .broadcast(i as ReplicaId, vec![op.to_proto()]);
787 }
788 }
789 })
790 .detach();
791 });
792
793 contexts.push(context);
794 network.lock().add_peer(i as ReplicaId);
795 }
796
797 let mut mutation_count = operations;
798
799 while mutation_count > 0
800 || !network.lock().is_idle()
801 || network.lock().contains_disconnected_peers()
802 {
803 let context_index = rng.gen_range(0..contexts.len());
804 let context = &contexts[context_index];
805
806 match rng.gen_range(0..100) {
807 0..=29 if mutation_count > 0 => {
808 log::info!("Context {}: edit buffer", context_index);
809 context.update(cx, |context, cx| {
810 context
811 .buffer
812 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
813 });
814 mutation_count -= 1;
815 }
816 30..=44 if mutation_count > 0 => {
817 context.update(cx, |context, cx| {
818 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
819 log::info!("Context {}: split message at {:?}", context_index, range);
820 context.split_message(range, cx);
821 });
822 mutation_count -= 1;
823 }
824 45..=59 if mutation_count > 0 => {
825 context.update(cx, |context, cx| {
826 if let Some(message) = context.messages(cx).choose(&mut rng) {
827 let role = *[Role::User, Role::Assistant, Role::System]
828 .choose(&mut rng)
829 .unwrap();
830 log::info!(
831 "Context {}: insert message after {:?} with {:?}",
832 context_index,
833 message.id,
834 role
835 );
836 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
837 }
838 });
839 mutation_count -= 1;
840 }
841 60..=74 if mutation_count > 0 => {
842 context.update(cx, |context, cx| {
843 let command_text = "/".to_string()
844 + slash_commands
845 .command_names()
846 .choose(&mut rng)
847 .unwrap()
848 .clone()
849 .as_ref();
850
851 let command_range = context.buffer.update(cx, |buffer, cx| {
852 let offset = buffer.random_byte_range(0, &mut rng).start;
853 buffer.edit(
854 [(offset..offset, format!("\n{}\n", command_text))],
855 None,
856 cx,
857 );
858 offset + 1..offset + 1 + command_text.len()
859 });
860
861 let output_len = rng.gen_range(1..=10);
862 let output_text = RandomCharIter::new(&mut rng)
863 .filter(|c| *c != '\r')
864 .take(output_len)
865 .collect::<String>();
866
867 let num_sections = rng.gen_range(0..=3);
868 let mut sections = Vec::with_capacity(num_sections);
869 for _ in 0..num_sections {
870 let section_start = rng.gen_range(0..output_len);
871 let section_end = rng.gen_range(section_start..=output_len);
872 sections.push(SlashCommandOutputSection {
873 range: section_start..section_end,
874 icon: ui::IconName::Ai,
875 label: "section".into(),
876 });
877 }
878
879 log::info!(
880 "Context {}: insert slash command output at {:?} with {:?}",
881 context_index,
882 command_range,
883 sections
884 );
885
886 let command_range = context.buffer.read(cx).anchor_after(command_range.start)
887 ..context.buffer.read(cx).anchor_after(command_range.end);
888 context.insert_command_output(
889 command_range,
890 Task::ready(Ok(SlashCommandOutput {
891 text: output_text,
892 sections,
893 run_commands_in_text: false,
894 })),
895 true,
896 false,
897 cx,
898 );
899 });
900 cx.run_until_parked();
901 mutation_count -= 1;
902 }
903 75..=84 if mutation_count > 0 => {
904 context.update(cx, |context, cx| {
905 if let Some(message) = context.messages(cx).choose(&mut rng) {
906 let new_status = match rng.gen_range(0..3) {
907 0 => MessageStatus::Done,
908 1 => MessageStatus::Pending,
909 _ => MessageStatus::Error(SharedString::from("Random error")),
910 };
911 log::info!(
912 "Context {}: update message {:?} status to {:?}",
913 context_index,
914 message.id,
915 new_status
916 );
917 context.update_metadata(message.id, cx, |metadata| {
918 metadata.status = new_status;
919 });
920 }
921 });
922 mutation_count -= 1;
923 }
924 _ => {
925 let replica_id = context_index as ReplicaId;
926 if network.lock().is_disconnected(replica_id) {
927 network.lock().reconnect_peer(replica_id, 0);
928
929 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
930 let host_context = &contexts[0].read(cx);
931 let guest_context = context.read(cx);
932 (
933 guest_context.serialize_ops(&host_context.version(cx), cx),
934 host_context.serialize_ops(&guest_context.version(cx), cx),
935 )
936 });
937 let ops_to_send = ops_to_send.await;
938 let ops_to_receive = ops_to_receive
939 .await
940 .into_iter()
941 .map(ContextOperation::from_proto)
942 .collect::<Result<Vec<_>>>()
943 .unwrap();
944 log::info!(
945 "Context {}: reconnecting. Sent {} operations, received {} operations",
946 context_index,
947 ops_to_send.len(),
948 ops_to_receive.len()
949 );
950
951 network.lock().broadcast(replica_id, ops_to_send);
952 context
953 .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx))
954 .unwrap();
955 } else if rng.gen_bool(0.1) && replica_id != 0 {
956 log::info!("Context {}: disconnecting", context_index);
957 network.lock().disconnect_peer(replica_id);
958 } else if network.lock().has_unreceived(replica_id) {
959 log::info!("Context {}: applying operations", context_index);
960 let ops = network.lock().receive(replica_id);
961 let ops = ops
962 .into_iter()
963 .map(ContextOperation::from_proto)
964 .collect::<Result<Vec<_>>>()
965 .unwrap();
966 context
967 .update(cx, |context, cx| context.apply_ops(ops, cx))
968 .unwrap();
969 }
970 }
971 }
972 }
973
974 cx.read(|cx| {
975 let first_context = contexts[0].read(cx);
976 for context in &contexts[1..] {
977 let context = context.read(cx);
978 assert!(context.pending_ops.is_empty());
979 assert_eq!(
980 context.buffer.read(cx).text(),
981 first_context.buffer.read(cx).text(),
982 "Context {} text != Context 0 text",
983 context.buffer.read(cx).replica_id()
984 );
985 assert_eq!(
986 context.message_anchors,
987 first_context.message_anchors,
988 "Context {} messages != Context 0 messages",
989 context.buffer.read(cx).replica_id()
990 );
991 assert_eq!(
992 context.messages_metadata,
993 first_context.messages_metadata,
994 "Context {} message metadata != Context 0 message metadata",
995 context.buffer.read(cx).replica_id()
996 );
997 assert_eq!(
998 context.slash_command_output_sections,
999 first_context.slash_command_output_sections,
1000 "Context {} slash command output sections != Context 0 slash command output sections",
1001 context.buffer.read(cx).replica_id()
1002 );
1003 }
1004 });
1005}
1006
1007#[gpui::test]
1008fn test_mark_cache_anchors(cx: &mut AppContext) {
1009 let settings_store = SettingsStore::test(cx);
1010 LanguageModelRegistry::test(cx);
1011 cx.set_global(settings_store);
1012 assistant_panel::init(cx);
1013 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1014 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1015 let context =
1016 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
1017 let buffer = context.read(cx).buffer.clone();
1018
1019 // Create a test cache configuration
1020 let cache_configuration = &Some(LanguageModelCacheConfiguration {
1021 max_cache_anchors: 3,
1022 should_speculate: true,
1023 min_total_token: 10,
1024 });
1025
1026 let message_1 = context.read(cx).message_anchors[0].clone();
1027
1028 context.update(cx, |context, cx| {
1029 context.mark_cache_anchors(cache_configuration, false, cx)
1030 });
1031
1032 assert_eq!(
1033 messages_cache(&context, cx)
1034 .iter()
1035 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1036 .count(),
1037 0,
1038 "Empty messages should not have any cache anchors."
1039 );
1040
1041 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1042 let message_2 = context
1043 .update(cx, |context, cx| {
1044 context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1045 })
1046 .unwrap();
1047
1048 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1049 let message_3 = context
1050 .update(cx, |context, cx| {
1051 context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1052 })
1053 .unwrap();
1054 buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1055
1056 context.update(cx, |context, cx| {
1057 context.mark_cache_anchors(cache_configuration, false, cx)
1058 });
1059 assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1060 assert_eq!(
1061 messages_cache(&context, cx)
1062 .iter()
1063 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1064 .count(),
1065 0,
1066 "Messages should not be marked for cache before going over the token minimum."
1067 );
1068 context.update(cx, |context, _| {
1069 context.token_count = Some(20);
1070 });
1071
1072 context.update(cx, |context, cx| {
1073 context.mark_cache_anchors(cache_configuration, true, cx)
1074 });
1075 assert_eq!(
1076 messages_cache(&context, cx)
1077 .iter()
1078 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1079 .collect::<Vec<bool>>(),
1080 vec![true, true, false],
1081 "Last message should not be an anchor on speculative request."
1082 );
1083
1084 context
1085 .update(cx, |context, cx| {
1086 context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1087 })
1088 .unwrap();
1089
1090 context.update(cx, |context, cx| {
1091 context.mark_cache_anchors(cache_configuration, false, cx)
1092 });
1093 assert_eq!(
1094 messages_cache(&context, cx)
1095 .iter()
1096 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1097 .collect::<Vec<bool>>(),
1098 vec![false, true, true, false],
1099 "Most recent message should also be cached if not a speculative request."
1100 );
1101 context.update(cx, |context, cx| {
1102 context.update_cache_status_for_completion(cx)
1103 });
1104 assert_eq!(
1105 messages_cache(&context, cx)
1106 .iter()
1107 .map(|(_, cache)| cache
1108 .as_ref()
1109 .map_or(None, |cache| Some(cache.status.clone())))
1110 .collect::<Vec<Option<CacheStatus>>>(),
1111 vec![
1112 Some(CacheStatus::Cached),
1113 Some(CacheStatus::Cached),
1114 Some(CacheStatus::Cached),
1115 None
1116 ],
1117 "All user messages prior to anchor should be marked as cached."
1118 );
1119
1120 buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1121 context.update(cx, |context, cx| {
1122 context.mark_cache_anchors(cache_configuration, false, cx)
1123 });
1124 assert_eq!(
1125 messages_cache(&context, cx)
1126 .iter()
1127 .map(|(_, cache)| cache
1128 .as_ref()
1129 .map_or(None, |cache| Some(cache.status.clone())))
1130 .collect::<Vec<Option<CacheStatus>>>(),
1131 vec![
1132 Some(CacheStatus::Cached),
1133 Some(CacheStatus::Cached),
1134 Some(CacheStatus::Pending),
1135 None
1136 ],
1137 "Modifying a message should invalidate it's cache but leave previous messages."
1138 );
1139 buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1140 context.update(cx, |context, cx| {
1141 context.mark_cache_anchors(cache_configuration, false, cx)
1142 });
1143 assert_eq!(
1144 messages_cache(&context, cx)
1145 .iter()
1146 .map(|(_, cache)| cache
1147 .as_ref()
1148 .map_or(None, |cache| Some(cache.status.clone())))
1149 .collect::<Vec<Option<CacheStatus>>>(),
1150 vec![
1151 Some(CacheStatus::Pending),
1152 Some(CacheStatus::Pending),
1153 Some(CacheStatus::Pending),
1154 None
1155 ],
1156 "Modifying a message should invalidate all future messages."
1157 );
1158}
1159
1160fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
1161 context
1162 .read(cx)
1163 .messages(cx)
1164 .map(|message| (message.id, message.role, message.offset_range))
1165 .collect()
1166}
1167
1168fn messages_cache(
1169 context: &Model<Context>,
1170 cx: &AppContext,
1171) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1172 context
1173 .read(cx)
1174 .messages(cx)
1175 .map(|message| (message.id, message.cache.clone()))
1176 .collect()
1177}
1178
1179#[derive(Clone)]
1180struct FakeSlashCommand(String);
1181
1182impl SlashCommand for FakeSlashCommand {
1183 fn name(&self) -> String {
1184 self.0.clone()
1185 }
1186
1187 fn description(&self) -> String {
1188 format!("Fake slash command: {}", self.0)
1189 }
1190
1191 fn menu_text(&self) -> String {
1192 format!("Run fake command: {}", self.0)
1193 }
1194
1195 fn complete_argument(
1196 self: Arc<Self>,
1197 _arguments: &[String],
1198 _cancel: Arc<AtomicBool>,
1199 _workspace: Option<WeakView<Workspace>>,
1200 _cx: &mut WindowContext,
1201 ) -> Task<Result<Vec<ArgumentCompletion>>> {
1202 Task::ready(Ok(vec![]))
1203 }
1204
1205 fn requires_argument(&self) -> bool {
1206 false
1207 }
1208
1209 fn run(
1210 self: Arc<Self>,
1211 _arguments: &[String],
1212 _workspace: WeakView<Workspace>,
1213 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1214 _cx: &mut WindowContext,
1215 ) -> Task<Result<SlashCommandOutput>> {
1216 Task::ready(Ok(SlashCommandOutput {
1217 text: format!("Executed fake command: {}", self.0),
1218 sections: vec![],
1219 run_commands_in_text: false,
1220 }))
1221 }
1222}