1use crate::{
2 assistant_panel, prompt_library, slash_command::file_command, workflow::tool, Context,
3 ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
4};
5use anyhow::Result;
6use assistant_slash_command::{
7 ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
8 SlashCommandRegistry,
9};
10use collections::HashSet;
11use fs::{FakeFs, Fs as _};
12use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
13use indoc::indoc;
14use language::{Buffer, LanguageRegistry, LspAdapterDelegate};
15use language_model::{LanguageModelRegistry, Role};
16use parking_lot::Mutex;
17use project::Project;
18use rand::prelude::*;
19use rope::Point;
20use serde_json::json;
21use settings::SettingsStore;
22use std::{
23 cell::RefCell,
24 env,
25 ops::Range,
26 path::Path,
27 rc::Rc,
28 sync::{atomic::AtomicBool, Arc},
29};
30use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToPoint as _};
31use ui::{Context as _, WindowContext};
32use unindent::Unindent;
33use util::{test::marked_text_ranges, RandomCharIter};
34use workspace::Workspace;
35
36#[gpui::test]
37fn test_inserting_and_removing_messages(cx: &mut AppContext) {
38 let settings_store = SettingsStore::test(cx);
39 LanguageModelRegistry::test(cx);
40 cx.set_global(settings_store);
41 assistant_panel::init(cx);
42 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
43 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
44 let context =
45 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
46 let buffer = context.read(cx).buffer.clone();
47
48 let message_1 = context.read(cx).message_anchors[0].clone();
49 assert_eq!(
50 messages(&context, cx),
51 vec![(message_1.id, Role::User, 0..0)]
52 );
53
54 let message_2 = context.update(cx, |context, cx| {
55 context
56 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
57 .unwrap()
58 });
59 assert_eq!(
60 messages(&context, cx),
61 vec![
62 (message_1.id, Role::User, 0..1),
63 (message_2.id, Role::Assistant, 1..1)
64 ]
65 );
66
67 buffer.update(cx, |buffer, cx| {
68 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
69 });
70 assert_eq!(
71 messages(&context, cx),
72 vec![
73 (message_1.id, Role::User, 0..2),
74 (message_2.id, Role::Assistant, 2..3)
75 ]
76 );
77
78 let message_3 = context.update(cx, |context, cx| {
79 context
80 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
81 .unwrap()
82 });
83 assert_eq!(
84 messages(&context, cx),
85 vec![
86 (message_1.id, Role::User, 0..2),
87 (message_2.id, Role::Assistant, 2..4),
88 (message_3.id, Role::User, 4..4)
89 ]
90 );
91
92 let message_4 = context.update(cx, |context, cx| {
93 context
94 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
95 .unwrap()
96 });
97 assert_eq!(
98 messages(&context, cx),
99 vec![
100 (message_1.id, Role::User, 0..2),
101 (message_2.id, Role::Assistant, 2..4),
102 (message_4.id, Role::User, 4..5),
103 (message_3.id, Role::User, 5..5),
104 ]
105 );
106
107 buffer.update(cx, |buffer, cx| {
108 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
109 });
110 assert_eq!(
111 messages(&context, cx),
112 vec![
113 (message_1.id, Role::User, 0..2),
114 (message_2.id, Role::Assistant, 2..4),
115 (message_4.id, Role::User, 4..6),
116 (message_3.id, Role::User, 6..7),
117 ]
118 );
119
120 // Deleting across message boundaries merges the messages.
121 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
122 assert_eq!(
123 messages(&context, cx),
124 vec![
125 (message_1.id, Role::User, 0..3),
126 (message_3.id, Role::User, 3..4),
127 ]
128 );
129
130 // Undoing the deletion should also undo the merge.
131 buffer.update(cx, |buffer, cx| buffer.undo(cx));
132 assert_eq!(
133 messages(&context, cx),
134 vec![
135 (message_1.id, Role::User, 0..2),
136 (message_2.id, Role::Assistant, 2..4),
137 (message_4.id, Role::User, 4..6),
138 (message_3.id, Role::User, 6..7),
139 ]
140 );
141
142 // Redoing the deletion should also redo the merge.
143 buffer.update(cx, |buffer, cx| buffer.redo(cx));
144 assert_eq!(
145 messages(&context, cx),
146 vec![
147 (message_1.id, Role::User, 0..3),
148 (message_3.id, Role::User, 3..4),
149 ]
150 );
151
152 // Ensure we can still insert after a merged message.
153 let message_5 = context.update(cx, |context, cx| {
154 context
155 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
156 .unwrap()
157 });
158 assert_eq!(
159 messages(&context, cx),
160 vec![
161 (message_1.id, Role::User, 0..3),
162 (message_5.id, Role::System, 3..4),
163 (message_3.id, Role::User, 4..5)
164 ]
165 );
166}
167
168#[gpui::test]
169fn test_message_splitting(cx: &mut AppContext) {
170 let settings_store = SettingsStore::test(cx);
171 cx.set_global(settings_store);
172 LanguageModelRegistry::test(cx);
173 assistant_panel::init(cx);
174 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
175
176 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
177 let context =
178 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
179 let buffer = context.read(cx).buffer.clone();
180
181 let message_1 = context.read(cx).message_anchors[0].clone();
182 assert_eq!(
183 messages(&context, cx),
184 vec![(message_1.id, Role::User, 0..0)]
185 );
186
187 buffer.update(cx, |buffer, cx| {
188 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
189 });
190
191 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
192 let message_2 = message_2.unwrap();
193
194 // We recycle newlines in the middle of a split message
195 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
196 assert_eq!(
197 messages(&context, cx),
198 vec![
199 (message_1.id, Role::User, 0..4),
200 (message_2.id, Role::User, 4..16),
201 ]
202 );
203
204 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
205 let message_3 = message_3.unwrap();
206
207 // We don't recycle newlines at the end of a split message
208 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
209 assert_eq!(
210 messages(&context, cx),
211 vec![
212 (message_1.id, Role::User, 0..4),
213 (message_3.id, Role::User, 4..5),
214 (message_2.id, Role::User, 5..17),
215 ]
216 );
217
218 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
219 let message_4 = message_4.unwrap();
220 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
221 assert_eq!(
222 messages(&context, cx),
223 vec![
224 (message_1.id, Role::User, 0..4),
225 (message_3.id, Role::User, 4..5),
226 (message_2.id, Role::User, 5..9),
227 (message_4.id, Role::User, 9..17),
228 ]
229 );
230
231 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
232 let message_5 = message_5.unwrap();
233 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
234 assert_eq!(
235 messages(&context, cx),
236 vec![
237 (message_1.id, Role::User, 0..4),
238 (message_3.id, Role::User, 4..5),
239 (message_2.id, Role::User, 5..9),
240 (message_4.id, Role::User, 9..10),
241 (message_5.id, Role::User, 10..18),
242 ]
243 );
244
245 let (message_6, message_7) =
246 context.update(cx, |context, cx| context.split_message(14..16, cx));
247 let message_6 = message_6.unwrap();
248 let message_7 = message_7.unwrap();
249 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
250 assert_eq!(
251 messages(&context, cx),
252 vec![
253 (message_1.id, Role::User, 0..4),
254 (message_3.id, Role::User, 4..5),
255 (message_2.id, Role::User, 5..9),
256 (message_4.id, Role::User, 9..10),
257 (message_5.id, Role::User, 10..14),
258 (message_6.id, Role::User, 14..17),
259 (message_7.id, Role::User, 17..19),
260 ]
261 );
262}
263
264#[gpui::test]
265fn test_messages_for_offsets(cx: &mut AppContext) {
266 let settings_store = SettingsStore::test(cx);
267 LanguageModelRegistry::test(cx);
268 cx.set_global(settings_store);
269 assistant_panel::init(cx);
270 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
271 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
272 let context =
273 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
274 let buffer = context.read(cx).buffer.clone();
275
276 let message_1 = context.read(cx).message_anchors[0].clone();
277 assert_eq!(
278 messages(&context, cx),
279 vec![(message_1.id, Role::User, 0..0)]
280 );
281
282 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
283 let message_2 = context
284 .update(cx, |context, cx| {
285 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
286 })
287 .unwrap();
288 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
289
290 let message_3 = context
291 .update(cx, |context, cx| {
292 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
293 })
294 .unwrap();
295 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
296
297 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
298 assert_eq!(
299 messages(&context, cx),
300 vec![
301 (message_1.id, Role::User, 0..4),
302 (message_2.id, Role::User, 4..8),
303 (message_3.id, Role::User, 8..11)
304 ]
305 );
306
307 assert_eq!(
308 message_ids_for_offsets(&context, &[0, 4, 9], cx),
309 [message_1.id, message_2.id, message_3.id]
310 );
311 assert_eq!(
312 message_ids_for_offsets(&context, &[0, 1, 11], cx),
313 [message_1.id, message_3.id]
314 );
315
316 let message_4 = context
317 .update(cx, |context, cx| {
318 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
319 })
320 .unwrap();
321 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
322 assert_eq!(
323 messages(&context, cx),
324 vec![
325 (message_1.id, Role::User, 0..4),
326 (message_2.id, Role::User, 4..8),
327 (message_3.id, Role::User, 8..12),
328 (message_4.id, Role::User, 12..12)
329 ]
330 );
331 assert_eq!(
332 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
333 [message_1.id, message_2.id, message_3.id, message_4.id]
334 );
335
336 fn message_ids_for_offsets(
337 context: &Model<Context>,
338 offsets: &[usize],
339 cx: &AppContext,
340 ) -> Vec<MessageId> {
341 context
342 .read(cx)
343 .messages_for_offsets(offsets.iter().copied(), cx)
344 .into_iter()
345 .map(|message| message.id)
346 .collect()
347 }
348}
349
350#[gpui::test]
351async fn test_slash_commands(cx: &mut TestAppContext) {
352 let settings_store = cx.update(SettingsStore::test);
353 cx.set_global(settings_store);
354 cx.update(LanguageModelRegistry::test);
355 cx.update(Project::init_settings);
356 cx.update(assistant_panel::init);
357 let fs = FakeFs::new(cx.background_executor.clone());
358
359 fs.insert_tree(
360 "/test",
361 json!({
362 "src": {
363 "lib.rs": "fn one() -> usize { 1 }",
364 "main.rs": "
365 use crate::one;
366 fn main() { one(); }
367 ".unindent(),
368 }
369 }),
370 )
371 .await;
372
373 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
374 slash_command_registry.register_command(file_command::FileSlashCommand, false);
375
376 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
377 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
378 let context =
379 cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
380
381 let output_ranges = Rc::new(RefCell::new(HashSet::default()));
382 context.update(cx, |_, cx| {
383 cx.subscribe(&context, {
384 let ranges = output_ranges.clone();
385 move |_, _, event, _| match event {
386 ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
387 for range in removed {
388 ranges.borrow_mut().remove(range);
389 }
390 for command in updated {
391 ranges.borrow_mut().insert(command.source_range.clone());
392 }
393 }
394 _ => {}
395 }
396 })
397 .detach();
398 });
399
400 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
401
402 // Insert a slash command
403 buffer.update(cx, |buffer, cx| {
404 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
405 });
406 assert_text_and_output_ranges(
407 &buffer,
408 &output_ranges.borrow(),
409 "
410 «/file src/lib.rs»
411 "
412 .unindent()
413 .trim_end(),
414 cx,
415 );
416
417 // Edit the argument of the slash command.
418 buffer.update(cx, |buffer, cx| {
419 let edit_offset = buffer.text().find("lib.rs").unwrap();
420 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
421 });
422 assert_text_and_output_ranges(
423 &buffer,
424 &output_ranges.borrow(),
425 "
426 «/file src/main.rs»
427 "
428 .unindent()
429 .trim_end(),
430 cx,
431 );
432
433 // Edit the name of the slash command, using one that doesn't exist.
434 buffer.update(cx, |buffer, cx| {
435 let edit_offset = buffer.text().find("/file").unwrap();
436 buffer.edit(
437 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
438 None,
439 cx,
440 );
441 });
442 assert_text_and_output_ranges(
443 &buffer,
444 &output_ranges.borrow(),
445 "
446 /unknown src/main.rs
447 "
448 .unindent()
449 .trim_end(),
450 cx,
451 );
452
453 #[track_caller]
454 fn assert_text_and_output_ranges(
455 buffer: &Model<Buffer>,
456 ranges: &HashSet<Range<language::Anchor>>,
457 expected_marked_text: &str,
458 cx: &mut TestAppContext,
459 ) {
460 let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
461 let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
462 let mut ranges = ranges
463 .iter()
464 .map(|range| range.to_offset(buffer))
465 .collect::<Vec<_>>();
466 ranges.sort_by_key(|a| a.start);
467 (buffer.text(), ranges)
468 });
469
470 assert_eq!(actual_text, expected_text);
471 assert_eq!(actual_ranges, expected_ranges);
472 }
473}
474
475#[gpui::test]
476async fn test_edit_step_parsing(cx: &mut TestAppContext) {
477 cx.update(prompt_library::init);
478 let settings_store = cx.update(SettingsStore::test);
479 cx.set_global(settings_store);
480 cx.update(Project::init_settings);
481 let fs = FakeFs::new(cx.executor());
482 fs.as_fake()
483 .insert_tree(
484 "/root",
485 json!({
486 "hello.rs": r#"
487 fn hello() {
488 println!("Hello, World!");
489 }
490 "#.unindent()
491 }),
492 )
493 .await;
494 let project = Project::test(fs, [Path::new("/root")], cx).await;
495 cx.update(LanguageModelRegistry::test);
496
497 let model = cx.read(|cx| {
498 LanguageModelRegistry::read_global(cx)
499 .active_model()
500 .unwrap()
501 });
502 cx.update(assistant_panel::init);
503 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
504
505 // Create a new context
506 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
507 let context = cx.new_model(|cx| {
508 Context::local(
509 registry.clone(),
510 Some(project),
511 None,
512 prompt_builder.clone(),
513 cx,
514 )
515 });
516 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
517
518 // Simulate user input
519 let user_message = indoc! {r#"
520 Please add unnecessary complexity to this code:
521
522 ```hello.rs
523 fn main() {
524 println!("Hello, World!");
525 }
526 ```
527 "#};
528 buffer.update(cx, |buffer, cx| {
529 buffer.edit([(0..0, user_message)], None, cx);
530 });
531
532 // Simulate LLM response with edit steps
533 let llm_response = indoc! {r#"
534 Sure, I can help you with that. Here's a step-by-step process:
535
536 <step>
537 First, let's extract the greeting into a separate function:
538
539 ```rust
540 fn greet() {
541 println!("Hello, World!");
542 }
543
544 fn main() {
545 greet();
546 }
547 ```
548 </step>
549
550 <step>
551 Now, let's make the greeting customizable:
552
553 ```rust
554 fn greet(name: &str) {
555 println!("Hello, {}!", name);
556 }
557
558 fn main() {
559 greet("World");
560 }
561 ```
562 </step>
563
564 These changes make the code more modular and flexible.
565 "#};
566
567 // Simulate the assist method to trigger the LLM response
568 context.update(cx, |context, cx| context.assist(cx));
569 cx.run_until_parked();
570
571 // Retrieve the assistant response message's start from the context
572 let response_start_row = context.read_with(cx, |context, cx| {
573 let buffer = context.buffer.read(cx);
574 context.message_anchors[1].start.to_point(buffer).row
575 });
576
577 // Simulate the LLM completion
578 model
579 .as_fake()
580 .stream_last_completion_response(llm_response.to_string());
581 model.as_fake().end_last_completion_stream();
582
583 // Wait for the completion to be processed
584 cx.run_until_parked();
585
586 // Verify that the edit steps were parsed correctly
587 context.read_with(cx, |context, cx| {
588 assert_eq!(
589 workflow_steps(context, cx),
590 vec![
591 (
592 Point::new(response_start_row + 2, 0)..Point::new(response_start_row + 12, 3),
593 WorkflowStepTestStatus::Pending
594 ),
595 (
596 Point::new(response_start_row + 14, 0)..Point::new(response_start_row + 24, 3),
597 WorkflowStepTestStatus::Pending
598 ),
599 ]
600 );
601 });
602
603 model
604 .as_fake()
605 .respond_to_last_tool_use(tool::WorkflowStepResolutionTool {
606 step_title: "Title".into(),
607 suggestions: vec![tool::WorkflowSuggestionTool {
608 path: "/root/hello.rs".into(),
609 // Simulate a symbol name that's slightly different than our outline query
610 kind: tool::WorkflowSuggestionToolKind::Update {
611 symbol: "fn main()".into(),
612 description: "Extract a greeting function".into(),
613 },
614 }],
615 });
616
617 // Wait for tool use to be processed.
618 cx.run_until_parked();
619
620 // Verify that the first edit step is not pending anymore.
621 context.read_with(cx, |context, cx| {
622 assert_eq!(
623 workflow_steps(context, cx),
624 vec![
625 (
626 Point::new(response_start_row + 2, 0)..Point::new(response_start_row + 12, 3),
627 WorkflowStepTestStatus::Resolved
628 ),
629 (
630 Point::new(response_start_row + 14, 0)..Point::new(response_start_row + 24, 3),
631 WorkflowStepTestStatus::Pending
632 ),
633 ]
634 );
635 });
636
637 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
638 enum WorkflowStepTestStatus {
639 Pending,
640 Resolved,
641 Error,
642 }
643
644 fn workflow_steps(
645 context: &Context,
646 cx: &AppContext,
647 ) -> Vec<(Range<Point>, WorkflowStepTestStatus)> {
648 context
649 .workflow_steps
650 .iter()
651 .map(|step| {
652 let buffer = context.buffer.read(cx);
653 let status = match &step.step.read(cx).resolution {
654 None => WorkflowStepTestStatus::Pending,
655 Some(Ok(_)) => WorkflowStepTestStatus::Resolved,
656 Some(Err(_)) => WorkflowStepTestStatus::Error,
657 };
658 (step.range.to_point(buffer), status)
659 })
660 .collect()
661 }
662}
663
664#[gpui::test]
665async fn test_serialization(cx: &mut TestAppContext) {
666 let settings_store = cx.update(SettingsStore::test);
667 cx.set_global(settings_store);
668 cx.update(LanguageModelRegistry::test);
669 cx.update(assistant_panel::init);
670 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
671 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
672 let context =
673 cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
674 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
675 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
676 let message_1 = context.update(cx, |context, cx| {
677 context
678 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
679 .unwrap()
680 });
681 let message_2 = context.update(cx, |context, cx| {
682 context
683 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
684 .unwrap()
685 });
686 buffer.update(cx, |buffer, cx| {
687 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
688 buffer.finalize_last_transaction();
689 });
690 let _message_3 = context.update(cx, |context, cx| {
691 context
692 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
693 .unwrap()
694 });
695 buffer.update(cx, |buffer, cx| buffer.undo(cx));
696 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
697 assert_eq!(
698 cx.read(|cx| messages(&context, cx)),
699 [
700 (message_0, Role::User, 0..2),
701 (message_1.id, Role::Assistant, 2..6),
702 (message_2.id, Role::System, 6..6),
703 ]
704 );
705
706 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
707 let deserialized_context = cx.new_model(|cx| {
708 Context::deserialize(
709 serialized_context,
710 Default::default(),
711 registry.clone(),
712 prompt_builder.clone(),
713 None,
714 None,
715 cx,
716 )
717 });
718 let deserialized_buffer =
719 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
720 assert_eq!(
721 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
722 "a\nb\nc\n"
723 );
724 assert_eq!(
725 cx.read(|cx| messages(&deserialized_context, cx)),
726 [
727 (message_0, Role::User, 0..2),
728 (message_1.id, Role::Assistant, 2..6),
729 (message_2.id, Role::System, 6..6),
730 ]
731 );
732}
733
734#[gpui::test(iterations = 100)]
735async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
736 let min_peers = env::var("MIN_PEERS")
737 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
738 .unwrap_or(2);
739 let max_peers = env::var("MAX_PEERS")
740 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
741 .unwrap_or(5);
742 let operations = env::var("OPERATIONS")
743 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
744 .unwrap_or(50);
745
746 let settings_store = cx.update(SettingsStore::test);
747 cx.set_global(settings_store);
748 cx.update(LanguageModelRegistry::test);
749
750 cx.update(assistant_panel::init);
751 let slash_commands = cx.update(SlashCommandRegistry::default_global);
752 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
753 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
754 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
755
756 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
757 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
758 let mut contexts = Vec::new();
759
760 let num_peers = rng.gen_range(min_peers..=max_peers);
761 let context_id = ContextId::new();
762 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
763 for i in 0..num_peers {
764 let context = cx.new_model(|cx| {
765 Context::new(
766 context_id.clone(),
767 i as ReplicaId,
768 language::Capability::ReadWrite,
769 registry.clone(),
770 prompt_builder.clone(),
771 None,
772 None,
773 cx,
774 )
775 });
776
777 cx.update(|cx| {
778 cx.subscribe(&context, {
779 let network = network.clone();
780 move |_, event, _| {
781 if let ContextEvent::Operation(op) = event {
782 network
783 .lock()
784 .broadcast(i as ReplicaId, vec![op.to_proto()]);
785 }
786 }
787 })
788 .detach();
789 });
790
791 contexts.push(context);
792 network.lock().add_peer(i as ReplicaId);
793 }
794
795 let mut mutation_count = operations;
796
797 while mutation_count > 0
798 || !network.lock().is_idle()
799 || network.lock().contains_disconnected_peers()
800 {
801 let context_index = rng.gen_range(0..contexts.len());
802 let context = &contexts[context_index];
803
804 match rng.gen_range(0..100) {
805 0..=29 if mutation_count > 0 => {
806 log::info!("Context {}: edit buffer", context_index);
807 context.update(cx, |context, cx| {
808 context
809 .buffer
810 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
811 });
812 mutation_count -= 1;
813 }
814 30..=44 if mutation_count > 0 => {
815 context.update(cx, |context, cx| {
816 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
817 log::info!("Context {}: split message at {:?}", context_index, range);
818 context.split_message(range, cx);
819 });
820 mutation_count -= 1;
821 }
822 45..=59 if mutation_count > 0 => {
823 context.update(cx, |context, cx| {
824 if let Some(message) = context.messages(cx).choose(&mut rng) {
825 let role = *[Role::User, Role::Assistant, Role::System]
826 .choose(&mut rng)
827 .unwrap();
828 log::info!(
829 "Context {}: insert message after {:?} with {:?}",
830 context_index,
831 message.id,
832 role
833 );
834 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
835 }
836 });
837 mutation_count -= 1;
838 }
839 60..=74 if mutation_count > 0 => {
840 context.update(cx, |context, cx| {
841 let command_text = "/".to_string()
842 + slash_commands
843 .command_names()
844 .choose(&mut rng)
845 .unwrap()
846 .clone()
847 .as_ref();
848
849 let command_range = context.buffer.update(cx, |buffer, cx| {
850 let offset = buffer.random_byte_range(0, &mut rng).start;
851 buffer.edit(
852 [(offset..offset, format!("\n{}\n", command_text))],
853 None,
854 cx,
855 );
856 offset + 1..offset + 1 + command_text.len()
857 });
858
859 let output_len = rng.gen_range(1..=10);
860 let output_text = RandomCharIter::new(&mut rng)
861 .filter(|c| *c != '\r')
862 .take(output_len)
863 .collect::<String>();
864
865 let num_sections = rng.gen_range(0..=3);
866 let mut sections = Vec::with_capacity(num_sections);
867 for _ in 0..num_sections {
868 let section_start = rng.gen_range(0..output_len);
869 let section_end = rng.gen_range(section_start..=output_len);
870 sections.push(SlashCommandOutputSection {
871 range: section_start..section_end,
872 icon: ui::IconName::Ai,
873 label: "section".into(),
874 });
875 }
876
877 log::info!(
878 "Context {}: insert slash command output at {:?} with {:?}",
879 context_index,
880 command_range,
881 sections
882 );
883
884 let command_range = context.buffer.read(cx).anchor_after(command_range.start)
885 ..context.buffer.read(cx).anchor_after(command_range.end);
886 context.insert_command_output(
887 command_range,
888 Task::ready(Ok(SlashCommandOutput {
889 text: output_text,
890 sections,
891 run_commands_in_text: false,
892 })),
893 true,
894 cx,
895 );
896 });
897 cx.run_until_parked();
898 mutation_count -= 1;
899 }
900 75..=84 if mutation_count > 0 => {
901 context.update(cx, |context, cx| {
902 if let Some(message) = context.messages(cx).choose(&mut rng) {
903 let new_status = match rng.gen_range(0..3) {
904 0 => MessageStatus::Done,
905 1 => MessageStatus::Pending,
906 _ => MessageStatus::Error(SharedString::from("Random error")),
907 };
908 log::info!(
909 "Context {}: update message {:?} status to {:?}",
910 context_index,
911 message.id,
912 new_status
913 );
914 context.update_metadata(message.id, cx, |metadata| {
915 metadata.status = new_status;
916 });
917 }
918 });
919 mutation_count -= 1;
920 }
921 _ => {
922 let replica_id = context_index as ReplicaId;
923 if network.lock().is_disconnected(replica_id) {
924 network.lock().reconnect_peer(replica_id, 0);
925
926 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
927 let host_context = &contexts[0].read(cx);
928 let guest_context = context.read(cx);
929 (
930 guest_context.serialize_ops(&host_context.version(cx), cx),
931 host_context.serialize_ops(&guest_context.version(cx), cx),
932 )
933 });
934 let ops_to_send = ops_to_send.await;
935 let ops_to_receive = ops_to_receive
936 .await
937 .into_iter()
938 .map(ContextOperation::from_proto)
939 .collect::<Result<Vec<_>>>()
940 .unwrap();
941 log::info!(
942 "Context {}: reconnecting. Sent {} operations, received {} operations",
943 context_index,
944 ops_to_send.len(),
945 ops_to_receive.len()
946 );
947
948 network.lock().broadcast(replica_id, ops_to_send);
949 context
950 .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx))
951 .unwrap();
952 } else if rng.gen_bool(0.1) && replica_id != 0 {
953 log::info!("Context {}: disconnecting", context_index);
954 network.lock().disconnect_peer(replica_id);
955 } else if network.lock().has_unreceived(replica_id) {
956 log::info!("Context {}: applying operations", context_index);
957 let ops = network.lock().receive(replica_id);
958 let ops = ops
959 .into_iter()
960 .map(ContextOperation::from_proto)
961 .collect::<Result<Vec<_>>>()
962 .unwrap();
963 context
964 .update(cx, |context, cx| context.apply_ops(ops, cx))
965 .unwrap();
966 }
967 }
968 }
969 }
970
971 cx.read(|cx| {
972 let first_context = contexts[0].read(cx);
973 for context in &contexts[1..] {
974 let context = context.read(cx);
975 assert!(context.pending_ops.is_empty());
976 assert_eq!(
977 context.buffer.read(cx).text(),
978 first_context.buffer.read(cx).text(),
979 "Context {} text != Context 0 text",
980 context.buffer.read(cx).replica_id()
981 );
982 assert_eq!(
983 context.message_anchors,
984 first_context.message_anchors,
985 "Context {} messages != Context 0 messages",
986 context.buffer.read(cx).replica_id()
987 );
988 assert_eq!(
989 context.messages_metadata,
990 first_context.messages_metadata,
991 "Context {} message metadata != Context 0 message metadata",
992 context.buffer.read(cx).replica_id()
993 );
994 assert_eq!(
995 context.slash_command_output_sections,
996 first_context.slash_command_output_sections,
997 "Context {} slash command output sections != Context 0 slash command output sections",
998 context.buffer.read(cx).replica_id()
999 );
1000 }
1001 });
1002}
1003
1004fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
1005 context
1006 .read(cx)
1007 .messages(cx)
1008 .map(|message| (message.id, message.role, message.offset_range))
1009 .collect()
1010}
1011
1012#[derive(Clone)]
1013struct FakeSlashCommand(String);
1014
1015impl SlashCommand for FakeSlashCommand {
1016 fn name(&self) -> String {
1017 self.0.clone()
1018 }
1019
1020 fn description(&self) -> String {
1021 format!("Fake slash command: {}", self.0)
1022 }
1023
1024 fn menu_text(&self) -> String {
1025 format!("Run fake command: {}", self.0)
1026 }
1027
1028 fn complete_argument(
1029 self: Arc<Self>,
1030 _arguments: &[String],
1031 _cancel: Arc<AtomicBool>,
1032 _workspace: Option<WeakView<Workspace>>,
1033 _cx: &mut WindowContext,
1034 ) -> Task<Result<Vec<ArgumentCompletion>>> {
1035 Task::ready(Ok(vec![]))
1036 }
1037
1038 fn requires_argument(&self) -> bool {
1039 false
1040 }
1041
1042 fn run(
1043 self: Arc<Self>,
1044 _arguments: &[String],
1045 _workspace: WeakView<Workspace>,
1046 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1047 _cx: &mut WindowContext,
1048 ) -> Task<Result<SlashCommandOutput>> {
1049 Task::ready(Ok(SlashCommandOutput {
1050 text: format!("Executed fake command: {}", self.0),
1051 sections: vec![],
1052 run_commands_in_text: false,
1053 }))
1054 }
1055}