diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index d08c468baf93db6891a17216b66472f05fa5fe6f..5fa816a8175f4adc9ea88c44e89db676f5a0ce63 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod context_tests; + use crate::{ prompts::PromptBuilder, slash_command::SlashCommandLine, workflow::WorkflowStep, MessageId, MessageStatus, @@ -308,7 +311,7 @@ pub struct MessageAnchor { #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub struct MessageMetadata { pub role: Role, - status: MessageStatus, + pub status: MessageStatus, timestamp: clock::Lamport, } @@ -530,7 +533,7 @@ impl Context { this } - fn serialize(&self, cx: &AppContext) -> SavedContext { + pub(crate) fn serialize(&self, cx: &AppContext) -> SavedContext { let buffer = self.buffer.read(cx); SavedContext { id: Some(self.id.clone()), @@ -1531,7 +1534,7 @@ impl Context { } } - fn insert_message_after( + pub fn insert_message_after( &mut self, message_id: MessageId, role: Role, @@ -2393,1055 +2396,3 @@ pub struct SavedContextMetadata { pub path: PathBuf, pub mtime: chrono::DateTime, } - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - assistant_panel, prompt_library, slash_command::file_command, workflow::tool, MessageId, - }; - use assistant_slash_command::{ArgumentCompletion, SlashCommand}; - use fs::FakeFs; - use gpui::{AppContext, TestAppContext, WeakView}; - use indoc::indoc; - use language::LspAdapterDelegate; - use parking_lot::Mutex; - use project::Project; - use rand::prelude::*; - use serde_json::json; - use settings::SettingsStore; - use std::{cell::RefCell, env, rc::Rc, sync::atomic::AtomicBool}; - use text::{network::Network, ToPoint}; - use ui::WindowContext; - use unindent::Unindent; - use util::{test::marked_text_ranges, RandomCharIter}; - use workspace::Workspace; - - #[gpui::test] - fn test_inserting_and_removing_messages(cx: &mut AppContext) { - let settings_store = SettingsStore::test(cx); - LanguageModelRegistry::test(cx); - cx.set_global(settings_store); - assistant_panel::init(cx); - let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); - let buffer = context.read(cx).buffer.clone(); - - let message_1 = context.read(cx).message_anchors[0].clone(); - assert_eq!( - messages(&context, cx), - vec![(message_1.id, Role::User, 0..0)] - ); - - let message_2 = context.update(cx, |context, cx| { - context - .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) - .unwrap() - }); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..1), - (message_2.id, Role::Assistant, 1..1) - ] - ); - - buffer.update(cx, |buffer, cx| { - buffer.edit([(0..0, "1"), (1..1, "2")], None, cx) - }); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..2), - (message_2.id, Role::Assistant, 2..3) - ] - ); - - let message_3 = context.update(cx, |context, cx| { - context - .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) - .unwrap() - }); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..2), - (message_2.id, Role::Assistant, 2..4), - (message_3.id, Role::User, 4..4) - ] - ); - - let message_4 = context.update(cx, |context, cx| { - context - .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) - .unwrap() - }); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..2), - (message_2.id, Role::Assistant, 2..4), - (message_4.id, Role::User, 4..5), - (message_3.id, Role::User, 5..5), - ] - ); - - buffer.update(cx, |buffer, cx| { - buffer.edit([(4..4, "C"), (5..5, "D")], None, cx) - }); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..2), - (message_2.id, Role::Assistant, 2..4), - (message_4.id, Role::User, 4..6), - (message_3.id, Role::User, 6..7), - ] - ); - - // Deleting across message boundaries merges the messages. - buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx)); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..3), - (message_3.id, Role::User, 3..4), - ] - ); - - // Undoing the deletion should also undo the merge. - buffer.update(cx, |buffer, cx| buffer.undo(cx)); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..2), - (message_2.id, Role::Assistant, 2..4), - (message_4.id, Role::User, 4..6), - (message_3.id, Role::User, 6..7), - ] - ); - - // Redoing the deletion should also redo the merge. - buffer.update(cx, |buffer, cx| buffer.redo(cx)); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..3), - (message_3.id, Role::User, 3..4), - ] - ); - - // Ensure we can still insert after a merged message. - let message_5 = context.update(cx, |context, cx| { - context - .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) - .unwrap() - }); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..3), - (message_5.id, Role::System, 3..4), - (message_3.id, Role::User, 4..5) - ] - ); - } - - #[gpui::test] - fn test_message_splitting(cx: &mut AppContext) { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - LanguageModelRegistry::test(cx); - assistant_panel::init(cx); - let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); - let buffer = context.read(cx).buffer.clone(); - - let message_1 = context.read(cx).message_anchors[0].clone(); - assert_eq!( - messages(&context, cx), - vec![(message_1.id, Role::User, 0..0)] - ); - - buffer.update(cx, |buffer, cx| { - buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx) - }); - - let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx)); - let message_2 = message_2.unwrap(); - - // We recycle newlines in the middle of a split message - assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_2.id, Role::User, 4..16), - ] - ); - - let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx)); - let message_3 = message_3.unwrap(); - - // We don't recycle newlines at the end of a split message - assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_3.id, Role::User, 4..5), - (message_2.id, Role::User, 5..17), - ] - ); - - let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx)); - let message_4 = message_4.unwrap(); - assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_3.id, Role::User, 4..5), - (message_2.id, Role::User, 5..9), - (message_4.id, Role::User, 9..17), - ] - ); - - let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx)); - let message_5 = message_5.unwrap(); - assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_3.id, Role::User, 4..5), - (message_2.id, Role::User, 5..9), - (message_4.id, Role::User, 9..10), - (message_5.id, Role::User, 10..18), - ] - ); - - let (message_6, message_7) = - context.update(cx, |context, cx| context.split_message(14..16, cx)); - let message_6 = message_6.unwrap(); - let message_7 = message_7.unwrap(); - assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_3.id, Role::User, 4..5), - (message_2.id, Role::User, 5..9), - (message_4.id, Role::User, 9..10), - (message_5.id, Role::User, 10..14), - (message_6.id, Role::User, 14..17), - (message_7.id, Role::User, 17..19), - ] - ); - } - - #[gpui::test] - fn test_messages_for_offsets(cx: &mut AppContext) { - let settings_store = SettingsStore::test(cx); - LanguageModelRegistry::test(cx); - cx.set_global(settings_store); - assistant_panel::init(cx); - let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); - let buffer = context.read(cx).buffer.clone(); - - let message_1 = context.read(cx).message_anchors[0].clone(); - assert_eq!( - messages(&context, cx), - vec![(message_1.id, Role::User, 0..0)] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx)); - let message_2 = context - .update(cx, |context, cx| { - context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx) - }) - .unwrap(); - buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx)); - - let message_3 = context - .update(cx, |context, cx| { - context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) - }) - .unwrap(); - buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx)); - - assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_2.id, Role::User, 4..8), - (message_3.id, Role::User, 8..11) - ] - ); - - assert_eq!( - message_ids_for_offsets(&context, &[0, 4, 9], cx), - [message_1.id, message_2.id, message_3.id] - ); - assert_eq!( - message_ids_for_offsets(&context, &[0, 1, 11], cx), - [message_1.id, message_3.id] - ); - - let message_4 = context - .update(cx, |context, cx| { - context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx) - }) - .unwrap(); - assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_2.id, Role::User, 4..8), - (message_3.id, Role::User, 8..12), - (message_4.id, Role::User, 12..12) - ] - ); - assert_eq!( - message_ids_for_offsets(&context, &[0, 4, 8, 12], cx), - [message_1.id, message_2.id, message_3.id, message_4.id] - ); - - fn message_ids_for_offsets( - context: &Model, - offsets: &[usize], - cx: &AppContext, - ) -> Vec { - context - .read(cx) - .messages_for_offsets(offsets.iter().copied(), cx) - .into_iter() - .map(|message| message.id) - .collect() - } - } - - #[gpui::test] - async fn test_slash_commands(cx: &mut TestAppContext) { - let settings_store = cx.update(SettingsStore::test); - cx.set_global(settings_store); - cx.update(LanguageModelRegistry::test); - cx.update(Project::init_settings); - cx.update(assistant_panel::init); - let fs = FakeFs::new(cx.background_executor.clone()); - - fs.insert_tree( - "/test", - json!({ - "src": { - "lib.rs": "fn one() -> usize { 1 }", - "main.rs": " - use crate::one; - fn main() { one(); } - ".unindent(), - } - }), - ) - .await; - - let slash_command_registry = cx.update(SlashCommandRegistry::default_global); - slash_command_registry.register_command(file_command::FileSlashCommand, false); - - let registry = Arc::new(LanguageRegistry::test(cx.executor())); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = cx.new_model(|cx| { - Context::local(registry.clone(), None, None, prompt_builder.clone(), cx) - }); - - let output_ranges = Rc::new(RefCell::new(HashSet::default())); - context.update(cx, |_, cx| { - cx.subscribe(&context, { - let ranges = output_ranges.clone(); - move |_, _, event, _| match event { - ContextEvent::PendingSlashCommandsUpdated { removed, updated } => { - for range in removed { - ranges.borrow_mut().remove(range); - } - for command in updated { - ranges.borrow_mut().insert(command.source_range.clone()); - } - } - _ => {} - } - }) - .detach(); - }); - - let buffer = context.read_with(cx, |context, _| context.buffer.clone()); - - // Insert a slash command - buffer.update(cx, |buffer, cx| { - buffer.edit([(0..0, "/file src/lib.rs")], None, cx); - }); - assert_text_and_output_ranges( - &buffer, - &output_ranges.borrow(), - " - «/file src/lib.rs» - " - .unindent() - .trim_end(), - cx, - ); - - // Edit the argument of the slash command. - buffer.update(cx, |buffer, cx| { - let edit_offset = buffer.text().find("lib.rs").unwrap(); - buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx); - }); - assert_text_and_output_ranges( - &buffer, - &output_ranges.borrow(), - " - «/file src/main.rs» - " - .unindent() - .trim_end(), - cx, - ); - - // Edit the name of the slash command, using one that doesn't exist. - buffer.update(cx, |buffer, cx| { - let edit_offset = buffer.text().find("/file").unwrap(); - buffer.edit( - [(edit_offset..edit_offset + "/file".len(), "/unknown")], - None, - cx, - ); - }); - assert_text_and_output_ranges( - &buffer, - &output_ranges.borrow(), - " - /unknown src/main.rs - " - .unindent() - .trim_end(), - cx, - ); - - #[track_caller] - fn assert_text_and_output_ranges( - buffer: &Model, - ranges: &HashSet>, - expected_marked_text: &str, - cx: &mut TestAppContext, - ) { - let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false); - let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| { - let mut ranges = ranges - .iter() - .map(|range| range.to_offset(buffer)) - .collect::>(); - ranges.sort_by_key(|a| a.start); - (buffer.text(), ranges) - }); - - assert_eq!(actual_text, expected_text); - assert_eq!(actual_ranges, expected_ranges); - } - } - - #[gpui::test] - async fn test_edit_step_parsing(cx: &mut TestAppContext) { - cx.update(prompt_library::init); - let settings_store = cx.update(SettingsStore::test); - cx.set_global(settings_store); - cx.update(Project::init_settings); - let fs = FakeFs::new(cx.executor()); - fs.as_fake() - .insert_tree( - "/root", - json!({ - "hello.rs": r#" - fn hello() { - println!("Hello, World!"); - } - "#.unindent() - }), - ) - .await; - let project = Project::test(fs, [Path::new("/root")], cx).await; - cx.update(LanguageModelRegistry::test); - - let model = cx.read(|cx| { - LanguageModelRegistry::read_global(cx) - .active_model() - .unwrap() - }); - cx.update(assistant_panel::init); - let registry = Arc::new(LanguageRegistry::test(cx.executor())); - - // Create a new context - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = cx.new_model(|cx| { - Context::local( - registry.clone(), - Some(project), - None, - prompt_builder.clone(), - cx, - ) - }); - let buffer = context.read_with(cx, |context, _| context.buffer.clone()); - - // Simulate user input - let user_message = indoc! {r#" - Please add unnecessary complexity to this code: - - ```hello.rs - fn main() { - println!("Hello, World!"); - } - ``` - "#}; - buffer.update(cx, |buffer, cx| { - buffer.edit([(0..0, user_message)], None, cx); - }); - - // Simulate LLM response with edit steps - let llm_response = indoc! {r#" - Sure, I can help you with that. Here's a step-by-step process: - - - First, let's extract the greeting into a separate function: - - ```rust - fn greet() { - println!("Hello, World!"); - } - - fn main() { - greet(); - } - ``` - - - - Now, let's make the greeting customizable: - - ```rust - fn greet(name: &str) { - println!("Hello, {}!", name); - } - - fn main() { - greet("World"); - } - ``` - - - These changes make the code more modular and flexible. - "#}; - - // Simulate the assist method to trigger the LLM response - context.update(cx, |context, cx| context.assist(cx)); - cx.run_until_parked(); - - // Retrieve the assistant response message's start from the context - let response_start_row = context.read_with(cx, |context, cx| { - let buffer = context.buffer.read(cx); - context.message_anchors[1].start.to_point(buffer).row - }); - - // Simulate the LLM completion - model - .as_fake() - .stream_last_completion_response(llm_response.to_string()); - model.as_fake().end_last_completion_stream(); - - // Wait for the completion to be processed - cx.run_until_parked(); - - // Verify that the edit steps were parsed correctly - context.read_with(cx, |context, cx| { - assert_eq!( - workflow_steps(context, cx), - vec![ - ( - Point::new(response_start_row + 2, 0) - ..Point::new(response_start_row + 12, 3), - WorkflowStepTestStatus::Pending - ), - ( - Point::new(response_start_row + 14, 0) - ..Point::new(response_start_row + 24, 3), - WorkflowStepTestStatus::Pending - ), - ] - ); - }); - - model - .as_fake() - .respond_to_last_tool_use(tool::WorkflowStepResolutionTool { - step_title: "Title".into(), - suggestions: vec![tool::WorkflowSuggestionTool { - path: "/root/hello.rs".into(), - // Simulate a symbol name that's slightly different than our outline query - kind: tool::WorkflowSuggestionToolKind::Update { - symbol: "fn main()".into(), - description: "Extract a greeting function".into(), - }, - }], - }); - - // Wait for tool use to be processed. - cx.run_until_parked(); - - // Verify that the first edit step is not pending anymore. - context.read_with(cx, |context, cx| { - assert_eq!( - workflow_steps(context, cx), - vec![ - ( - Point::new(response_start_row + 2, 0) - ..Point::new(response_start_row + 12, 3), - WorkflowStepTestStatus::Resolved - ), - ( - Point::new(response_start_row + 14, 0) - ..Point::new(response_start_row + 24, 3), - WorkflowStepTestStatus::Pending - ), - ] - ); - }); - - #[derive(Copy, Clone, Debug, Eq, PartialEq)] - enum WorkflowStepTestStatus { - Pending, - Resolved, - Error, - } - - fn workflow_steps( - context: &Context, - cx: &AppContext, - ) -> Vec<(Range, WorkflowStepTestStatus)> { - context - .workflow_steps - .iter() - .map(|step| { - let buffer = context.buffer.read(cx); - let status = match &step.step.read(cx).resolution { - None => WorkflowStepTestStatus::Pending, - Some(Ok(_)) => WorkflowStepTestStatus::Resolved, - Some(Err(_)) => WorkflowStepTestStatus::Error, - }; - (step.range.to_point(buffer), status) - }) - .collect() - } - } - - #[gpui::test] - async fn test_serialization(cx: &mut TestAppContext) { - let settings_store = cx.update(SettingsStore::test); - cx.set_global(settings_store); - cx.update(LanguageModelRegistry::test); - cx.update(assistant_panel::init); - let registry = Arc::new(LanguageRegistry::test(cx.executor())); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = cx.new_model(|cx| { - Context::local(registry.clone(), None, None, prompt_builder.clone(), cx) - }); - let buffer = context.read_with(cx, |context, _| context.buffer.clone()); - let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id); - let message_1 = context.update(cx, |context, cx| { - context - .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx) - .unwrap() - }); - let message_2 = context.update(cx, |context, cx| { - context - .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) - .unwrap() - }); - buffer.update(cx, |buffer, cx| { - buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx); - buffer.finalize_last_transaction(); - }); - let _message_3 = context.update(cx, |context, cx| { - context - .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx) - .unwrap() - }); - buffer.update(cx, |buffer, cx| buffer.undo(cx)); - assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n"); - assert_eq!( - cx.read(|cx| messages(&context, cx)), - [ - (message_0, Role::User, 0..2), - (message_1.id, Role::Assistant, 2..6), - (message_2.id, Role::System, 6..6), - ] - ); - - let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx)); - let deserialized_context = cx.new_model(|cx| { - Context::deserialize( - serialized_context, - Default::default(), - registry.clone(), - prompt_builder.clone(), - None, - None, - cx, - ) - }); - let deserialized_buffer = - deserialized_context.read_with(cx, |context, _| context.buffer.clone()); - assert_eq!( - deserialized_buffer.read_with(cx, |buffer, _| buffer.text()), - "a\nb\nc\n" - ); - assert_eq!( - cx.read(|cx| messages(&deserialized_context, cx)), - [ - (message_0, Role::User, 0..2), - (message_1.id, Role::Assistant, 2..6), - (message_2.id, Role::System, 6..6), - ] - ); - } - - #[gpui::test(iterations = 100)] - async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) { - let min_peers = env::var("MIN_PEERS") - .map(|i| i.parse().expect("invalid `MIN_PEERS` variable")) - .unwrap_or(2); - let max_peers = env::var("MAX_PEERS") - .map(|i| i.parse().expect("invalid `MAX_PEERS` variable")) - .unwrap_or(5); - let operations = env::var("OPERATIONS") - .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) - .unwrap_or(50); - - let settings_store = cx.update(SettingsStore::test); - cx.set_global(settings_store); - cx.update(LanguageModelRegistry::test); - - cx.update(assistant_panel::init); - let slash_commands = cx.update(SlashCommandRegistry::default_global); - slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false); - slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false); - slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false); - - let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone())); - let network = Arc::new(Mutex::new(Network::new(rng.clone()))); - let mut contexts = Vec::new(); - - let num_peers = rng.gen_range(min_peers..=max_peers); - let context_id = ContextId::new(); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - for i in 0..num_peers { - let context = cx.new_model(|cx| { - Context::new( - context_id.clone(), - i as ReplicaId, - language::Capability::ReadWrite, - registry.clone(), - prompt_builder.clone(), - None, - None, - cx, - ) - }); - - cx.update(|cx| { - cx.subscribe(&context, { - let network = network.clone(); - move |_, event, _| { - if let ContextEvent::Operation(op) = event { - network - .lock() - .broadcast(i as ReplicaId, vec![op.to_proto()]); - } - } - }) - .detach(); - }); - - contexts.push(context); - network.lock().add_peer(i as ReplicaId); - } - - let mut mutation_count = operations; - - while mutation_count > 0 - || !network.lock().is_idle() - || network.lock().contains_disconnected_peers() - { - let context_index = rng.gen_range(0..contexts.len()); - let context = &contexts[context_index]; - - match rng.gen_range(0..100) { - 0..=29 if mutation_count > 0 => { - log::info!("Context {}: edit buffer", context_index); - context.update(cx, |context, cx| { - context - .buffer - .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx)); - }); - mutation_count -= 1; - } - 30..=44 if mutation_count > 0 => { - context.update(cx, |context, cx| { - let range = context.buffer.read(cx).random_byte_range(0, &mut rng); - log::info!("Context {}: split message at {:?}", context_index, range); - context.split_message(range, cx); - }); - mutation_count -= 1; - } - 45..=59 if mutation_count > 0 => { - context.update(cx, |context, cx| { - if let Some(message) = context.messages(cx).choose(&mut rng) { - let role = *[Role::User, Role::Assistant, Role::System] - .choose(&mut rng) - .unwrap(); - log::info!( - "Context {}: insert message after {:?} with {:?}", - context_index, - message.id, - role - ); - context.insert_message_after(message.id, role, MessageStatus::Done, cx); - } - }); - mutation_count -= 1; - } - 60..=74 if mutation_count > 0 => { - context.update(cx, |context, cx| { - let command_text = "/".to_string() - + slash_commands - .command_names() - .choose(&mut rng) - .unwrap() - .clone() - .as_ref(); - - let command_range = context.buffer.update(cx, |buffer, cx| { - let offset = buffer.random_byte_range(0, &mut rng).start; - buffer.edit( - [(offset..offset, format!("\n{}\n", command_text))], - None, - cx, - ); - offset + 1..offset + 1 + command_text.len() - }); - - let output_len = rng.gen_range(1..=10); - let output_text = RandomCharIter::new(&mut rng) - .filter(|c| *c != '\r') - .take(output_len) - .collect::(); - - let num_sections = rng.gen_range(0..=3); - let mut sections = Vec::with_capacity(num_sections); - for _ in 0..num_sections { - let section_start = rng.gen_range(0..output_len); - let section_end = rng.gen_range(section_start..=output_len); - sections.push(SlashCommandOutputSection { - range: section_start..section_end, - icon: ui::IconName::Ai, - label: "section".into(), - }); - } - - log::info!( - "Context {}: insert slash command output at {:?} with {:?}", - context_index, - command_range, - sections - ); - - let command_range = - context.buffer.read(cx).anchor_after(command_range.start) - ..context.buffer.read(cx).anchor_after(command_range.end); - context.insert_command_output( - command_range, - Task::ready(Ok(SlashCommandOutput { - text: output_text, - sections, - run_commands_in_text: false, - })), - true, - cx, - ); - }); - cx.run_until_parked(); - mutation_count -= 1; - } - 75..=84 if mutation_count > 0 => { - context.update(cx, |context, cx| { - if let Some(message) = context.messages(cx).choose(&mut rng) { - let new_status = match rng.gen_range(0..3) { - 0 => MessageStatus::Done, - 1 => MessageStatus::Pending, - _ => MessageStatus::Error(SharedString::from("Random error")), - }; - log::info!( - "Context {}: update message {:?} status to {:?}", - context_index, - message.id, - new_status - ); - context.update_metadata(message.id, cx, |metadata| { - metadata.status = new_status; - }); - } - }); - mutation_count -= 1; - } - _ => { - let replica_id = context_index as ReplicaId; - if network.lock().is_disconnected(replica_id) { - network.lock().reconnect_peer(replica_id, 0); - - let (ops_to_send, ops_to_receive) = cx.read(|cx| { - let host_context = &contexts[0].read(cx); - let guest_context = context.read(cx); - ( - guest_context.serialize_ops(&host_context.version(cx), cx), - host_context.serialize_ops(&guest_context.version(cx), cx), - ) - }); - let ops_to_send = ops_to_send.await; - let ops_to_receive = ops_to_receive - .await - .into_iter() - .map(ContextOperation::from_proto) - .collect::>>() - .unwrap(); - log::info!( - "Context {}: reconnecting. Sent {} operations, received {} operations", - context_index, - ops_to_send.len(), - ops_to_receive.len() - ); - - network.lock().broadcast(replica_id, ops_to_send); - context - .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx)) - .unwrap(); - } else if rng.gen_bool(0.1) && replica_id != 0 { - log::info!("Context {}: disconnecting", context_index); - network.lock().disconnect_peer(replica_id); - } else if network.lock().has_unreceived(replica_id) { - log::info!("Context {}: applying operations", context_index); - let ops = network.lock().receive(replica_id); - let ops = ops - .into_iter() - .map(ContextOperation::from_proto) - .collect::>>() - .unwrap(); - context - .update(cx, |context, cx| context.apply_ops(ops, cx)) - .unwrap(); - } - } - } - } - - cx.read(|cx| { - let first_context = contexts[0].read(cx); - for context in &contexts[1..] { - let context = context.read(cx); - assert!(context.pending_ops.is_empty()); - assert_eq!( - context.buffer.read(cx).text(), - first_context.buffer.read(cx).text(), - "Context {} text != Context 0 text", - context.buffer.read(cx).replica_id() - ); - assert_eq!( - context.message_anchors, - first_context.message_anchors, - "Context {} messages != Context 0 messages", - context.buffer.read(cx).replica_id() - ); - assert_eq!( - context.messages_metadata, - first_context.messages_metadata, - "Context {} message metadata != Context 0 message metadata", - context.buffer.read(cx).replica_id() - ); - assert_eq!( - context.slash_command_output_sections, - first_context.slash_command_output_sections, - "Context {} slash command output sections != Context 0 slash command output sections", - context.buffer.read(cx).replica_id() - ); - } - }); - } - - fn messages(context: &Model, cx: &AppContext) -> Vec<(MessageId, Role, Range)> { - context - .read(cx) - .messages(cx) - .map(|message| (message.id, message.role, message.offset_range)) - .collect() - } - - #[derive(Clone)] - struct FakeSlashCommand(String); - - impl SlashCommand for FakeSlashCommand { - fn name(&self) -> String { - self.0.clone() - } - - fn description(&self) -> String { - format!("Fake slash command: {}", self.0) - } - - fn menu_text(&self) -> String { - format!("Run fake command: {}", self.0) - } - - fn complete_argument( - self: Arc, - _arguments: &[String], - _cancel: Arc, - _workspace: Option>, - _cx: &mut WindowContext, - ) -> Task>> { - Task::ready(Ok(vec![])) - } - - fn requires_argument(&self) -> bool { - false - } - - fn run( - self: Arc, - _arguments: &[String], - _workspace: WeakView, - _delegate: Option>, - _cx: &mut WindowContext, - ) -> Task> { - Task::ready(Ok(SlashCommandOutput { - text: format!("Executed fake command: {}", self.0), - sections: vec![], - run_commands_in_text: false, - })) - } - } -} diff --git a/crates/assistant/src/context/context_tests.rs b/crates/assistant/src/context/context_tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..3718c3781266996659b6821cab56e60f63329b8a --- /dev/null +++ b/crates/assistant/src/context/context_tests.rs @@ -0,0 +1,1055 @@ +use crate::{ + assistant_panel, prompt_library, slash_command::file_command, workflow::tool, Context, + ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder, +}; +use anyhow::Result; +use assistant_slash_command::{ + ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection, + SlashCommandRegistry, +}; +use collections::HashSet; +use fs::{FakeFs, Fs as _}; +use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView}; +use indoc::indoc; +use language::{Buffer, LanguageRegistry, LspAdapterDelegate}; +use language_model::{LanguageModelRegistry, Role}; +use parking_lot::Mutex; +use project::Project; +use rand::prelude::*; +use rope::Point; +use serde_json::json; +use settings::SettingsStore; +use std::{ + cell::RefCell, + env, + ops::Range, + path::Path, + rc::Rc, + sync::{atomic::AtomicBool, Arc}, +}; +use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToPoint as _}; +use ui::{Context as _, WindowContext}; +use unindent::Unindent; +use util::{test::marked_text_ranges, RandomCharIter}; +use workspace::Workspace; + +#[gpui::test] +fn test_inserting_and_removing_messages(cx: &mut AppContext) { + let settings_store = SettingsStore::test(cx); + LanguageModelRegistry::test(cx); + cx.set_global(settings_store); + assistant_panel::init(cx); + let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = + cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); + let buffer = context.read(cx).buffer.clone(); + + let message_1 = context.read(cx).message_anchors[0].clone(); + assert_eq!( + messages(&context, cx), + vec![(message_1.id, Role::User, 0..0)] + ); + + let message_2 = context.update(cx, |context, cx| { + context + .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) + .unwrap() + }); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..1), + (message_2.id, Role::Assistant, 1..1) + ] + ); + + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "1"), (1..1, "2")], None, cx) + }); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..3) + ] + ); + + let message_3 = context.update(cx, |context, cx| { + context + .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) + .unwrap() + }); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..4), + (message_3.id, Role::User, 4..4) + ] + ); + + let message_4 = context.update(cx, |context, cx| { + context + .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) + .unwrap() + }); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..4), + (message_4.id, Role::User, 4..5), + (message_3.id, Role::User, 5..5), + ] + ); + + buffer.update(cx, |buffer, cx| { + buffer.edit([(4..4, "C"), (5..5, "D")], None, cx) + }); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..4), + (message_4.id, Role::User, 4..6), + (message_3.id, Role::User, 6..7), + ] + ); + + // Deleting across message boundaries merges the messages. + buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx)); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..3), + (message_3.id, Role::User, 3..4), + ] + ); + + // Undoing the deletion should also undo the merge. + buffer.update(cx, |buffer, cx| buffer.undo(cx)); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..4), + (message_4.id, Role::User, 4..6), + (message_3.id, Role::User, 6..7), + ] + ); + + // Redoing the deletion should also redo the merge. + buffer.update(cx, |buffer, cx| buffer.redo(cx)); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..3), + (message_3.id, Role::User, 3..4), + ] + ); + + // Ensure we can still insert after a merged message. + let message_5 = context.update(cx, |context, cx| { + context + .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) + .unwrap() + }); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..3), + (message_5.id, Role::System, 3..4), + (message_3.id, Role::User, 4..5) + ] + ); +} + +#[gpui::test] +fn test_message_splitting(cx: &mut AppContext) { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + LanguageModelRegistry::test(cx); + assistant_panel::init(cx); + let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); + + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = + cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); + let buffer = context.read(cx).buffer.clone(); + + let message_1 = context.read(cx).message_anchors[0].clone(); + assert_eq!( + messages(&context, cx), + vec![(message_1.id, Role::User, 0..0)] + ); + + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx) + }); + + let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx)); + let message_2 = message_2.unwrap(); + + // We recycle newlines in the middle of a split message + assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_2.id, Role::User, 4..16), + ] + ); + + let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx)); + let message_3 = message_3.unwrap(); + + // We don't recycle newlines at the end of a split message + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..17), + ] + ); + + let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx)); + let message_4 = message_4.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..17), + ] + ); + + let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx)); + let message_5 = message_5.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..10), + (message_5.id, Role::User, 10..18), + ] + ); + + let (message_6, message_7) = + context.update(cx, |context, cx| context.split_message(14..16, cx)); + let message_6 = message_6.unwrap(); + let message_7 = message_7.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..10), + (message_5.id, Role::User, 10..14), + (message_6.id, Role::User, 14..17), + (message_7.id, Role::User, 17..19), + ] + ); +} + +#[gpui::test] +fn test_messages_for_offsets(cx: &mut AppContext) { + let settings_store = SettingsStore::test(cx); + LanguageModelRegistry::test(cx); + cx.set_global(settings_store); + assistant_panel::init(cx); + let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = + cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); + let buffer = context.read(cx).buffer.clone(); + + let message_1 = context.read(cx).message_anchors[0].clone(); + assert_eq!( + messages(&context, cx), + vec![(message_1.id, Role::User, 0..0)] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx)); + let message_2 = context + .update(cx, |context, cx| { + context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx) + }) + .unwrap(); + buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx)); + + let message_3 = context + .update(cx, |context, cx| { + context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) + }) + .unwrap(); + buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx)); + + assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_2.id, Role::User, 4..8), + (message_3.id, Role::User, 8..11) + ] + ); + + assert_eq!( + message_ids_for_offsets(&context, &[0, 4, 9], cx), + [message_1.id, message_2.id, message_3.id] + ); + assert_eq!( + message_ids_for_offsets(&context, &[0, 1, 11], cx), + [message_1.id, message_3.id] + ); + + let message_4 = context + .update(cx, |context, cx| { + context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx) + }) + .unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_2.id, Role::User, 4..8), + (message_3.id, Role::User, 8..12), + (message_4.id, Role::User, 12..12) + ] + ); + assert_eq!( + message_ids_for_offsets(&context, &[0, 4, 8, 12], cx), + [message_1.id, message_2.id, message_3.id, message_4.id] + ); + + fn message_ids_for_offsets( + context: &Model, + offsets: &[usize], + cx: &AppContext, + ) -> Vec { + context + .read(cx) + .messages_for_offsets(offsets.iter().copied(), cx) + .into_iter() + .map(|message| message.id) + .collect() + } +} + +#[gpui::test] +async fn test_slash_commands(cx: &mut TestAppContext) { + let settings_store = cx.update(SettingsStore::test); + cx.set_global(settings_store); + cx.update(LanguageModelRegistry::test); + cx.update(Project::init_settings); + cx.update(assistant_panel::init); + let fs = FakeFs::new(cx.background_executor.clone()); + + fs.insert_tree( + "/test", + json!({ + "src": { + "lib.rs": "fn one() -> usize { 1 }", + "main.rs": " + use crate::one; + fn main() { one(); } + ".unindent(), + } + }), + ) + .await; + + let slash_command_registry = cx.update(SlashCommandRegistry::default_global); + slash_command_registry.register_command(file_command::FileSlashCommand, false); + + let registry = Arc::new(LanguageRegistry::test(cx.executor())); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = + cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx)); + + let output_ranges = Rc::new(RefCell::new(HashSet::default())); + context.update(cx, |_, cx| { + cx.subscribe(&context, { + let ranges = output_ranges.clone(); + move |_, _, event, _| match event { + ContextEvent::PendingSlashCommandsUpdated { removed, updated } => { + for range in removed { + ranges.borrow_mut().remove(range); + } + for command in updated { + ranges.borrow_mut().insert(command.source_range.clone()); + } + } + _ => {} + } + }) + .detach(); + }); + + let buffer = context.read_with(cx, |context, _| context.buffer.clone()); + + // Insert a slash command + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "/file src/lib.rs")], None, cx); + }); + assert_text_and_output_ranges( + &buffer, + &output_ranges.borrow(), + " + «/file src/lib.rs» + " + .unindent() + .trim_end(), + cx, + ); + + // Edit the argument of the slash command. + buffer.update(cx, |buffer, cx| { + let edit_offset = buffer.text().find("lib.rs").unwrap(); + buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx); + }); + assert_text_and_output_ranges( + &buffer, + &output_ranges.borrow(), + " + «/file src/main.rs» + " + .unindent() + .trim_end(), + cx, + ); + + // Edit the name of the slash command, using one that doesn't exist. + buffer.update(cx, |buffer, cx| { + let edit_offset = buffer.text().find("/file").unwrap(); + buffer.edit( + [(edit_offset..edit_offset + "/file".len(), "/unknown")], + None, + cx, + ); + }); + assert_text_and_output_ranges( + &buffer, + &output_ranges.borrow(), + " + /unknown src/main.rs + " + .unindent() + .trim_end(), + cx, + ); + + #[track_caller] + fn assert_text_and_output_ranges( + buffer: &Model, + ranges: &HashSet>, + expected_marked_text: &str, + cx: &mut TestAppContext, + ) { + let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false); + let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| { + let mut ranges = ranges + .iter() + .map(|range| range.to_offset(buffer)) + .collect::>(); + ranges.sort_by_key(|a| a.start); + (buffer.text(), ranges) + }); + + assert_eq!(actual_text, expected_text); + assert_eq!(actual_ranges, expected_ranges); + } +} + +#[gpui::test] +async fn test_edit_step_parsing(cx: &mut TestAppContext) { + cx.update(prompt_library::init); + let settings_store = cx.update(SettingsStore::test); + cx.set_global(settings_store); + cx.update(Project::init_settings); + let fs = FakeFs::new(cx.executor()); + fs.as_fake() + .insert_tree( + "/root", + json!({ + "hello.rs": r#" + fn hello() { + println!("Hello, World!"); + } + "#.unindent() + }), + ) + .await; + let project = Project::test(fs, [Path::new("/root")], cx).await; + cx.update(LanguageModelRegistry::test); + + let model = cx.read(|cx| { + LanguageModelRegistry::read_global(cx) + .active_model() + .unwrap() + }); + cx.update(assistant_panel::init); + let registry = Arc::new(LanguageRegistry::test(cx.executor())); + + // Create a new context + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = cx.new_model(|cx| { + Context::local( + registry.clone(), + Some(project), + None, + prompt_builder.clone(), + cx, + ) + }); + let buffer = context.read_with(cx, |context, _| context.buffer.clone()); + + // Simulate user input + let user_message = indoc! {r#" + Please add unnecessary complexity to this code: + + ```hello.rs + fn main() { + println!("Hello, World!"); + } + ``` + "#}; + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, user_message)], None, cx); + }); + + // Simulate LLM response with edit steps + let llm_response = indoc! {r#" + Sure, I can help you with that. Here's a step-by-step process: + + + First, let's extract the greeting into a separate function: + + ```rust + fn greet() { + println!("Hello, World!"); + } + + fn main() { + greet(); + } + ``` + + + + Now, let's make the greeting customizable: + + ```rust + fn greet(name: &str) { + println!("Hello, {}!", name); + } + + fn main() { + greet("World"); + } + ``` + + + These changes make the code more modular and flexible. + "#}; + + // Simulate the assist method to trigger the LLM response + context.update(cx, |context, cx| context.assist(cx)); + cx.run_until_parked(); + + // Retrieve the assistant response message's start from the context + let response_start_row = context.read_with(cx, |context, cx| { + let buffer = context.buffer.read(cx); + context.message_anchors[1].start.to_point(buffer).row + }); + + // Simulate the LLM completion + model + .as_fake() + .stream_last_completion_response(llm_response.to_string()); + model.as_fake().end_last_completion_stream(); + + // Wait for the completion to be processed + cx.run_until_parked(); + + // Verify that the edit steps were parsed correctly + context.read_with(cx, |context, cx| { + assert_eq!( + workflow_steps(context, cx), + vec![ + ( + Point::new(response_start_row + 2, 0)..Point::new(response_start_row + 12, 3), + WorkflowStepTestStatus::Pending + ), + ( + Point::new(response_start_row + 14, 0)..Point::new(response_start_row + 24, 3), + WorkflowStepTestStatus::Pending + ), + ] + ); + }); + + model + .as_fake() + .respond_to_last_tool_use(tool::WorkflowStepResolutionTool { + step_title: "Title".into(), + suggestions: vec![tool::WorkflowSuggestionTool { + path: "/root/hello.rs".into(), + // Simulate a symbol name that's slightly different than our outline query + kind: tool::WorkflowSuggestionToolKind::Update { + symbol: "fn main()".into(), + description: "Extract a greeting function".into(), + }, + }], + }); + + // Wait for tool use to be processed. + cx.run_until_parked(); + + // Verify that the first edit step is not pending anymore. + context.read_with(cx, |context, cx| { + assert_eq!( + workflow_steps(context, cx), + vec![ + ( + Point::new(response_start_row + 2, 0)..Point::new(response_start_row + 12, 3), + WorkflowStepTestStatus::Resolved + ), + ( + Point::new(response_start_row + 14, 0)..Point::new(response_start_row + 24, 3), + WorkflowStepTestStatus::Pending + ), + ] + ); + }); + + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + enum WorkflowStepTestStatus { + Pending, + Resolved, + Error, + } + + fn workflow_steps( + context: &Context, + cx: &AppContext, + ) -> Vec<(Range, WorkflowStepTestStatus)> { + context + .workflow_steps + .iter() + .map(|step| { + let buffer = context.buffer.read(cx); + let status = match &step.step.read(cx).resolution { + None => WorkflowStepTestStatus::Pending, + Some(Ok(_)) => WorkflowStepTestStatus::Resolved, + Some(Err(_)) => WorkflowStepTestStatus::Error, + }; + (step.range.to_point(buffer), status) + }) + .collect() + } +} + +#[gpui::test] +async fn test_serialization(cx: &mut TestAppContext) { + let settings_store = cx.update(SettingsStore::test); + cx.set_global(settings_store); + cx.update(LanguageModelRegistry::test); + cx.update(assistant_panel::init); + let registry = Arc::new(LanguageRegistry::test(cx.executor())); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = + cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx)); + let buffer = context.read_with(cx, |context, _| context.buffer.clone()); + let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id); + let message_1 = context.update(cx, |context, cx| { + context + .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx) + .unwrap() + }); + let message_2 = context.update(cx, |context, cx| { + context + .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) + .unwrap() + }); + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx); + buffer.finalize_last_transaction(); + }); + let _message_3 = context.update(cx, |context, cx| { + context + .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx) + .unwrap() + }); + buffer.update(cx, |buffer, cx| buffer.undo(cx)); + assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n"); + assert_eq!( + cx.read(|cx| messages(&context, cx)), + [ + (message_0, Role::User, 0..2), + (message_1.id, Role::Assistant, 2..6), + (message_2.id, Role::System, 6..6), + ] + ); + + let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx)); + let deserialized_context = cx.new_model(|cx| { + Context::deserialize( + serialized_context, + Default::default(), + registry.clone(), + prompt_builder.clone(), + None, + None, + cx, + ) + }); + let deserialized_buffer = + deserialized_context.read_with(cx, |context, _| context.buffer.clone()); + assert_eq!( + deserialized_buffer.read_with(cx, |buffer, _| buffer.text()), + "a\nb\nc\n" + ); + assert_eq!( + cx.read(|cx| messages(&deserialized_context, cx)), + [ + (message_0, Role::User, 0..2), + (message_1.id, Role::Assistant, 2..6), + (message_2.id, Role::System, 6..6), + ] + ); +} + +#[gpui::test(iterations = 100)] +async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) { + let min_peers = env::var("MIN_PEERS") + .map(|i| i.parse().expect("invalid `MIN_PEERS` variable")) + .unwrap_or(2); + let max_peers = env::var("MAX_PEERS") + .map(|i| i.parse().expect("invalid `MAX_PEERS` variable")) + .unwrap_or(5); + let operations = env::var("OPERATIONS") + .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) + .unwrap_or(50); + + let settings_store = cx.update(SettingsStore::test); + cx.set_global(settings_store); + cx.update(LanguageModelRegistry::test); + + cx.update(assistant_panel::init); + let slash_commands = cx.update(SlashCommandRegistry::default_global); + slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false); + slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false); + slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false); + + let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone())); + let network = Arc::new(Mutex::new(Network::new(rng.clone()))); + let mut contexts = Vec::new(); + + let num_peers = rng.gen_range(min_peers..=max_peers); + let context_id = ContextId::new(); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + for i in 0..num_peers { + let context = cx.new_model(|cx| { + Context::new( + context_id.clone(), + i as ReplicaId, + language::Capability::ReadWrite, + registry.clone(), + prompt_builder.clone(), + None, + None, + cx, + ) + }); + + cx.update(|cx| { + cx.subscribe(&context, { + let network = network.clone(); + move |_, event, _| { + if let ContextEvent::Operation(op) = event { + network + .lock() + .broadcast(i as ReplicaId, vec![op.to_proto()]); + } + } + }) + .detach(); + }); + + contexts.push(context); + network.lock().add_peer(i as ReplicaId); + } + + let mut mutation_count = operations; + + while mutation_count > 0 + || !network.lock().is_idle() + || network.lock().contains_disconnected_peers() + { + let context_index = rng.gen_range(0..contexts.len()); + let context = &contexts[context_index]; + + match rng.gen_range(0..100) { + 0..=29 if mutation_count > 0 => { + log::info!("Context {}: edit buffer", context_index); + context.update(cx, |context, cx| { + context + .buffer + .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx)); + }); + mutation_count -= 1; + } + 30..=44 if mutation_count > 0 => { + context.update(cx, |context, cx| { + let range = context.buffer.read(cx).random_byte_range(0, &mut rng); + log::info!("Context {}: split message at {:?}", context_index, range); + context.split_message(range, cx); + }); + mutation_count -= 1; + } + 45..=59 if mutation_count > 0 => { + context.update(cx, |context, cx| { + if let Some(message) = context.messages(cx).choose(&mut rng) { + let role = *[Role::User, Role::Assistant, Role::System] + .choose(&mut rng) + .unwrap(); + log::info!( + "Context {}: insert message after {:?} with {:?}", + context_index, + message.id, + role + ); + context.insert_message_after(message.id, role, MessageStatus::Done, cx); + } + }); + mutation_count -= 1; + } + 60..=74 if mutation_count > 0 => { + context.update(cx, |context, cx| { + let command_text = "/".to_string() + + slash_commands + .command_names() + .choose(&mut rng) + .unwrap() + .clone() + .as_ref(); + + let command_range = context.buffer.update(cx, |buffer, cx| { + let offset = buffer.random_byte_range(0, &mut rng).start; + buffer.edit( + [(offset..offset, format!("\n{}\n", command_text))], + None, + cx, + ); + offset + 1..offset + 1 + command_text.len() + }); + + let output_len = rng.gen_range(1..=10); + let output_text = RandomCharIter::new(&mut rng) + .filter(|c| *c != '\r') + .take(output_len) + .collect::(); + + let num_sections = rng.gen_range(0..=3); + let mut sections = Vec::with_capacity(num_sections); + for _ in 0..num_sections { + let section_start = rng.gen_range(0..output_len); + let section_end = rng.gen_range(section_start..=output_len); + sections.push(SlashCommandOutputSection { + range: section_start..section_end, + icon: ui::IconName::Ai, + label: "section".into(), + }); + } + + log::info!( + "Context {}: insert slash command output at {:?} with {:?}", + context_index, + command_range, + sections + ); + + let command_range = context.buffer.read(cx).anchor_after(command_range.start) + ..context.buffer.read(cx).anchor_after(command_range.end); + context.insert_command_output( + command_range, + Task::ready(Ok(SlashCommandOutput { + text: output_text, + sections, + run_commands_in_text: false, + })), + true, + cx, + ); + }); + cx.run_until_parked(); + mutation_count -= 1; + } + 75..=84 if mutation_count > 0 => { + context.update(cx, |context, cx| { + if let Some(message) = context.messages(cx).choose(&mut rng) { + let new_status = match rng.gen_range(0..3) { + 0 => MessageStatus::Done, + 1 => MessageStatus::Pending, + _ => MessageStatus::Error(SharedString::from("Random error")), + }; + log::info!( + "Context {}: update message {:?} status to {:?}", + context_index, + message.id, + new_status + ); + context.update_metadata(message.id, cx, |metadata| { + metadata.status = new_status; + }); + } + }); + mutation_count -= 1; + } + _ => { + let replica_id = context_index as ReplicaId; + if network.lock().is_disconnected(replica_id) { + network.lock().reconnect_peer(replica_id, 0); + + let (ops_to_send, ops_to_receive) = cx.read(|cx| { + let host_context = &contexts[0].read(cx); + let guest_context = context.read(cx); + ( + guest_context.serialize_ops(&host_context.version(cx), cx), + host_context.serialize_ops(&guest_context.version(cx), cx), + ) + }); + let ops_to_send = ops_to_send.await; + let ops_to_receive = ops_to_receive + .await + .into_iter() + .map(ContextOperation::from_proto) + .collect::>>() + .unwrap(); + log::info!( + "Context {}: reconnecting. Sent {} operations, received {} operations", + context_index, + ops_to_send.len(), + ops_to_receive.len() + ); + + network.lock().broadcast(replica_id, ops_to_send); + context + .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx)) + .unwrap(); + } else if rng.gen_bool(0.1) && replica_id != 0 { + log::info!("Context {}: disconnecting", context_index); + network.lock().disconnect_peer(replica_id); + } else if network.lock().has_unreceived(replica_id) { + log::info!("Context {}: applying operations", context_index); + let ops = network.lock().receive(replica_id); + let ops = ops + .into_iter() + .map(ContextOperation::from_proto) + .collect::>>() + .unwrap(); + context + .update(cx, |context, cx| context.apply_ops(ops, cx)) + .unwrap(); + } + } + } + } + + cx.read(|cx| { + let first_context = contexts[0].read(cx); + for context in &contexts[1..] { + let context = context.read(cx); + assert!(context.pending_ops.is_empty()); + assert_eq!( + context.buffer.read(cx).text(), + first_context.buffer.read(cx).text(), + "Context {} text != Context 0 text", + context.buffer.read(cx).replica_id() + ); + assert_eq!( + context.message_anchors, + first_context.message_anchors, + "Context {} messages != Context 0 messages", + context.buffer.read(cx).replica_id() + ); + assert_eq!( + context.messages_metadata, + first_context.messages_metadata, + "Context {} message metadata != Context 0 message metadata", + context.buffer.read(cx).replica_id() + ); + assert_eq!( + context.slash_command_output_sections, + first_context.slash_command_output_sections, + "Context {} slash command output sections != Context 0 slash command output sections", + context.buffer.read(cx).replica_id() + ); + } + }); +} + +fn messages(context: &Model, cx: &AppContext) -> Vec<(MessageId, Role, Range)> { + context + .read(cx) + .messages(cx) + .map(|message| (message.id, message.role, message.offset_range)) + .collect() +} + +#[derive(Clone)] +struct FakeSlashCommand(String); + +impl SlashCommand for FakeSlashCommand { + fn name(&self) -> String { + self.0.clone() + } + + fn description(&self) -> String { + format!("Fake slash command: {}", self.0) + } + + fn menu_text(&self) -> String { + format!("Run fake command: {}", self.0) + } + + fn complete_argument( + self: Arc, + _arguments: &[String], + _cancel: Arc, + _workspace: Option>, + _cx: &mut WindowContext, + ) -> Task>> { + Task::ready(Ok(vec![])) + } + + fn requires_argument(&self) -> bool { + false + } + + fn run( + self: Arc, + _arguments: &[String], + _workspace: WeakView, + _delegate: Option>, + _cx: &mut WindowContext, + ) -> Task> { + Task::ready(Ok(SlashCommandOutput { + text: format!("Executed fake command: {}", self.0), + sections: vec![], + run_commands_in_text: false, + })) + } +}