diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index d3a4b665328c5b5200e0cf230bd05a020f47ee3c..a7efd5c0fed5a5f70aea5987daefdffa2697967e 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -1,7 +1,7 @@ pub mod assistant_panel; pub mod assistant_settings; -mod codegen; mod completion_provider; +mod inline_assistant; mod model_selector; mod prompt_library; mod prompts; @@ -18,6 +18,7 @@ use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; pub(crate) use completion_provider::*; use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal}; +pub(crate) use inline_assistant::*; pub(crate) use model_selector::*; pub(crate) use saved_conversation::*; use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; @@ -31,6 +32,7 @@ use std::{ fmt::{self, Display}, sync::Arc, }; +pub(crate) use streaming_diff::*; use util::paths::EMBEDDINGS_DIR; actions!( @@ -273,10 +275,11 @@ pub fn init(client: Arc, cx: &mut AppContext) { .detach(); prompt_library::init(cx); - completion_provider::init(client, cx); + completion_provider::init(client.clone(), cx); assistant_slash_command::init(cx); register_slash_commands(cx); assistant_panel::init(cx); + inline_assistant::init(client.telemetry().clone(), cx); CommandPaletteFilter::update_global(cx, |filter, _cx| { filter.hide_namespace(Assistant::NAMESPACE); diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 9e05e37793d6b67c281bedbcb298fd644dadaf43..223a8d0fe67f19d6663d2643c63a4df2932ed3b1 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,31 +1,27 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings}, - codegen::{self, Codegen, CodegenKind}, prompt_library::open_prompt_library, - prompts::generate_content_prompt, search::*, slash_command::{ default_command::DefaultSlashCommand, SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry, }, ApplyEdit, Assist, CompletionProvider, ConfirmCommand, CycleMessageRole, InlineAssist, - LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus, - ModelSelector, QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata, - SavedMessage, Split, ToggleFocus, ToggleHistory, ToggleModelSelector, + InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, + MessageStatus, ModelSelector, QuoteSelection, ResetKey, Role, SavedConversation, + SavedConversationMetadata, SavedMessage, Split, ToggleFocus, ToggleHistory, + ToggleModelSelector, }; use anyhow::{anyhow, Result}; use assistant_slash_command::{SlashCommand, SlashCommandOutput, SlashCommandOutputSection}; use client::telemetry::Telemetry; -use collections::{hash_map, BTreeSet, HashMap, HashSet, VecDeque}; -use editor::{actions::ShowCompletions, GutterDimensions}; +use collections::{BTreeSet, HashMap, HashSet}; +use editor::actions::ShowCompletions; use editor::{ - actions::{FoldAt, MoveDown, MoveToEndOfLine, MoveUp, Newline, UnfoldAt}, - display_map::{ - BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, Flap, ToDisplayPoint, - }, + actions::{FoldAt, MoveToEndOfLine, Newline, UnfoldAt}, + display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, Flap, ToDisplayPoint}, scroll::{Autoscroll, AutoscrollStrategy}, - Anchor, Editor, EditorElement, EditorEvent, EditorStyle, MultiBufferSnapshot, RowExt, - ToOffset as _, ToPoint, + Anchor, Editor, EditorEvent, RowExt, ToOffset as _, ToPoint, }; use editor::{display_map::FlapId, FoldPlaceholder}; use file_icons::FileIcons; @@ -33,19 +29,17 @@ use fs::Fs; use futures::future::Shared; use futures::{FutureExt, StreamExt}; use gpui::{ - div, point, relative, rems, uniform_list, Action, AnyElement, AnyView, AppContext, - AsyncAppContext, AsyncWindowContext, ClipboardItem, Context, Empty, EventEmitter, FocusHandle, - FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement, IntoElement, Model, - ModelContext, ParentElement, Pixels, Render, SharedString, StatefulInteractiveElement, Styled, - Subscription, Task, TextStyle, UniformListScrollHandle, View, ViewContext, VisualContext, - WeakModel, WeakView, WhiteSpace, WindowContext, + div, point, rems, uniform_list, Action, AnyElement, AnyView, AppContext, AsyncAppContext, + AsyncWindowContext, ClipboardItem, Context, Empty, EventEmitter, FocusHandle, FocusableView, + InteractiveElement, IntoElement, Model, ModelContext, ParentElement, Pixels, Render, + SharedString, StatefulInteractiveElement, Styled, Subscription, Task, UniformListScrollHandle, + UpdateGlobal, View, ViewContext, VisualContext, WeakView, WindowContext, }; use language::{ language_settings::SoftWrap, AnchorRangeExt, AutoindentMode, Buffer, LanguageRegistry, LspAdapterDelegate, OffsetRangeExt as _, Point, ToOffset as _, }; use multi_buffer::MultiBufferRow; -use parking_lot::Mutex; use project::{Project, ProjectLspAdapterDelegate, ProjectTransaction}; use search::{buffer_search::DivRegistrar, BufferSearchBar}; use settings::Settings; @@ -59,19 +53,18 @@ use std::{ time::{Duration, Instant}, }; use telemetry_events::AssistantKind; -use theme::ThemeSettings; use ui::{ popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding, PopoverMenuHandle, Tab, TabBar, Tooltip, }; use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt}; use uuid::Uuid; +use workspace::NewFile; use workspace::{ dock::{DockPosition, Panel, PanelEvent}, searchable::Direction, - Save, Toast, ToggleZoom, Toolbar, Workspace, + Save, ToggleZoom, Toolbar, Workspace, }; -use workspace::{notifications::NotificationId, NewFile}; pub fn init(cx: &mut AppContext) { cx.observe_new_views( @@ -110,10 +103,6 @@ pub struct AssistantPanel { fs: Arc, telemetry: Arc, _subscriptions: Vec, - next_inline_assist_id: usize, - pending_inline_assists: HashMap, - pending_inline_assist_ids_by_editor: HashMap, Vec>, - inline_prompt_history: VecDeque, _watch_saved_conversations: Task>, authentication_prompt: Option, model_menu_handle: PopoverMenuHandle, @@ -125,8 +114,6 @@ struct ActiveConversationEditor { } impl AssistantPanel { - const INLINE_PROMPT_HISTORY_MAX_LEN: usize = 20; - pub fn load( workspace: WeakView, cx: AsyncWindowContext, @@ -205,10 +192,6 @@ impl AssistantPanel { width: None, height: None, _subscriptions: subscriptions, - next_inline_assist_id: 0, - pending_inline_assists: Default::default(), - pending_inline_assist_ids_by_editor: Default::default(), - inline_prompt_history: Default::default(), _watch_saved_conversations, authentication_prompt: None, model_menu_handle: PopoverMenuHandle::default(), @@ -295,26 +278,30 @@ impl AssistantPanel { } }); - let show_include_conversation; + let include_conversation; let active_editor; if let Some(conversation_editor) = conversation_editor { active_editor = conversation_editor; - show_include_conversation = false; + include_conversation = false; } else if let Some(workspace_editor) = workspace .active_item(cx) .and_then(|item| item.act_as::(cx)) { active_editor = workspace_editor; - show_include_conversation = true; + include_conversation = true; } else { return; }; - let project = workspace.project().clone(); if assistant.update(cx, |assistant, cx| assistant.is_authenticated(cx)) { - assistant.update(cx, |assistant, cx| { - assistant.new_inline_assist(&active_editor, &project, show_include_conversation, cx) - }); + InlineAssistant::update_global(cx, |assistant, cx| { + assistant.assist( + &active_editor, + Some(cx.view().downgrade()), + include_conversation, + cx, + ) + }) } else { let assistant = assistant.downgrade(); cx.spawn(|workspace, mut cx| async move { @@ -322,14 +309,16 @@ impl AssistantPanel { .update(&mut cx, |assistant, cx| assistant.authenticate(cx))? .await?; if assistant.update(&mut cx, |assistant, cx| assistant.is_authenticated(cx))? { - assistant.update(&mut cx, |assistant, cx| { - assistant.new_inline_assist( - &active_editor, - &project, - show_include_conversation, - cx, - ) - })?; + cx.update(|cx| { + InlineAssistant::update_global(cx, |assistant, cx| { + assistant.assist( + &active_editor, + Some(workspace), + include_conversation, + cx, + ) + }) + })? } else { workspace.update(&mut cx, |workspace, cx| { workspace.focus_panel::(cx) @@ -342,439 +331,17 @@ impl AssistantPanel { } } - fn new_inline_assist( - &mut self, - editor: &View, - project: &Model, - include_conversation: bool, - cx: &mut ViewContext, - ) { - let selection = editor.read(cx).selections.newest_anchor().clone(); - if selection.start.excerpt_id != selection.end.excerpt_id { - return; - } - let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); - - // Extend the selection to the start and the end of the line. - let mut point_selection = selection.map(|selection| selection.to_point(&snapshot)); - if point_selection.end > point_selection.start { - point_selection.start.column = 0; - // If the selection ends at the start of the line, we don't want to include it. - if point_selection.end.column == 0 { - point_selection.end.row -= 1; - } - point_selection.end.column = snapshot.line_len(MultiBufferRow(point_selection.end.row)); - } - - let codegen_kind = if point_selection.start == point_selection.end { - CodegenKind::Generate { - position: snapshot.anchor_after(point_selection.start), - } - } else { - CodegenKind::Transform { - range: snapshot.anchor_before(point_selection.start) - ..snapshot.anchor_after(point_selection.end), - } - }; - - let inline_assist_id = post_inc(&mut self.next_inline_assist_id); - let telemetry = self.telemetry.clone(); - - let codegen = cx.new_model(|cx| { - Codegen::new( - editor.read(cx).buffer().clone(), - codegen_kind, - Some(telemetry), - cx, - ) - }); - - let measurements = Arc::new(Mutex::new(GutterDimensions::default())); - let inline_assistant = cx.new_view(|cx| { - InlineAssistant::new( - inline_assist_id, - measurements.clone(), - include_conversation, - self.inline_prompt_history.clone(), - codegen.clone(), - cx, - ) - }); - let block_id = editor.update(cx, |editor, cx| { - editor.change_selections(None, cx, |selections| { - selections.select_anchor_ranges([selection.head()..selection.head()]) - }); - editor.insert_blocks( - [BlockProperties { - style: BlockStyle::Flex, - position: snapshot.anchor_before(Point::new(point_selection.head().row, 0)), - height: 2, - render: Box::new({ - let inline_assistant = inline_assistant.clone(); - move |cx: &mut BlockContext| { - *measurements.lock() = *cx.gutter_dimensions; - inline_assistant.clone().into_any_element() - } - }), - disposition: if selection.reversed { - BlockDisposition::Above - } else { - BlockDisposition::Below - }, - }], - Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)), - cx, - )[0] - }); - - self.pending_inline_assists.insert( - inline_assist_id, - PendingInlineAssist { - editor: editor.downgrade(), - inline_assistant: Some((block_id, inline_assistant.clone())), - codegen: codegen.clone(), - project: project.downgrade(), - _subscriptions: vec![ - cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event), - cx.subscribe(editor, { - let inline_assistant = inline_assistant.downgrade(); - move |_, editor, event, cx| { - if let Some(inline_assistant) = inline_assistant.upgrade() { - if let EditorEvent::SelectionsChanged { local } = event { - if *local - && inline_assistant.focus_handle(cx).contains_focused(cx) - { - cx.focus_view(&editor); - } - } - } - } - }), - cx.observe(&codegen, { - let editor = editor.downgrade(); - move |this, _, cx| { - if let Some(editor) = editor.upgrade() { - this.update_highlights_for_editor(&editor, cx); - } - } - }), - cx.subscribe(&codegen, move |this, codegen, event, cx| match event { - codegen::Event::Undone => { - this.finish_inline_assist(inline_assist_id, false, cx) - } - codegen::Event::Finished => { - let pending_assist = if let Some(pending_assist) = - this.pending_inline_assists.get(&inline_assist_id) - { - pending_assist - } else { - return; - }; - - let error = codegen - .read(cx) - .error() - .map(|error| format!("Inline assistant error: {}", error)); - if let Some(error) = error { - if pending_assist.inline_assistant.is_none() { - if let Some(workspace) = this.workspace.upgrade() { - workspace.update(cx, |workspace, cx| { - struct InlineAssistantError; - - let id = - NotificationId::identified::( - inline_assist_id, - ); - - workspace.show_toast(Toast::new(id, error), cx); - }) - } - - this.finish_inline_assist(inline_assist_id, false, cx); - } - } else { - this.finish_inline_assist(inline_assist_id, false, cx); - } - } - }), - ], - }, - ); - - self.pending_inline_assist_ids_by_editor - .entry(editor.downgrade()) - .or_default() - .push(inline_assist_id); - self.update_highlights_for_editor(editor, cx); - } - - fn handle_inline_assistant_event( - &mut self, - inline_assistant: View, - event: &InlineAssistantEvent, - cx: &mut ViewContext, - ) { - let assist_id = inline_assistant.read(cx).id; - match event { - InlineAssistantEvent::Confirmed { - prompt, - include_conversation, - } => { - self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx); - } - InlineAssistantEvent::Canceled => { - self.finish_inline_assist(assist_id, true, cx); - } - InlineAssistantEvent::Dismissed => { - self.hide_inline_assist(assist_id, cx); - } - } - } - fn cancel_last_inline_assist( - workspace: &mut Workspace, + _workspace: &mut Workspace, _: &editor::actions::Cancel, cx: &mut ViewContext, ) { - if let Some(panel) = workspace.panel::(cx) { - if let Some(editor) = workspace - .active_item(cx) - .and_then(|item| item.downcast::()) - { - let handled = panel.update(cx, |panel, cx| { - if let Some(assist_id) = panel - .pending_inline_assist_ids_by_editor - .get(&editor.downgrade()) - .and_then(|assist_ids| assist_ids.last().copied()) - { - panel.finish_inline_assist(assist_id, true, cx); - true - } else { - false - } - }); - if handled { - return; - } - } - } - - cx.propagate(); - } - - fn finish_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext) { - self.hide_inline_assist(assist_id, cx); - - if let Some(pending_assist) = self.pending_inline_assists.remove(&assist_id) { - if let hash_map::Entry::Occupied(mut entry) = self - .pending_inline_assist_ids_by_editor - .entry(pending_assist.editor.clone()) - { - entry.get_mut().retain(|id| *id != assist_id); - if entry.get().is_empty() { - entry.remove(); - } - } - - if let Some(editor) = pending_assist.editor.upgrade() { - self.update_highlights_for_editor(&editor, cx); - - if undo { - pending_assist - .codegen - .update(cx, |codegen, cx| codegen.undo(cx)); - } - } - } - } - - fn hide_inline_assist(&mut self, assist_id: usize, cx: &mut ViewContext) { - if let Some(pending_assist) = self.pending_inline_assists.get_mut(&assist_id) { - if let Some(editor) = pending_assist.editor.upgrade() { - if let Some((block_id, inline_assistant)) = pending_assist.inline_assistant.take() { - editor.update(cx, |editor, cx| { - editor.remove_blocks(HashSet::from_iter([block_id]), None, cx); - if inline_assistant.focus_handle(cx).contains_focused(cx) { - editor.focus(cx); - } - }); - } - } - } - } - - fn confirm_inline_assist( - &mut self, - inline_assist_id: usize, - user_prompt: &str, - include_conversation: bool, - cx: &mut ViewContext, - ) { - let conversation = if include_conversation { - self.active_conversation_editor() - .map(|editor| editor.read(cx).conversation.clone()) - } else { - None - }; - - let pending_assist = - if let Some(pending_assist) = self.pending_inline_assists.get_mut(&inline_assist_id) { - pending_assist - } else { - return; - }; - - let editor = if let Some(editor) = pending_assist.editor.upgrade() { - editor - } else { - return; - }; - - let project = pending_assist.project.clone(); - - let project_name = project.upgrade().map(|project| { - project - .read(cx) - .worktree_root_names(cx) - .collect::>() - .join("/") - }); - - self.inline_prompt_history - .retain(|prompt| prompt != user_prompt); - self.inline_prompt_history.push_back(user_prompt.into()); - if self.inline_prompt_history.len() > Self::INLINE_PROMPT_HISTORY_MAX_LEN { - self.inline_prompt_history.pop_front(); - } - - let codegen = pending_assist.codegen.clone(); - let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); - let range = codegen.read(cx).range(); - let start = snapshot.point_to_buffer_offset(range.start); - let end = snapshot.point_to_buffer_offset(range.end); - let (buffer, range) = if let Some((start, end)) = start.zip(end) { - let (start_buffer, start_buffer_offset) = start; - let (end_buffer, end_buffer_offset) = end; - if start_buffer.remote_id() == end_buffer.remote_id() { - (start_buffer.clone(), start_buffer_offset..end_buffer_offset) - } else { - self.finish_inline_assist(inline_assist_id, false, cx); - return; - } - } else { - self.finish_inline_assist(inline_assist_id, false, cx); - return; - }; - - let language = buffer.language_at(range.start); - let language_name = if let Some(language) = language.as_ref() { - if Arc::ptr_eq(language, &language::PLAIN_TEXT) { - None - } else { - Some(language.name()) - } - } else { - None - }; - - // Higher Temperature increases the randomness of model outputs. - // If Markdown or No Language is Known, increase the randomness for more creative output - // If Code, decrease temperature to get more deterministic outputs - let temperature = if let Some(language) = language_name.clone() { - if language.as_ref() == "Markdown" { - 1.0 - } else { - 0.5 - } - } else { - 1.0 - }; - - let user_prompt = user_prompt.to_string(); - - let prompt = cx.background_executor().spawn(async move { - let language_name = language_name.as_deref(); - generate_content_prompt(user_prompt, language_name, buffer, range, project_name) + let canceled = InlineAssistant::update_global(cx, |assistant, cx| { + assistant.cancel_last_inline_assist(cx) }); - - let mut messages = Vec::new(); - if let Some(conversation) = conversation { - let conversation = conversation.read(cx); - let buffer = conversation.buffer.read(cx); - messages.extend( - conversation - .messages(cx) - .map(|message| message.to_request_message(buffer)), - ); + if !canceled { + cx.propagate(); } - let model = CompletionProvider::global(cx).model(); - - cx.spawn(|_, mut cx| async move { - // I Don't know if we want to return a ? here. - let prompt = prompt.await?; - - messages.push(LanguageModelRequestMessage { - role: Role::User, - content: prompt, - }); - - let request = LanguageModelRequest { - model, - messages, - stop: vec!["|END|>".to_string()], - temperature, - }; - - codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?; - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } - - fn update_highlights_for_editor(&self, editor: &View, cx: &mut ViewContext) { - let mut background_ranges = Vec::new(); - let mut foreground_ranges = Vec::new(); - let empty_inline_assist_ids = Vec::new(); - let inline_assist_ids = self - .pending_inline_assist_ids_by_editor - .get(&editor.downgrade()) - .unwrap_or(&empty_inline_assist_ids); - - for inline_assist_id in inline_assist_ids { - if let Some(pending_assist) = self.pending_inline_assists.get(inline_assist_id) { - let codegen = pending_assist.codegen.read(cx); - background_ranges.push(codegen.range()); - foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned()); - } - } - - let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); - merge_ranges(&mut background_ranges, &snapshot); - merge_ranges(&mut foreground_ranges, &snapshot); - editor.update(cx, |editor, cx| { - if background_ranges.is_empty() { - editor.clear_background_highlights::(cx); - } else { - editor.highlight_background::( - &background_ranges, - |theme| theme.editor_active_line_background, // TODO use the appropriate color - cx, - ); - } - - if foreground_ranges.is_empty() { - editor.clear_highlights::(cx); - } else { - editor.highlight_text::( - foreground_ranges, - HighlightStyle { - fade_out: Some(0.6), - ..Default::default() - }, - cx, - ); - } - }); } fn new_conversation(&mut self, cx: &mut ViewContext) -> Option> { @@ -919,6 +486,15 @@ impl AssistantPanel { Some(&self.active_conversation_editor.as_ref()?.editor) } + pub fn active_conversation(&self, cx: &AppContext) -> Option> { + Some( + self.active_conversation_editor()? + .read(cx) + .conversation + .clone(), + ) + } + fn render_popover_button(&self, cx: &mut ViewContext) -> impl IntoElement { let assistant = cx.view().clone(); let zoomed = self.zoomed; @@ -2045,7 +1621,7 @@ impl Conversation { user_messages } - fn to_completion_request(&self, cx: &mut ModelContext) -> LanguageModelRequest { + pub fn to_completion_request(&self, cx: &AppContext) -> LanguageModelRequest { let messages = self .messages(cx) .filter(|message| matches!(message.status, MessageStatus::Done)) @@ -2919,16 +2495,8 @@ impl ConversationEditor { ) } }; - let render_trailer = { - let confirm_command = confirm_command.clone(); - move |row, _, cx: &mut WindowContext| { - render_pending_slash_command_trailer( - row, - confirm_command.clone(), - cx, - ) - } - }; + let render_trailer = + |_row, _unfold, _cx: &mut WindowContext| Empty.into_any(); let start = buffer .anchor_in_excerpt(excerpt_id, command.source_range.start) @@ -3517,239 +3085,6 @@ impl Message { } } -enum InlineAssistantEvent { - Confirmed { - prompt: String, - include_conversation: bool, - }, - Canceled, - Dismissed, -} - -struct InlineAssistant { - id: usize, - prompt_editor: View, - confirmed: bool, - include_conversation: bool, - gutter_dimensions: Arc>, - prompt_history: VecDeque, - prompt_history_ix: Option, - pending_prompt: String, - codegen: Model, - _subscriptions: Vec, -} - -impl EventEmitter for InlineAssistant {} - -impl Render for InlineAssistant { - fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - let gutter_dimensions = *self.gutter_dimensions.lock(); - let icon_size = IconSize::default(); - h_flex() - .w_full() - .py_2() - .border_y_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().editor_background) - .on_action(cx.listener(Self::confirm)) - .on_action(cx.listener(Self::cancel)) - .on_action(cx.listener(Self::move_up)) - .on_action(cx.listener(Self::move_down)) - .child( - h_flex() - .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0)) - .pr(gutter_dimensions.fold_area_width()) - .justify_end() - .children(if let Some(error) = self.codegen.read(cx).error() { - let error_message = SharedString::from(error.to_string()); - Some( - div() - .id("error") - .tooltip(move |cx| Tooltip::text(error_message.clone(), cx)) - .child( - Icon::new(IconName::XCircle) - .size(icon_size) - .color(Color::Error), - ), - ) - } else { - None - }), - ) - .child(h_flex().flex_1().child(self.render_prompt_editor(cx))) - } -} - -impl FocusableView for InlineAssistant { - fn focus_handle(&self, cx: &AppContext) -> FocusHandle { - self.prompt_editor.focus_handle(cx) - } -} - -impl InlineAssistant { - #[allow(clippy::too_many_arguments)] - fn new( - id: usize, - gutter_dimensions: Arc>, - include_conversation: bool, - prompt_history: VecDeque, - codegen: Model, - cx: &mut ViewContext, - ) -> Self { - let prompt_editor = cx.new_view(|cx| { - let mut editor = Editor::single_line(cx); - let placeholder = match codegen.read(cx).kind() { - CodegenKind::Transform { .. } => "Enter transformation prompt…", - CodegenKind::Generate { .. } => "Enter generation prompt…", - }; - editor.set_placeholder_text(placeholder, cx); - editor - }); - cx.focus_view(&prompt_editor); - - let subscriptions = vec![ - cx.observe(&codegen, Self::handle_codegen_changed), - cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events), - ]; - - Self { - id, - prompt_editor, - confirmed: false, - include_conversation, - gutter_dimensions, - prompt_history, - prompt_history_ix: None, - pending_prompt: String::new(), - codegen, - _subscriptions: subscriptions, - } - } - - fn handle_prompt_editor_events( - &mut self, - _: View, - event: &EditorEvent, - cx: &mut ViewContext, - ) { - if let EditorEvent::Edited = event { - self.pending_prompt = self.prompt_editor.read(cx).text(cx); - cx.notify(); - } - } - - fn handle_codegen_changed(&mut self, _: Model, cx: &mut ViewContext) { - let is_read_only = !self.codegen.read(cx).idle(); - self.prompt_editor.update(cx, |editor, cx| { - let was_read_only = editor.read_only(cx); - if was_read_only != is_read_only { - if is_read_only { - editor.set_read_only(true); - } else { - self.confirmed = false; - editor.set_read_only(false); - } - } - }); - cx.notify(); - } - - fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext) { - cx.emit(InlineAssistantEvent::Canceled); - } - - fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { - if self.confirmed { - cx.emit(InlineAssistantEvent::Dismissed); - } else { - let prompt = self.prompt_editor.read(cx).text(cx); - self.prompt_editor - .update(cx, |editor, _cx| editor.set_read_only(true)); - cx.emit(InlineAssistantEvent::Confirmed { - prompt, - include_conversation: self.include_conversation, - }); - self.confirmed = true; - cx.notify(); - } - } - - fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext) { - if let Some(ix) = self.prompt_history_ix { - if ix > 0 { - self.prompt_history_ix = Some(ix - 1); - let prompt = self.prompt_history[ix - 1].clone(); - self.set_prompt(&prompt, cx); - } - } else if !self.prompt_history.is_empty() { - self.prompt_history_ix = Some(self.prompt_history.len() - 1); - let prompt = self.prompt_history[self.prompt_history.len() - 1].clone(); - self.set_prompt(&prompt, cx); - } - } - - fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext) { - if let Some(ix) = self.prompt_history_ix { - if ix < self.prompt_history.len() - 1 { - self.prompt_history_ix = Some(ix + 1); - let prompt = self.prompt_history[ix + 1].clone(); - self.set_prompt(&prompt, cx); - } else { - self.prompt_history_ix = None; - let pending_prompt = self.pending_prompt.clone(); - self.set_prompt(&pending_prompt, cx); - } - } - } - - fn set_prompt(&mut self, prompt: &str, cx: &mut ViewContext) { - self.prompt_editor.update(cx, |editor, cx| { - editor.buffer().update(cx, |buffer, cx| { - let len = buffer.len(cx); - buffer.edit([(0..len, prompt)], None, cx); - }); - }); - } - - fn render_prompt_editor(&self, cx: &mut ViewContext) -> impl IntoElement { - let settings = ThemeSettings::get_global(cx); - let text_style = TextStyle { - color: if self.prompt_editor.read(cx).read_only(cx) { - cx.theme().colors().text_disabled - } else { - cx.theme().colors().text - }, - font_family: settings.ui_font.family.clone(), - font_features: settings.ui_font.features.clone(), - font_size: rems(0.875).into(), - font_weight: FontWeight::NORMAL, - font_style: FontStyle::Normal, - line_height: relative(1.3), - background_color: None, - underline: None, - strikethrough: None, - white_space: WhiteSpace::Normal, - }; - EditorElement::new( - &self.prompt_editor, - EditorStyle { - background: cx.theme().colors().editor_background, - local_player: cx.theme().players().local(), - text: text_style, - ..Default::default() - }, - ) - } -} - -struct PendingInlineAssist { - editor: WeakView, - inline_assistant: Option<(BlockId, View)>, - codegen: Model, - _subscriptions: Vec, - project: WeakModel, -} - type ToggleFold = Arc; fn render_slash_command_output_toggle( @@ -3801,47 +3136,6 @@ fn render_pending_slash_command_gutter_decoration( icon.into_any_element() } -fn render_pending_slash_command_trailer( - _row: MultiBufferRow, - _confirm_command: Arc, - _cx: &mut WindowContext, -) -> AnyElement { - Empty.into_any() - // ButtonLike::new(("run_button", row.0)) - // .style(ButtonStyle::Filled) - // .size(ButtonSize::Compact) - // .layer(ElevationIndex::ModalSurface) - // .children( - // KeyBinding::for_action(&Confirm, cx) - // .map(|binding| binding.icon_size(IconSize::XSmall).into_any_element()), - // ) - // .child(Label::new("Run").size(LabelSize::XSmall)) - // .on_click(move |_, cx| confirm_command(cx)) - // .into_any_element() -} - -fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { - ranges.sort_unstable_by(|a, b| { - a.start - .cmp(&b.start, buffer) - .then_with(|| b.end.cmp(&a.end, buffer)) - }); - - let mut ix = 0; - while ix + 1 < ranges.len() { - let b = ranges[ix + 1].clone(); - let a = &mut ranges[ix]; - if a.end.cmp(&b.start, buffer).is_gt() { - if a.end.cmp(&b.end, buffer).is_lt() { - a.end = b.end; - } - ranges.remove(ix + 1); - } else { - ix += 1; - } - } -} - fn make_lsp_adapter_delegate( project: &Model, cx: &mut AppContext, diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs deleted file mode 100644 index 2a725189ebf186d439072a26bea909832850daa8..0000000000000000000000000000000000000000 --- a/crates/assistant/src/codegen.rs +++ /dev/null @@ -1,704 +0,0 @@ -use crate::{ - streaming_diff::{Hunk, StreamingDiff}, - CompletionProvider, LanguageModelRequest, -}; -use anyhow::Result; -use client::telemetry::Telemetry; -use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; -use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; -use gpui::{EventEmitter, Model, ModelContext, Task}; -use language::{Rope, TransactionId}; -use multi_buffer::MultiBufferRow; -use std::{cmp, future, ops::Range, sync::Arc, time::Instant}; - -#[derive(Debug)] -pub enum Event { - Finished, - Undone, -} - -#[derive(Clone)] -pub enum CodegenKind { - Transform { range: Range }, - Generate { position: Anchor }, -} - -pub struct Codegen { - buffer: Model, - snapshot: MultiBufferSnapshot, - kind: CodegenKind, - last_equal_ranges: Vec>, - transaction_id: Option, - error: Option, - generation: Task<()>, - idle: bool, - telemetry: Option>, - _subscription: gpui::Subscription, -} - -impl EventEmitter for Codegen {} - -impl Codegen { - pub fn new( - buffer: Model, - kind: CodegenKind, - telemetry: Option>, - cx: &mut ModelContext, - ) -> Self { - let snapshot = buffer.read(cx).snapshot(cx); - Self { - buffer: buffer.clone(), - snapshot, - kind, - last_equal_ranges: Default::default(), - transaction_id: Default::default(), - error: Default::default(), - idle: true, - generation: Task::ready(()), - telemetry, - _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), - } - } - - fn handle_buffer_event( - &mut self, - _buffer: Model, - event: &multi_buffer::Event, - cx: &mut ModelContext, - ) { - if let multi_buffer::Event::TransactionUndone { transaction_id } = event { - if self.transaction_id == Some(*transaction_id) { - self.transaction_id = None; - self.generation = Task::ready(()); - cx.emit(Event::Undone); - } - } - } - - pub fn range(&self) -> Range { - match &self.kind { - CodegenKind::Transform { range } => range.clone(), - CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position, - } - } - - pub fn kind(&self) -> &CodegenKind { - &self.kind - } - - pub fn last_equal_ranges(&self) -> &[Range] { - &self.last_equal_ranges - } - - pub fn idle(&self) -> bool { - self.idle - } - - pub fn error(&self) -> Option<&anyhow::Error> { - self.error.as_ref() - } - - pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext) { - let range = self.range(); - let snapshot = self.snapshot.clone(); - let selected_text = snapshot - .text_for_range(range.start..range.end) - .collect::(); - - let selection_start = range.start.to_point(&snapshot); - let suggested_line_indent = snapshot - .suggested_indents(selection_start.row..selection_start.row + 1, cx) - .into_values() - .next() - .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row))); - - let model_telemetry_id = prompt.model.telemetry_id(); - let response = CompletionProvider::global(cx).complete(prompt); - let telemetry = self.telemetry.clone(); - self.generation = cx.spawn(|this, mut cx| { - async move { - let generate = async { - let mut edit_start = range.start.to_offset(&snapshot); - - let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); - let diff: Task> = - cx.background_executor().spawn(async move { - let mut response_latency = None; - let request_start = Instant::now(); - let diff = async { - let chunks = strip_invalid_spans_from_codeblock(response.await?); - futures::pin_mut!(chunks); - let mut diff = StreamingDiff::new(selected_text.to_string()); - - let mut new_text = String::new(); - let mut base_indent = None; - let mut line_indent = None; - let mut first_line = true; - - while let Some(chunk) = chunks.next().await { - if response_latency.is_none() { - response_latency = Some(request_start.elapsed()); - } - let chunk = chunk?; - - let mut lines = chunk.split('\n').peekable(); - while let Some(line) = lines.next() { - new_text.push_str(line); - if line_indent.is_none() { - if let Some(non_whitespace_ch_ix) = - new_text.find(|ch: char| !ch.is_whitespace()) - { - line_indent = Some(non_whitespace_ch_ix); - base_indent = base_indent.or(line_indent); - - let line_indent = line_indent.unwrap(); - let base_indent = base_indent.unwrap(); - let indent_delta = - line_indent as i32 - base_indent as i32; - let mut corrected_indent_len = cmp::max( - 0, - suggested_line_indent.len as i32 + indent_delta, - ) - as usize; - if first_line { - corrected_indent_len = corrected_indent_len - .saturating_sub( - selection_start.column as usize, - ); - } - - let indent_char = suggested_line_indent.char(); - let mut indent_buffer = [0; 4]; - let indent_str = - indent_char.encode_utf8(&mut indent_buffer); - new_text.replace_range( - ..line_indent, - &indent_str.repeat(corrected_indent_len), - ); - } - } - - if line_indent.is_some() { - hunks_tx.send(diff.push_new(&new_text)).await?; - new_text.clear(); - } - - if lines.peek().is_some() { - hunks_tx.send(diff.push_new("\n")).await?; - line_indent = None; - first_line = false; - } - } - } - hunks_tx.send(diff.push_new(&new_text)).await?; - hunks_tx.send(diff.finish()).await?; - - anyhow::Ok(()) - }; - - let result = diff.await; - - let error_message = - result.as_ref().err().map(|error| error.to_string()); - if let Some(telemetry) = telemetry { - telemetry.report_assistant_event( - None, - telemetry_events::AssistantKind::Inline, - model_telemetry_id, - response_latency, - error_message, - ); - } - - result?; - Ok(()) - }); - - while let Some(hunks) = hunks_rx.next().await { - this.update(&mut cx, |this, cx| { - this.last_equal_ranges.clear(); - - let transaction = this.buffer.update(cx, |buffer, cx| { - // Avoid grouping assistant edits with user edits. - buffer.finalize_last_transaction(cx); - - buffer.start_transaction(cx); - buffer.edit( - hunks.into_iter().filter_map(|hunk| match hunk { - Hunk::Insert { text } => { - let edit_start = snapshot.anchor_after(edit_start); - Some((edit_start..edit_start, text)) - } - Hunk::Remove { len } => { - let edit_end = edit_start + len; - let edit_range = snapshot.anchor_after(edit_start) - ..snapshot.anchor_before(edit_end); - edit_start = edit_end; - Some((edit_range, String::new())) - } - Hunk::Keep { len } => { - let edit_end = edit_start + len; - let edit_range = snapshot.anchor_after(edit_start) - ..snapshot.anchor_before(edit_end); - edit_start = edit_end; - this.last_equal_ranges.push(edit_range); - None - } - }), - None, - cx, - ); - - buffer.end_transaction(cx) - }); - - if let Some(transaction) = transaction { - if let Some(first_transaction) = this.transaction_id { - // Group all assistant edits into the first transaction. - this.buffer.update(cx, |buffer, cx| { - buffer.merge_transactions( - transaction, - first_transaction, - cx, - ) - }); - } else { - this.transaction_id = Some(transaction); - this.buffer.update(cx, |buffer, cx| { - buffer.finalize_last_transaction(cx) - }); - } - } - - cx.notify(); - })?; - } - - diff.await?; - - anyhow::Ok(()) - }; - - let result = generate.await; - this.update(&mut cx, |this, cx| { - this.last_equal_ranges.clear(); - this.idle = true; - if let Err(error) = result { - this.error = Some(error); - } - cx.emit(Event::Finished); - cx.notify(); - }) - .ok(); - } - }); - self.error.take(); - self.idle = false; - cx.notify(); - } - - pub fn undo(&mut self, cx: &mut ModelContext) { - if let Some(transaction_id) = self.transaction_id { - self.buffer - .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx)); - } - } -} - -fn strip_invalid_spans_from_codeblock( - stream: impl Stream>, -) -> impl Stream> { - let mut first_line = true; - let mut buffer = String::new(); - let mut starts_with_markdown_codeblock = false; - let mut includes_start_or_end_span = false; - stream.filter_map(move |chunk| { - let chunk = match chunk { - Ok(chunk) => chunk, - Err(err) => return future::ready(Some(Err(err))), - }; - buffer.push_str(&chunk); - - if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") { - includes_start_or_end_span = true; - - buffer = buffer - .strip_prefix("<|S|>") - .or_else(|| buffer.strip_prefix("<|S|")) - .unwrap_or(&buffer) - .to_string(); - } else if buffer.ends_with("|E|>") { - includes_start_or_end_span = true; - } else if buffer.starts_with("<|") - || buffer.starts_with("<|S") - || buffer.starts_with("<|S|") - || buffer.ends_with('|') - || buffer.ends_with("|E") - || buffer.ends_with("|E|") - { - return future::ready(None); - } - - if first_line { - if buffer.is_empty() || buffer == "`" || buffer == "``" { - return future::ready(None); - } else if buffer.starts_with("```") { - starts_with_markdown_codeblock = true; - if let Some(newline_ix) = buffer.find('\n') { - buffer.replace_range(..newline_ix + 1, ""); - first_line = false; - } else { - return future::ready(None); - } - } - } - - let mut text = buffer.to_string(); - if starts_with_markdown_codeblock { - text = text - .strip_suffix("\n```\n") - .or_else(|| text.strip_suffix("\n```")) - .or_else(|| text.strip_suffix("\n``")) - .or_else(|| text.strip_suffix("\n`")) - .or_else(|| text.strip_suffix('\n')) - .unwrap_or(&text) - .to_string(); - } - - if includes_start_or_end_span { - text = text - .strip_suffix("|E|>") - .or_else(|| text.strip_suffix("E|>")) - .or_else(|| text.strip_prefix("|>")) - .or_else(|| text.strip_prefix('>')) - .unwrap_or(&text) - .to_string(); - }; - - if text.contains('\n') { - first_line = false; - } - - let remainder = buffer.split_off(text.len()); - let result = if buffer.is_empty() { - None - } else { - Some(Ok(buffer.clone())) - }; - - buffer = remainder; - future::ready(result) - }) -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use crate::FakeCompletionProvider; - - use super::*; - use futures::stream::{self}; - use gpui::{Context, TestAppContext}; - use indoc::indoc; - use language::{ - language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher, - Point, - }; - use rand::prelude::*; - use serde::Serialize; - use settings::SettingsStore; - - #[derive(Serialize)] - pub struct DummyCompletionRequest { - pub name: String, - } - - #[gpui::test(iterations = 10)] - async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) { - let provider = FakeCompletionProvider::default(); - cx.set_global(cx.update(SettingsStore::test)); - cx.set_global(CompletionProvider::Fake(provider.clone())); - cx.update(language_settings::init); - - let text = indoc! {" - fn main() { - let x = 0; - for _ in 0..10 { - x += 1; - } - } - "}; - let buffer = - cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let range = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) - }); - let codegen = cx.new_model(|cx| { - Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx) - }); - - let request = LanguageModelRequest::default(); - codegen.update(cx, |codegen, cx| codegen.start(request, cx)); - - let mut new_text = concat!( - " let mut x = 0;\n", - " while x < 10 {\n", - " x += 1;\n", - " }", - ); - while !new_text.is_empty() { - let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); - let (chunk, suffix) = new_text.split_at(len); - provider.send_completion(chunk.into()); - new_text = suffix; - cx.background_executor.run_until_parked(); - } - provider.finish_completion(); - cx.background_executor.run_until_parked(); - - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - fn main() { - let mut x = 0; - while x < 10 { - x += 1; - } - } - "} - ); - } - - #[gpui::test(iterations = 10)] - async fn test_autoindent_when_generating_past_indentation( - cx: &mut TestAppContext, - mut rng: StdRng, - ) { - let provider = FakeCompletionProvider::default(); - cx.set_global(CompletionProvider::Fake(provider.clone())); - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_settings::init); - - let text = indoc! {" - fn main() { - le - } - "}; - let buffer = - cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let position = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 6)) - }); - let codegen = cx.new_model(|cx| { - Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx) - }); - - let request = LanguageModelRequest::default(); - codegen.update(cx, |codegen, cx| codegen.start(request, cx)); - - let mut new_text = concat!( - "t mut x = 0;\n", - "while x < 10 {\n", - " x += 1;\n", - "}", // - ); - while !new_text.is_empty() { - let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); - let (chunk, suffix) = new_text.split_at(len); - provider.send_completion(chunk.into()); - new_text = suffix; - cx.background_executor.run_until_parked(); - } - provider.finish_completion(); - cx.background_executor.run_until_parked(); - - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - fn main() { - let mut x = 0; - while x < 10 { - x += 1; - } - } - "} - ); - } - - #[gpui::test(iterations = 10)] - async fn test_autoindent_when_generating_before_indentation( - cx: &mut TestAppContext, - mut rng: StdRng, - ) { - let provider = FakeCompletionProvider::default(); - cx.set_global(CompletionProvider::Fake(provider.clone())); - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_settings::init); - - let text = concat!( - "fn main() {\n", - " \n", - "}\n" // - ); - let buffer = - cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let position = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 2)) - }); - let codegen = cx.new_model(|cx| { - Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx) - }); - - let request = LanguageModelRequest::default(); - codegen.update(cx, |codegen, cx| codegen.start(request, cx)); - - let mut new_text = concat!( - "let mut x = 0;\n", - "while x < 10 {\n", - " x += 1;\n", - "}", // - ); - while !new_text.is_empty() { - let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); - let (chunk, suffix) = new_text.split_at(len); - provider.send_completion(chunk.into()); - new_text = suffix; - cx.background_executor.run_until_parked(); - } - provider.finish_completion(); - cx.background_executor.run_until_parked(); - - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - fn main() { - let mut x = 0; - while x < 10 { - x += 1; - } - } - "} - ); - } - - #[gpui::test] - async fn test_strip_invalid_spans_from_codeblock() { - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks( - "```html\n```js\nLorem ipsum dolor\n```\n```", - 2 - )) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "```js\nLorem ipsum dolor\n```" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "``\nLorem ipsum dolor\n```" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum" - ); - - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum" - ); - - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum" - ); - fn chunks(text: &str, size: usize) -> impl Stream> { - stream::iter( - text.chars() - .collect::>() - .chunks(size) - .map(|chunk| Ok(chunk.iter().collect::())) - .collect::>(), - ) - } - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::language()), - ) - .with_indents_query( - r#" - (call_expression) @indent - (field_expression) @indent - (_ "(" ")" @end) @indent - (_ "{" "}" @end) @indent - "#, - ) - .unwrap() - } -} diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs new file mode 100644 index 0000000000000000000000000000000000000000..0ad72b906f110125cf350b5d41f9376dd919e4ed --- /dev/null +++ b/crates/assistant/src/inline_assistant.rs @@ -0,0 +1,1449 @@ +use crate::{ + prompts::generate_content_prompt, AssistantPanel, CompletionProvider, Hunk, + LanguageModelRequest, LanguageModelRequestMessage, Role, StreamingDiff, +}; +use anyhow::Result; +use client::telemetry::Telemetry; +use collections::{hash_map, HashMap, HashSet, VecDeque}; +use editor::{ + actions::{MoveDown, MoveUp}, + display_map::{BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle}, + scroll::{Autoscroll, AutoscrollStrategy}, + Anchor, Editor, EditorElement, EditorEvent, EditorStyle, GutterDimensions, MultiBuffer, + MultiBufferSnapshot, ToOffset, ToPoint, +}; +use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; +use gpui::{ + AnyWindowHandle, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight, + Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, + ViewContext, WeakView, WhiteSpace, WindowContext, +}; +use language::{Point, TransactionId}; +use multi_buffer::MultiBufferRow; +use parking_lot::Mutex; +use rope::Rope; +use settings::Settings; +use std::{cmp, future, ops::Range, sync::Arc, time::Instant}; +use theme::ThemeSettings; +use ui::{prelude::*, Tooltip}; +use workspace::{notifications::NotificationId, Toast, Workspace}; + +pub fn init(telemetry: Arc, cx: &mut AppContext) { + cx.set_global(InlineAssistant::new(telemetry)); +} + +const PROMPT_HISTORY_MAX_LEN: usize = 20; + +pub struct InlineAssistant { + next_assist_id: InlineAssistId, + pending_assists: HashMap, + pending_assist_ids_by_editor: HashMap, EditorPendingAssists>, + prompt_history: VecDeque, + telemetry: Option>, +} + +struct EditorPendingAssists { + window: AnyWindowHandle, + assist_ids: Vec, +} + +impl Global for InlineAssistant {} + +impl InlineAssistant { + pub fn new(telemetry: Arc) -> Self { + Self { + next_assist_id: InlineAssistId::default(), + pending_assists: HashMap::default(), + pending_assist_ids_by_editor: HashMap::default(), + prompt_history: VecDeque::default(), + telemetry: Some(telemetry), + } + } + + pub fn assist( + &mut self, + editor: &View, + workspace: Option>, + include_conversation: bool, + cx: &mut WindowContext, + ) { + let selection = editor.read(cx).selections.newest_anchor().clone(); + if selection.start.excerpt_id != selection.end.excerpt_id { + return; + } + let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); + + // Extend the selection to the start and the end of the line. + let mut point_selection = selection.map(|selection| selection.to_point(&snapshot)); + if point_selection.end > point_selection.start { + point_selection.start.column = 0; + // If the selection ends at the start of the line, we don't want to include it. + if point_selection.end.column == 0 { + point_selection.end.row -= 1; + } + point_selection.end.column = snapshot.line_len(MultiBufferRow(point_selection.end.row)); + } + + let codegen_kind = if point_selection.start == point_selection.end { + CodegenKind::Generate { + position: snapshot.anchor_after(point_selection.start), + } + } else { + CodegenKind::Transform { + range: snapshot.anchor_before(point_selection.start) + ..snapshot.anchor_after(point_selection.end), + } + }; + + let inline_assist_id = self.next_assist_id.post_inc(); + let codegen = cx.new_model(|cx| { + Codegen::new( + editor.read(cx).buffer().clone(), + codegen_kind, + self.telemetry.clone(), + cx, + ) + }); + + let measurements = Arc::new(Mutex::new(GutterDimensions::default())); + let inline_assistant = cx.new_view(|cx| { + InlineAssistEditor::new( + inline_assist_id, + measurements.clone(), + self.prompt_history.clone(), + codegen.clone(), + cx, + ) + }); + let block_id = editor.update(cx, |editor, cx| { + editor.change_selections(None, cx, |selections| { + selections.select_anchor_ranges([selection.head()..selection.head()]) + }); + editor.insert_blocks( + [BlockProperties { + style: BlockStyle::Flex, + position: snapshot.anchor_before(Point::new(point_selection.head().row, 0)), + height: 2, + render: Box::new({ + let inline_assistant = inline_assistant.clone(); + move |cx: &mut BlockContext| { + *measurements.lock() = *cx.gutter_dimensions; + inline_assistant.clone().into_any_element() + } + }), + disposition: if selection.reversed { + BlockDisposition::Above + } else { + BlockDisposition::Below + }, + }], + Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)), + cx, + )[0] + }); + + self.pending_assists.insert( + inline_assist_id, + PendingInlineAssist { + include_conversation, + editor: editor.downgrade(), + inline_assistant: Some((block_id, inline_assistant.clone())), + codegen: codegen.clone(), + workspace, + _subscriptions: vec![ + cx.subscribe(&inline_assistant, |inline_assistant, event, cx| { + InlineAssistant::update_global(cx, |this, cx| { + this.handle_inline_assistant_event(inline_assistant, event, cx) + }) + }), + cx.subscribe(editor, { + let inline_assistant = inline_assistant.downgrade(); + move |editor, event, cx| { + if let Some(inline_assistant) = inline_assistant.upgrade() { + if let EditorEvent::SelectionsChanged { local } = event { + if *local + && inline_assistant.focus_handle(cx).contains_focused(cx) + { + cx.focus_view(&editor); + } + } + } + } + }), + cx.observe(&codegen, { + let editor = editor.downgrade(); + move |_, cx| { + if let Some(editor) = editor.upgrade() { + InlineAssistant::update_global(cx, |this, cx| { + this.update_highlights_for_editor(&editor, cx); + }) + } + } + }), + cx.subscribe(&codegen, move |codegen, event, cx| { + InlineAssistant::update_global(cx, |this, cx| match event { + CodegenEvent::Undone => { + this.finish_inline_assist(inline_assist_id, false, cx) + } + CodegenEvent::Finished => { + let pending_assist = if let Some(pending_assist) = + this.pending_assists.get(&inline_assist_id) + { + pending_assist + } else { + return; + }; + + let error = codegen + .read(cx) + .error() + .map(|error| format!("Inline assistant error: {}", error)); + if let Some(error) = error { + if pending_assist.inline_assistant.is_none() { + if let Some(workspace) = pending_assist + .workspace + .as_ref() + .and_then(|workspace| workspace.upgrade()) + { + workspace.update(cx, |workspace, cx| { + struct InlineAssistantError; + + let id = NotificationId::identified::< + InlineAssistantError, + >( + inline_assist_id.0 + ); + + workspace.show_toast(Toast::new(id, error), cx); + }) + } + + this.finish_inline_assist(inline_assist_id, false, cx); + } + } else { + this.finish_inline_assist(inline_assist_id, false, cx); + } + } + }) + }), + ], + }, + ); + + self.pending_assist_ids_by_editor + .entry(editor.downgrade()) + .or_insert_with(|| EditorPendingAssists { + window: cx.window_handle(), + assist_ids: Vec::new(), + }) + .assist_ids + .push(inline_assist_id); + self.update_highlights_for_editor(editor, cx); + } + + fn handle_inline_assistant_event( + &mut self, + inline_assistant: View, + event: &InlineAssistEditorEvent, + cx: &mut WindowContext, + ) { + let assist_id = inline_assistant.read(cx).id; + match event { + InlineAssistEditorEvent::Confirmed { prompt } => { + self.confirm_inline_assist(assist_id, prompt, cx); + } + InlineAssistEditorEvent::Canceled => { + self.finish_inline_assist(assist_id, true, cx); + } + InlineAssistEditorEvent::Dismissed => { + self.hide_inline_assist(assist_id, cx); + } + } + } + + pub fn cancel_last_inline_assist(&mut self, cx: &mut WindowContext) -> bool { + for (editor, pending_assists) in &self.pending_assist_ids_by_editor { + if pending_assists.window == cx.window_handle() { + if let Some(editor) = editor.upgrade() { + if editor.read(cx).is_focused(cx) { + if let Some(assist_id) = pending_assists.assist_ids.last().copied() { + self.finish_inline_assist(assist_id, true, cx); + return true; + } + } + } + } + } + false + } + + fn finish_inline_assist( + &mut self, + assist_id: InlineAssistId, + undo: bool, + cx: &mut WindowContext, + ) { + self.hide_inline_assist(assist_id, cx); + + if let Some(pending_assist) = self.pending_assists.remove(&assist_id) { + if let hash_map::Entry::Occupied(mut entry) = self + .pending_assist_ids_by_editor + .entry(pending_assist.editor.clone()) + { + entry.get_mut().assist_ids.retain(|id| *id != assist_id); + if entry.get().assist_ids.is_empty() { + entry.remove(); + } + } + + if let Some(editor) = pending_assist.editor.upgrade() { + self.update_highlights_for_editor(&editor, cx); + + if undo { + pending_assist + .codegen + .update(cx, |codegen, cx| codegen.undo(cx)); + } + } + } + } + + fn hide_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) { + if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) { + if let Some(editor) = pending_assist.editor.upgrade() { + if let Some((block_id, inline_assistant)) = pending_assist.inline_assistant.take() { + editor.update(cx, |editor, cx| { + editor.remove_blocks(HashSet::from_iter([block_id]), None, cx); + if inline_assistant.focus_handle(cx).contains_focused(cx) { + editor.focus(cx); + } + }); + } + } + } + } + + fn confirm_inline_assist( + &mut self, + assist_id: InlineAssistId, + user_prompt: &str, + cx: &mut WindowContext, + ) { + let pending_assist = if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) + { + pending_assist + } else { + return; + }; + + let conversation = if pending_assist.include_conversation { + pending_assist.workspace.as_ref().and_then(|workspace| { + let workspace = workspace.upgrade()?.read(cx); + let assistant_panel = workspace.panel::(cx)?; + assistant_panel.read(cx).active_conversation(cx) + }) + } else { + None + }; + + let editor = if let Some(editor) = pending_assist.editor.upgrade() { + editor + } else { + return; + }; + + let project_name = pending_assist.workspace.as_ref().and_then(|workspace| { + let workspace = workspace.upgrade()?; + Some( + workspace + .read(cx) + .project() + .read(cx) + .worktree_root_names(cx) + .collect::>() + .join("/"), + ) + }); + + self.prompt_history.retain(|prompt| prompt != user_prompt); + self.prompt_history.push_back(user_prompt.into()); + if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN { + self.prompt_history.pop_front(); + } + + let codegen = pending_assist.codegen.clone(); + let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); + let range = codegen.read(cx).range(); + let start = snapshot.point_to_buffer_offset(range.start); + let end = snapshot.point_to_buffer_offset(range.end); + let (buffer, range) = if let Some((start, end)) = start.zip(end) { + let (start_buffer, start_buffer_offset) = start; + let (end_buffer, end_buffer_offset) = end; + if start_buffer.remote_id() == end_buffer.remote_id() { + (start_buffer.clone(), start_buffer_offset..end_buffer_offset) + } else { + self.finish_inline_assist(assist_id, false, cx); + return; + } + } else { + self.finish_inline_assist(assist_id, false, cx); + return; + }; + + let language = buffer.language_at(range.start); + let language_name = if let Some(language) = language.as_ref() { + if Arc::ptr_eq(language, &language::PLAIN_TEXT) { + None + } else { + Some(language.name()) + } + } else { + None + }; + + // Higher Temperature increases the randomness of model outputs. + // If Markdown or No Language is Known, increase the randomness for more creative output + // If Code, decrease temperature to get more deterministic outputs + let temperature = if let Some(language) = language_name.clone() { + if language.as_ref() == "Markdown" { + 1.0 + } else { + 0.5 + } + } else { + 1.0 + }; + + let user_prompt = user_prompt.to_string(); + + let prompt = cx.background_executor().spawn(async move { + let language_name = language_name.as_deref(); + generate_content_prompt(user_prompt, language_name, buffer, range, project_name) + }); + + let mut messages = Vec::new(); + if let Some(conversation) = conversation { + let request = conversation.read(cx).to_completion_request(cx); + messages = request.messages; + } + let model = CompletionProvider::global(cx).model(); + + cx.spawn(|mut cx| async move { + let prompt = prompt.await?; + + messages.push(LanguageModelRequestMessage { + role: Role::User, + content: prompt, + }); + + let request = LanguageModelRequest { + model, + messages, + stop: vec!["|END|>".to_string()], + temperature, + }; + + codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + + fn update_highlights_for_editor(&self, editor: &View, cx: &mut WindowContext) { + let mut background_ranges = Vec::new(); + let mut foreground_ranges = Vec::new(); + let empty_inline_assist_ids = Vec::new(); + let inline_assist_ids = self + .pending_assist_ids_by_editor + .get(&editor.downgrade()) + .map_or(&empty_inline_assist_ids, |pending_assists| { + &pending_assists.assist_ids + }); + + for inline_assist_id in inline_assist_ids { + if let Some(pending_assist) = self.pending_assists.get(inline_assist_id) { + let codegen = pending_assist.codegen.read(cx); + background_ranges.push(codegen.range()); + foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned()); + } + } + + let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); + merge_ranges(&mut background_ranges, &snapshot); + merge_ranges(&mut foreground_ranges, &snapshot); + editor.update(cx, |editor, cx| { + if background_ranges.is_empty() { + editor.clear_background_highlights::(cx); + } else { + editor.highlight_background::( + &background_ranges, + |theme| theme.editor_active_line_background, // TODO use the appropriate color + cx, + ); + } + + if foreground_ranges.is_empty() { + editor.clear_highlights::(cx); + } else { + editor.highlight_text::( + foreground_ranges, + HighlightStyle { + fade_out: Some(0.6), + ..Default::default() + }, + cx, + ); + } + }); + } +} + +#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] +struct InlineAssistId(usize); + +impl InlineAssistId { + fn post_inc(&mut self) -> InlineAssistId { + let id = *self; + self.0 += 1; + id + } +} + +enum InlineAssistEditorEvent { + Confirmed { prompt: String }, + Canceled, + Dismissed, +} + +struct InlineAssistEditor { + id: InlineAssistId, + prompt_editor: View, + confirmed: bool, + gutter_dimensions: Arc>, + prompt_history: VecDeque, + prompt_history_ix: Option, + pending_prompt: String, + codegen: Model, + _subscriptions: Vec, +} + +impl EventEmitter for InlineAssistEditor {} + +impl Render for InlineAssistEditor { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let gutter_dimensions = *self.gutter_dimensions.lock(); + let icon_size = IconSize::default(); + h_flex() + .w_full() + .py_2() + .border_y_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().editor_background) + .on_action(cx.listener(Self::confirm)) + .on_action(cx.listener(Self::cancel)) + .on_action(cx.listener(Self::move_up)) + .on_action(cx.listener(Self::move_down)) + .child( + h_flex() + .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0)) + .pr(gutter_dimensions.fold_area_width()) + .justify_end() + .children(if let Some(error) = self.codegen.read(cx).error() { + let error_message = SharedString::from(error.to_string()); + Some( + div() + .id("error") + .tooltip(move |cx| Tooltip::text(error_message.clone(), cx)) + .child( + Icon::new(IconName::XCircle) + .size(icon_size) + .color(Color::Error), + ), + ) + } else { + None + }), + ) + .child(h_flex().flex_1().child(self.render_prompt_editor(cx))) + } +} + +impl FocusableView for InlineAssistEditor { + fn focus_handle(&self, cx: &AppContext) -> FocusHandle { + self.prompt_editor.focus_handle(cx) + } +} + +impl InlineAssistEditor { + #[allow(clippy::too_many_arguments)] + fn new( + id: InlineAssistId, + gutter_dimensions: Arc>, + prompt_history: VecDeque, + codegen: Model, + cx: &mut ViewContext, + ) -> Self { + let prompt_editor = cx.new_view(|cx| { + let mut editor = Editor::single_line(cx); + let placeholder = match codegen.read(cx).kind() { + CodegenKind::Transform { .. } => "Enter transformation prompt…", + CodegenKind::Generate { .. } => "Enter generation prompt…", + }; + editor.set_placeholder_text(placeholder, cx); + editor + }); + cx.focus_view(&prompt_editor); + + let subscriptions = vec![ + cx.observe(&codegen, Self::handle_codegen_changed), + cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events), + ]; + + Self { + id, + prompt_editor, + confirmed: false, + gutter_dimensions, + prompt_history, + prompt_history_ix: None, + pending_prompt: String::new(), + codegen, + _subscriptions: subscriptions, + } + } + + fn handle_prompt_editor_events( + &mut self, + _: View, + event: &EditorEvent, + cx: &mut ViewContext, + ) { + if let EditorEvent::Edited = event { + self.pending_prompt = self.prompt_editor.read(cx).text(cx); + cx.notify(); + } + } + + fn handle_codegen_changed(&mut self, _: Model, cx: &mut ViewContext) { + let is_read_only = !self.codegen.read(cx).idle(); + self.prompt_editor.update(cx, |editor, cx| { + let was_read_only = editor.read_only(cx); + if was_read_only != is_read_only { + if is_read_only { + editor.set_read_only(true); + } else { + self.confirmed = false; + editor.set_read_only(false); + } + } + }); + cx.notify(); + } + + fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext) { + cx.emit(InlineAssistEditorEvent::Canceled); + } + + fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + if self.confirmed { + cx.emit(InlineAssistEditorEvent::Dismissed); + } else { + let prompt = self.prompt_editor.read(cx).text(cx); + self.prompt_editor + .update(cx, |editor, _cx| editor.set_read_only(true)); + cx.emit(InlineAssistEditorEvent::Confirmed { prompt }); + self.confirmed = true; + cx.notify(); + } + } + + fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext) { + if let Some(ix) = self.prompt_history_ix { + if ix > 0 { + self.prompt_history_ix = Some(ix - 1); + let prompt = self.prompt_history[ix - 1].clone(); + self.set_prompt(&prompt, cx); + } + } else if !self.prompt_history.is_empty() { + self.prompt_history_ix = Some(self.prompt_history.len() - 1); + let prompt = self.prompt_history[self.prompt_history.len() - 1].clone(); + self.set_prompt(&prompt, cx); + } + } + + fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext) { + if let Some(ix) = self.prompt_history_ix { + if ix < self.prompt_history.len() - 1 { + self.prompt_history_ix = Some(ix + 1); + let prompt = self.prompt_history[ix + 1].clone(); + self.set_prompt(&prompt, cx); + } else { + self.prompt_history_ix = None; + let pending_prompt = self.pending_prompt.clone(); + self.set_prompt(&pending_prompt, cx); + } + } + } + + fn set_prompt(&mut self, prompt: &str, cx: &mut ViewContext) { + self.prompt_editor.update(cx, |editor, cx| { + editor.buffer().update(cx, |buffer, cx| { + let len = buffer.len(cx); + buffer.edit([(0..len, prompt)], None, cx); + }); + }); + } + + fn render_prompt_editor(&self, cx: &mut ViewContext) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: if self.prompt_editor.read(cx).read_only(cx) { + cx.theme().colors().text_disabled + } else { + cx.theme().colors().text + }, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features.clone(), + font_size: rems(0.875).into(), + font_weight: FontWeight::NORMAL, + font_style: FontStyle::Normal, + line_height: relative(1.3), + background_color: None, + underline: None, + strikethrough: None, + white_space: WhiteSpace::Normal, + }; + EditorElement::new( + &self.prompt_editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } +} + +struct PendingInlineAssist { + editor: WeakView, + inline_assistant: Option<(BlockId, View)>, + codegen: Model, + _subscriptions: Vec, + workspace: Option>, + include_conversation: bool, +} + +#[derive(Debug)] +pub enum CodegenEvent { + Finished, + Undone, +} + +#[derive(Clone)] +pub enum CodegenKind { + Transform { range: Range }, + Generate { position: Anchor }, +} + +pub struct Codegen { + buffer: Model, + snapshot: MultiBufferSnapshot, + kind: CodegenKind, + last_equal_ranges: Vec>, + transaction_id: Option, + error: Option, + generation: Task<()>, + idle: bool, + telemetry: Option>, + _subscription: gpui::Subscription, +} + +impl EventEmitter for Codegen {} + +impl Codegen { + pub fn new( + buffer: Model, + kind: CodegenKind, + telemetry: Option>, + cx: &mut ModelContext, + ) -> Self { + let snapshot = buffer.read(cx).snapshot(cx); + Self { + buffer: buffer.clone(), + snapshot, + kind, + last_equal_ranges: Default::default(), + transaction_id: Default::default(), + error: Default::default(), + idle: true, + generation: Task::ready(()), + telemetry, + _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), + } + } + + fn handle_buffer_event( + &mut self, + _buffer: Model, + event: &multi_buffer::Event, + cx: &mut ModelContext, + ) { + if let multi_buffer::Event::TransactionUndone { transaction_id } = event { + if self.transaction_id == Some(*transaction_id) { + self.transaction_id = None; + self.generation = Task::ready(()); + cx.emit(CodegenEvent::Undone); + } + } + } + + pub fn range(&self) -> Range { + match &self.kind { + CodegenKind::Transform { range } => range.clone(), + CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position, + } + } + + pub fn kind(&self) -> &CodegenKind { + &self.kind + } + + pub fn last_equal_ranges(&self) -> &[Range] { + &self.last_equal_ranges + } + + pub fn idle(&self) -> bool { + self.idle + } + + pub fn error(&self) -> Option<&anyhow::Error> { + self.error.as_ref() + } + + pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext) { + let range = self.range(); + let snapshot = self.snapshot.clone(); + let selected_text = snapshot + .text_for_range(range.start..range.end) + .collect::(); + + let selection_start = range.start.to_point(&snapshot); + let suggested_line_indent = snapshot + .suggested_indents(selection_start.row..selection_start.row + 1, cx) + .into_values() + .next() + .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row))); + + let model_telemetry_id = prompt.model.telemetry_id(); + let response = CompletionProvider::global(cx).complete(prompt); + let telemetry = self.telemetry.clone(); + self.generation = cx.spawn(|this, mut cx| { + async move { + let generate = async { + let mut edit_start = range.start.to_offset(&snapshot); + + let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); + let diff: Task> = + cx.background_executor().spawn(async move { + let mut response_latency = None; + let request_start = Instant::now(); + let diff = async { + let chunks = strip_invalid_spans_from_codeblock(response.await?); + futures::pin_mut!(chunks); + let mut diff = StreamingDiff::new(selected_text.to_string()); + + let mut new_text = String::new(); + let mut base_indent = None; + let mut line_indent = None; + let mut first_line = true; + + while let Some(chunk) = chunks.next().await { + if response_latency.is_none() { + response_latency = Some(request_start.elapsed()); + } + let chunk = chunk?; + + let mut lines = chunk.split('\n').peekable(); + while let Some(line) = lines.next() { + new_text.push_str(line); + if line_indent.is_none() { + if let Some(non_whitespace_ch_ix) = + new_text.find(|ch: char| !ch.is_whitespace()) + { + line_indent = Some(non_whitespace_ch_ix); + base_indent = base_indent.or(line_indent); + + let line_indent = line_indent.unwrap(); + let base_indent = base_indent.unwrap(); + let indent_delta = + line_indent as i32 - base_indent as i32; + let mut corrected_indent_len = cmp::max( + 0, + suggested_line_indent.len as i32 + indent_delta, + ) + as usize; + if first_line { + corrected_indent_len = corrected_indent_len + .saturating_sub( + selection_start.column as usize, + ); + } + + let indent_char = suggested_line_indent.char(); + let mut indent_buffer = [0; 4]; + let indent_str = + indent_char.encode_utf8(&mut indent_buffer); + new_text.replace_range( + ..line_indent, + &indent_str.repeat(corrected_indent_len), + ); + } + } + + if line_indent.is_some() { + hunks_tx.send(diff.push_new(&new_text)).await?; + new_text.clear(); + } + + if lines.peek().is_some() { + hunks_tx.send(diff.push_new("\n")).await?; + line_indent = None; + first_line = false; + } + } + } + hunks_tx.send(diff.push_new(&new_text)).await?; + hunks_tx.send(diff.finish()).await?; + + anyhow::Ok(()) + }; + + let result = diff.await; + + let error_message = + result.as_ref().err().map(|error| error.to_string()); + if let Some(telemetry) = telemetry { + telemetry.report_assistant_event( + None, + telemetry_events::AssistantKind::Inline, + model_telemetry_id, + response_latency, + error_message, + ); + } + + result?; + Ok(()) + }); + + while let Some(hunks) = hunks_rx.next().await { + this.update(&mut cx, |this, cx| { + this.last_equal_ranges.clear(); + + let transaction = this.buffer.update(cx, |buffer, cx| { + // Avoid grouping assistant edits with user edits. + buffer.finalize_last_transaction(cx); + + buffer.start_transaction(cx); + buffer.edit( + hunks.into_iter().filter_map(|hunk| match hunk { + Hunk::Insert { text } => { + let edit_start = snapshot.anchor_after(edit_start); + Some((edit_start..edit_start, text)) + } + Hunk::Remove { len } => { + let edit_end = edit_start + len; + let edit_range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + edit_start = edit_end; + Some((edit_range, String::new())) + } + Hunk::Keep { len } => { + let edit_end = edit_start + len; + let edit_range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + edit_start = edit_end; + this.last_equal_ranges.push(edit_range); + None + } + }), + None, + cx, + ); + + buffer.end_transaction(cx) + }); + + if let Some(transaction) = transaction { + if let Some(first_transaction) = this.transaction_id { + // Group all assistant edits into the first transaction. + this.buffer.update(cx, |buffer, cx| { + buffer.merge_transactions( + transaction, + first_transaction, + cx, + ) + }); + } else { + this.transaction_id = Some(transaction); + this.buffer.update(cx, |buffer, cx| { + buffer.finalize_last_transaction(cx) + }); + } + } + + cx.notify(); + })?; + } + + diff.await?; + + anyhow::Ok(()) + }; + + let result = generate.await; + this.update(&mut cx, |this, cx| { + this.last_equal_ranges.clear(); + this.idle = true; + if let Err(error) = result { + this.error = Some(error); + } + cx.emit(CodegenEvent::Finished); + cx.notify(); + }) + .ok(); + } + }); + self.error.take(); + self.idle = false; + cx.notify(); + } + + pub fn undo(&mut self, cx: &mut ModelContext) { + if let Some(transaction_id) = self.transaction_id { + self.buffer + .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx)); + } + } +} + +fn strip_invalid_spans_from_codeblock( + stream: impl Stream>, +) -> impl Stream> { + let mut first_line = true; + let mut buffer = String::new(); + let mut starts_with_markdown_codeblock = false; + let mut includes_start_or_end_span = false; + stream.filter_map(move |chunk| { + let chunk = match chunk { + Ok(chunk) => chunk, + Err(err) => return future::ready(Some(Err(err))), + }; + buffer.push_str(&chunk); + + if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") { + includes_start_or_end_span = true; + + buffer = buffer + .strip_prefix("<|S|>") + .or_else(|| buffer.strip_prefix("<|S|")) + .unwrap_or(&buffer) + .to_string(); + } else if buffer.ends_with("|E|>") { + includes_start_or_end_span = true; + } else if buffer.starts_with("<|") + || buffer.starts_with("<|S") + || buffer.starts_with("<|S|") + || buffer.ends_with('|') + || buffer.ends_with("|E") + || buffer.ends_with("|E|") + { + return future::ready(None); + } + + if first_line { + if buffer.is_empty() || buffer == "`" || buffer == "``" { + return future::ready(None); + } else if buffer.starts_with("```") { + starts_with_markdown_codeblock = true; + if let Some(newline_ix) = buffer.find('\n') { + buffer.replace_range(..newline_ix + 1, ""); + first_line = false; + } else { + return future::ready(None); + } + } + } + + let mut text = buffer.to_string(); + if starts_with_markdown_codeblock { + text = text + .strip_suffix("\n```\n") + .or_else(|| text.strip_suffix("\n```")) + .or_else(|| text.strip_suffix("\n``")) + .or_else(|| text.strip_suffix("\n`")) + .or_else(|| text.strip_suffix('\n')) + .unwrap_or(&text) + .to_string(); + } + + if includes_start_or_end_span { + text = text + .strip_suffix("|E|>") + .or_else(|| text.strip_suffix("E|>")) + .or_else(|| text.strip_prefix("|>")) + .or_else(|| text.strip_prefix('>')) + .unwrap_or(&text) + .to_string(); + }; + + if text.contains('\n') { + first_line = false; + } + + let remainder = buffer.split_off(text.len()); + let result = if buffer.is_empty() { + None + } else { + Some(Ok(buffer.clone())) + }; + + buffer = remainder; + future::ready(result) + }) +} + +fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { + ranges.sort_unstable_by(|a, b| { + a.start + .cmp(&b.start, buffer) + .then_with(|| b.end.cmp(&a.end, buffer)) + }); + + let mut ix = 0; + while ix + 1 < ranges.len() { + let b = ranges[ix + 1].clone(); + let a = &mut ranges[ix]; + if a.end.cmp(&b.start, buffer).is_gt() { + if a.end.cmp(&b.end, buffer).is_lt() { + a.end = b.end; + } + ranges.remove(ix + 1); + } else { + ix += 1; + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::FakeCompletionProvider; + + use super::*; + use futures::stream::{self}; + use gpui::{Context, TestAppContext}; + use indoc::indoc; + use language::{ + language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher, + Point, + }; + use rand::prelude::*; + use serde::Serialize; + use settings::SettingsStore; + + #[derive(Serialize)] + pub struct DummyCompletionRequest { + pub name: String, + } + + #[gpui::test(iterations = 10)] + async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) { + let provider = FakeCompletionProvider::default(); + cx.set_global(cx.update(SettingsStore::test)); + cx.set_global(CompletionProvider::Fake(provider.clone())); + cx.update(language_settings::init); + + let text = indoc! {" + fn main() { + let x = 0; + for _ in 0..10 { + x += 1; + } + } + "}; + let buffer = + cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) + }); + let codegen = cx.new_model(|cx| { + Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx) + }); + + let request = LanguageModelRequest::default(); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); + + let mut new_text = concat!( + " let mut x = 0;\n", + " while x < 10 {\n", + " x += 1;\n", + " }", + ); + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + provider.send_completion(chunk.into()); + new_text = suffix; + cx.background_executor.run_until_parked(); + } + provider.finish_completion(); + cx.background_executor.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } + + #[gpui::test(iterations = 10)] + async fn test_autoindent_when_generating_past_indentation( + cx: &mut TestAppContext, + mut rng: StdRng, + ) { + let provider = FakeCompletionProvider::default(); + cx.set_global(CompletionProvider::Fake(provider.clone())); + cx.set_global(cx.update(SettingsStore::test)); + cx.update(language_settings::init); + + let text = indoc! {" + fn main() { + le + } + "}; + let buffer = + cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); + let position = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 6)) + }); + let codegen = cx.new_model(|cx| { + Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx) + }); + + let request = LanguageModelRequest::default(); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); + + let mut new_text = concat!( + "t mut x = 0;\n", + "while x < 10 {\n", + " x += 1;\n", + "}", // + ); + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + provider.send_completion(chunk.into()); + new_text = suffix; + cx.background_executor.run_until_parked(); + } + provider.finish_completion(); + cx.background_executor.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } + + #[gpui::test(iterations = 10)] + async fn test_autoindent_when_generating_before_indentation( + cx: &mut TestAppContext, + mut rng: StdRng, + ) { + let provider = FakeCompletionProvider::default(); + cx.set_global(CompletionProvider::Fake(provider.clone())); + cx.set_global(cx.update(SettingsStore::test)); + cx.update(language_settings::init); + + let text = concat!( + "fn main() {\n", + " \n", + "}\n" // + ); + let buffer = + cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); + let position = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 2)) + }); + let codegen = cx.new_model(|cx| { + Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx) + }); + + let request = LanguageModelRequest::default(); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); + + let mut new_text = concat!( + "let mut x = 0;\n", + "while x < 10 {\n", + " x += 1;\n", + "}", // + ); + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + provider.send_completion(chunk.into()); + new_text = suffix; + cx.background_executor.run_until_parked(); + } + provider.finish_completion(); + cx.background_executor.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } + + #[gpui::test] + async fn test_strip_invalid_spans_from_codeblock() { + assert_eq!( + strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_invalid_spans_from_codeblock(chunks( + "```html\n```js\nLorem ipsum dolor\n```\n```", + 2 + )) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "```js\nLorem ipsum dolor\n```" + ); + assert_eq!( + strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "``\nLorem ipsum dolor\n```" + ); + assert_eq!( + strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum" + ); + + assert_eq!( + strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum" + ); + + assert_eq!( + strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum" + ); + assert_eq!( + strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum" + ); + fn chunks(text: &str, size: usize) -> impl Stream> { + stream::iter( + text.chars() + .collect::>() + .chunks(size) + .map(|chunk| Ok(chunk.iter().collect::())) + .collect::>(), + ) + } + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ) + .with_indents_query( + r#" + (call_expression) @indent + (field_expression) @indent + (_ "(" ")" @end) @indent + (_ "{" "}" @end) @indent + "#, + ) + .unwrap() + } +} diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index a56ce5356c31349687d38f00e18968868f005d38..b42d828a10dc3234215f4666617c6c380aec4c19 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -1,6 +1,6 @@ use crate::{ - slash_command::SlashCommandCompletionProvider, CompletionProvider, LanguageModelRequest, - LanguageModelRequestMessage, Role, + slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider, + InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role, }; use anyhow::{anyhow, Result}; use assistant_slash_command::SlashCommandRegistry; @@ -14,8 +14,8 @@ use futures::{ use fuzzy::StringMatchCandidate; use gpui::{ actions, point, size, AnyElement, AppContext, BackgroundExecutor, Bounds, DevicePixels, - EventEmitter, Global, PromptLevel, ReadGlobal, Subscription, Task, TitlebarOptions, View, - WindowBounds, WindowHandle, WindowOptions, + EventEmitter, Global, PromptLevel, ReadGlobal, Subscription, Task, TitlebarOptions, + UpdateGlobal, View, WindowBounds, WindowHandle, WindowOptions, }; use heed::{types::SerdeBincode, Database, RoTxn}; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; @@ -35,6 +35,7 @@ use ui::{ }; use util::{paths::PROMPTS_DIR, ResultExt, TryFutureExt}; use uuid::Uuid; +use workspace::Workspace; actions!( prompt_library, @@ -603,6 +604,49 @@ impl PromptLibrary { self.picker.update(cx, |picker, cx| picker.focus(cx)); } + pub fn inline_assist(&mut self, _: &InlineAssist, cx: &mut ViewContext) { + let Some(active_prompt_id) = self.active_prompt_id else { + cx.propagate(); + return; + }; + + let prompt_editor = &self.prompt_editors[&active_prompt_id].editor; + let provider = CompletionProvider::global(cx); + if provider.is_authenticated() { + InlineAssistant::update_global(cx, |assistant, cx| { + assistant.assist(&prompt_editor, None, false, cx) + }) + } else { + for window in cx.windows() { + if let Some(workspace) = window.downcast::() { + let panel = workspace + .update(cx, |workspace, cx| { + cx.activate_window(); + workspace.focus_panel::(cx) + }) + .ok() + .flatten(); + if panel.is_some() { + return; + } + } + } + } + } + + fn cancel_last_inline_assist( + &mut self, + _: &editor::actions::Cancel, + cx: &mut ViewContext, + ) { + let canceled = InlineAssistant::update_global(cx, |assistant, cx| { + assistant.cancel_last_inline_assist(cx) + }); + if !canceled { + cx.propagate(); + } + } + fn handle_prompt_editor_event( &mut self, prompt_id: PromptId, @@ -732,6 +776,8 @@ impl PromptLibrary { .child( div() .on_action(cx.listener(Self::focus_picker)) + .on_action(cx.listener(Self::inline_assist)) + .on_action(cx.listener(Self::cancel_last_inline_assist)) .flex_grow() .h_full() .pt(Spacing::Large.rems(cx)) diff --git a/crates/gpui/src/window.rs b/crates/gpui/src/window.rs index 7024b666cf8b5bdb02426fb68efebf6183631427..78cb5dc1157ccb963c0e5cac02838af10db68ead 100644 --- a/crates/gpui/src/window.rs +++ b/crates/gpui/src/window.rs @@ -1036,6 +1036,37 @@ impl<'a> WindowContext<'a> { }); } + /// Subscribe to events emitted by a model or view. + /// The entity to which you're subscribing must implement the [`EventEmitter`] trait. + /// The callback will be invoked a handle to the emitting entity (either a [`View`] or [`Model`]), the event, and a window context for the current window. + pub fn observe( + &mut self, + entity: &E, + mut on_notify: impl FnMut(E, &mut WindowContext<'_>) + 'static, + ) -> Subscription + where + E: Entity, + { + let entity_id = entity.entity_id(); + let entity = entity.downgrade(); + let window_handle = self.window.handle; + self.app.new_observer( + entity_id, + Box::new(move |cx| { + window_handle + .update(cx, |_, cx| { + if let Some(handle) = E::upgrade_from(&entity) { + on_notify(handle, cx); + true + } else { + false + } + }) + .unwrap_or(false) + }), + ) + } + /// Subscribe to events emitted by a model or view. /// The entity to which you're subscribing must implement the [`EventEmitter`] trait. /// The callback will be invoked a handle to the emitting entity (either a [`View`] or [`Model`]), the event, and a window context for the current window.