Never use the indentation that comes from OpenAI

Antonio Scandurra created

Change summary

crates/ai/src/assistant.rs |  51 +++--
crates/ai/src/codegen.rs   | 338 +++++++++++++++++++++++++++------------
2 files changed, 257 insertions(+), 132 deletions(-)

Detailed changes

crates/ai/src/assistant.rs đź”—

@@ -1,6 +1,6 @@
 use crate::{
     assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
-    codegen::{self, Codegen, OpenAICompletionProvider},
+    codegen::{self, Codegen, CodegenKind, OpenAICompletionProvider},
     stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage,
     Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
 };
@@ -270,24 +270,28 @@ impl AssistantPanel {
 
         let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
         let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
-        let selection = editor.read(cx).selections.newest_anchor().clone();
-        let range = selection.start.bias_left(&snapshot)..selection.end.bias_right(&snapshot);
         let provider = Arc::new(OpenAICompletionProvider::new(
             api_key,
             cx.background().clone(),
         ));
-        let codegen =
-            cx.add_model(|cx| Codegen::new(editor.read(cx).buffer().clone(), range, provider, cx));
-        let assist_kind = if editor.read(cx).selections.newest::<usize>(cx).is_empty() {
-            InlineAssistKind::Generate
+        let selection = editor.read(cx).selections.newest_anchor().clone();
+        let codegen_kind = if editor.read(cx).selections.newest::<usize>(cx).is_empty() {
+            CodegenKind::Generate {
+                position: selection.start,
+            }
         } else {
-            InlineAssistKind::Transform
+            CodegenKind::Transform {
+                range: selection.start..selection.end,
+            }
         };
+        let codegen = cx.add_model(|cx| {
+            Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
+        });
+
         let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
         let inline_assistant = cx.add_view(|cx| {
             let assistant = InlineAssistant::new(
                 inline_assist_id,
-                assist_kind,
                 measurements.clone(),
                 self.include_conversation_in_next_inline_assist,
                 self.inline_prompt_history.clone(),
@@ -330,7 +334,6 @@ impl AssistantPanel {
         self.pending_inline_assists.insert(
             inline_assist_id,
             PendingInlineAssist {
-                kind: assist_kind,
                 editor: editor.downgrade(),
                 inline_assistant: Some((block_id, inline_assistant.clone())),
                 codegen: codegen.clone(),
@@ -348,6 +351,14 @@ impl AssistantPanel {
                             }
                         }
                     }),
+                    cx.observe(&codegen, {
+                        let editor = editor.downgrade();
+                        move |this, _, cx| {
+                            if let Some(editor) = editor.upgrade(cx) {
+                                this.update_highlights_for_editor(&editor, cx);
+                            }
+                        }
+                    }),
                     cx.subscribe(&codegen, move |this, codegen, event, cx| match event {
                         codegen::Event::Undone => {
                             this.finish_inline_assist(inline_assist_id, false, cx)
@@ -542,8 +553,8 @@ impl AssistantPanel {
         if let Some(language_name) = language_name {
             writeln!(prompt, "You're an expert {language_name} engineer.").unwrap();
         }
-        match pending_assist.kind {
-            InlineAssistKind::Transform => {
+        match pending_assist.codegen.read(cx).kind() {
+            CodegenKind::Transform { .. } => {
                 writeln!(
                     prompt,
                     "You're currently working inside an editor on this file:"
@@ -583,7 +594,7 @@ impl AssistantPanel {
                 )
                 .unwrap();
             }
-            InlineAssistKind::Generate => {
+            CodegenKind::Generate { .. } => {
                 writeln!(
                     prompt,
                     "You're currently working inside an editor on this file:"
@@ -2649,12 +2660,6 @@ enum InlineAssistantEvent {
     },
 }
 
-#[derive(Copy, Clone)]
-enum InlineAssistKind {
-    Transform,
-    Generate,
-}
-
 struct InlineAssistant {
     id: usize,
     prompt_editor: ViewHandle<Editor>,
@@ -2769,7 +2774,6 @@ impl View for InlineAssistant {
 impl InlineAssistant {
     fn new(
         id: usize,
-        kind: InlineAssistKind,
         measurements: Rc<Cell<BlockMeasurements>>,
         include_conversation: bool,
         prompt_history: VecDeque<String>,
@@ -2781,9 +2785,9 @@ impl InlineAssistant {
                 Some(Arc::new(|theme| theme.assistant.inline.editor.clone())),
                 cx,
             );
-            let placeholder = match kind {
-                InlineAssistKind::Transform => "Enter transformation prompt…",
-                InlineAssistKind::Generate => "Enter generation prompt…",
+            let placeholder = match codegen.read(cx).kind() {
+                CodegenKind::Transform { .. } => "Enter transformation prompt…",
+                CodegenKind::Generate { .. } => "Enter generation prompt…",
             };
             editor.set_placeholder_text(placeholder, cx);
             editor
@@ -2929,7 +2933,6 @@ struct BlockMeasurements {
 }
 
 struct PendingInlineAssist {
-    kind: InlineAssistKind,
     editor: WeakViewHandle<Editor>,
     inline_assistant: Option<(BlockId, ViewHandle<InlineAssistant>)>,
     codegen: ModelHandle<Codegen>,

crates/ai/src/codegen.rs đź”—

@@ -4,12 +4,14 @@ use crate::{
     OpenAIRequest,
 };
 use anyhow::Result;
-use editor::{multi_buffer, Anchor, MultiBuffer, ToOffset, ToPoint};
+use editor::{
+    multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
+};
 use futures::{
     channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
 };
 use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
-use language::{IndentSize, Point, Rope, TransactionId};
+use language::{Rope, TransactionId};
 use std::{cmp, future, ops::Range, sync::Arc};
 
 pub trait CompletionProvider {
@@ -57,10 +59,17 @@ pub enum Event {
     Undone,
 }
 
+#[derive(Clone)]
+pub enum CodegenKind {
+    Transform { range: Range<Anchor> },
+    Generate { position: Anchor },
+}
+
 pub struct Codegen {
     provider: Arc<dyn CompletionProvider>,
     buffer: ModelHandle<MultiBuffer>,
-    range: Range<Anchor>,
+    snapshot: MultiBufferSnapshot,
+    kind: CodegenKind,
     last_equal_ranges: Vec<Range<Anchor>>,
     transaction_id: Option<TransactionId>,
     error: Option<anyhow::Error>,
@@ -76,14 +85,31 @@ impl Entity for Codegen {
 impl Codegen {
     pub fn new(
         buffer: ModelHandle<MultiBuffer>,
-        range: Range<Anchor>,
+        mut kind: CodegenKind,
         provider: Arc<dyn CompletionProvider>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
+        let snapshot = buffer.read(cx).snapshot(cx);
+        match &mut kind {
+            CodegenKind::Transform { range } => {
+                let mut point_range = range.to_point(&snapshot);
+                point_range.start.column = 0;
+                if point_range.end.column > 0 || point_range.start.row == point_range.end.row {
+                    point_range.end.column = snapshot.line_len(point_range.end.row);
+                }
+                range.start = snapshot.anchor_before(point_range.start);
+                range.end = snapshot.anchor_after(point_range.end);
+            }
+            CodegenKind::Generate { position } => {
+                *position = position.bias_right(&snapshot);
+            }
+        }
+
         Self {
             provider,
             buffer: buffer.clone(),
-            range,
+            snapshot,
+            kind,
             last_equal_ranges: Default::default(),
             transaction_id: Default::default(),
             error: Default::default(),
@@ -109,7 +135,14 @@ impl Codegen {
     }
 
     pub fn range(&self) -> Range<Anchor> {
-        self.range.clone()
+        match &self.kind {
+            CodegenKind::Transform { range } => range.clone(),
+            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
+        }
+    }
+
+    pub fn kind(&self) -> &CodegenKind {
+        &self.kind
     }
 
     pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
@@ -125,56 +158,18 @@ impl Codegen {
     }
 
     pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
-        let range = self.range.clone();
-        let snapshot = self.buffer.read(cx).snapshot(cx);
+        let range = self.range();
+        let snapshot = self.snapshot.clone();
         let selected_text = snapshot
             .text_for_range(range.start..range.end)
             .collect::<Rope>();
 
         let selection_start = range.start.to_point(&snapshot);
-        let selection_end = range.end.to_point(&snapshot);
-
-        let mut base_indent: Option<IndentSize> = None;
-        let mut start_row = selection_start.row;
-        if snapshot.is_line_blank(start_row) {
-            if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) {
-                start_row = prev_non_blank_row;
-            }
-        }
-        for row in start_row..=selection_end.row {
-            if snapshot.is_line_blank(row) {
-                continue;
-            }
-
-            let line_indent = snapshot.indent_size_for_line(row);
-            if let Some(base_indent) = base_indent.as_mut() {
-                if line_indent.len < base_indent.len {
-                    *base_indent = line_indent;
-                }
-            } else {
-                base_indent = Some(line_indent);
-            }
-        }
-
-        let mut normalized_selected_text = selected_text.clone();
-        if let Some(base_indent) = base_indent {
-            for row in selection_start.row..=selection_end.row {
-                let selection_row = row - selection_start.row;
-                let line_start =
-                    normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
-                let indent_len = if row == selection_start.row {
-                    base_indent.len.saturating_sub(selection_start.column)
-                } else {
-                    let line_len = normalized_selected_text.line_len(selection_row);
-                    cmp::min(line_len, base_indent.len)
-                };
-                let indent_end = cmp::min(
-                    line_start + indent_len as usize,
-                    normalized_selected_text.len(),
-                );
-                normalized_selected_text.replace(line_start..indent_end, "");
-            }
-        }
+        let suggested_line_indent = snapshot
+            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
+            .into_values()
+            .next()
+            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
 
         let response = self.provider.complete(prompt);
         self.generation = cx.spawn_weak(|this, mut cx| {
@@ -188,66 +183,58 @@ impl Codegen {
                         futures::pin_mut!(chunks);
                         let mut diff = StreamingDiff::new(selected_text.to_string());
 
-                        let mut indent_len;
-                        let indent_text;
-                        if let Some(base_indent) = base_indent {
-                            indent_len = base_indent.len;
-                            indent_text = match base_indent.kind {
-                                language::IndentKind::Space => " ",
-                                language::IndentKind::Tab => "\t",
-                            };
-                        } else {
-                            indent_len = 0;
-                            indent_text = "";
-                        };
-
-                        let mut first_line_len = 0;
-                        let mut first_line_non_whitespace_char_ix = None;
-                        let mut first_line = true;
                         let mut new_text = String::new();
+                        let mut base_indent = None;
+                        let mut line_indent = None;
+                        let mut first_line = true;
 
                         while let Some(chunk) = chunks.next().await {
                             let chunk = chunk?;
 
-                            let mut lines = chunk.split('\n');
-                            if let Some(mut line) = lines.next() {
-                                if first_line {
-                                    if first_line_non_whitespace_char_ix.is_none() {
-                                        if let Some(mut char_ix) =
-                                            line.find(|ch: char| !ch.is_whitespace())
-                                        {
-                                            line = &line[char_ix..];
-                                            char_ix += first_line_len;
-                                            first_line_non_whitespace_char_ix = Some(char_ix);
-                                            let first_line_indent = char_ix
-                                                .saturating_sub(selection_start.column as usize)
-                                                as usize;
-                                            new_text
-                                                .push_str(&indent_text.repeat(first_line_indent));
-                                            indent_len = indent_len.saturating_sub(char_ix as u32);
+                            let mut lines = chunk.split('\n').peekable();
+                            while let Some(line) = lines.next() {
+                                new_text.push_str(line);
+                                if line_indent.is_none() {
+                                    if let Some(non_whitespace_ch_ix) =
+                                        new_text.find(|ch: char| !ch.is_whitespace())
+                                    {
+                                        line_indent = Some(non_whitespace_ch_ix);
+                                        base_indent = base_indent.or(line_indent);
+
+                                        let line_indent = line_indent.unwrap();
+                                        let base_indent = base_indent.unwrap();
+                                        let indent_delta = line_indent as i32 - base_indent as i32;
+                                        let mut corrected_indent_len = cmp::max(
+                                            0,
+                                            suggested_line_indent.len as i32 + indent_delta,
+                                        )
+                                            as usize;
+                                        if first_line {
+                                            corrected_indent_len = corrected_indent_len
+                                                .saturating_sub(selection_start.column as usize);
                                         }
-                                    }
-                                    first_line_len += line.len();
-                                }
 
-                                if first_line_non_whitespace_char_ix.is_some() {
-                                    new_text.push_str(line);
+                                        let indent_char = suggested_line_indent.char();
+                                        let mut indent_buffer = [0; 4];
+                                        let indent_str =
+                                            indent_char.encode_utf8(&mut indent_buffer);
+                                        new_text.replace_range(
+                                            ..line_indent,
+                                            &indent_str.repeat(corrected_indent_len),
+                                        );
+                                    }
                                 }
-                            }
 
-                            for line in lines {
-                                first_line = false;
-                                new_text.push('\n');
-                                if !line.is_empty() {
-                                    new_text.push_str(&indent_text.repeat(indent_len as usize));
+                                if lines.peek().is_some() {
+                                    hunks_tx.send(diff.push_new(&new_text)).await?;
+                                    hunks_tx.send(diff.push_new("\n")).await?;
+                                    new_text.clear();
+                                    line_indent = None;
+                                    first_line = false;
                                 }
-                                new_text.push_str(line);
                             }
-
-                            let hunks = diff.push_new(&new_text);
-                            hunks_tx.send(hunks).await?;
-                            new_text.clear();
                         }
+                        hunks_tx.send(diff.push_new(&new_text)).await?;
                         hunks_tx.send(diff.finish()).await?;
 
                         anyhow::Ok(())
@@ -285,7 +272,7 @@ impl Codegen {
                                             let edit_end = edit_start + len;
                                             let edit_range = snapshot.anchor_after(edit_start)
                                                 ..snapshot.anchor_before(edit_end);
-                                            edit_start += len;
+                                            edit_start = edit_end;
                                             this.last_equal_ranges.push(edit_range);
                                             None
                                         }
@@ -410,16 +397,20 @@ mod tests {
     use futures::stream;
     use gpui::{executor::Deterministic, TestAppContext};
     use indoc::indoc;
-    use language::{tree_sitter_rust, Buffer, Language, LanguageConfig};
+    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
     use parking_lot::Mutex;
     use rand::prelude::*;
+    use settings::SettingsStore;
 
     #[gpui::test(iterations = 10)]
-    async fn test_autoindent(
+    async fn test_transform_autoindent(
         cx: &mut TestAppContext,
         mut rng: StdRng,
         deterministic: Arc<Deterministic>,
     ) {
+        cx.set_global(cx.read(SettingsStore::test));
+        cx.update(language_settings::init);
+
         let text = indoc! {"
             fn main() {
                 let x = 0;
@@ -436,15 +427,146 @@ mod tests {
             snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
         });
         let provider = Arc::new(TestCompletionProvider::new());
-        let codegen = cx.add_model(|cx| Codegen::new(buffer.clone(), range, provider.clone(), cx));
+        let codegen = cx.add_model(|cx| {
+            Codegen::new(
+                buffer.clone(),
+                CodegenKind::Transform { range },
+                provider.clone(),
+                cx,
+            )
+        });
         codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
 
-        let mut new_text = indoc! {"
-                   let mut x = 0;
-            while x < 10 {
-                           x += 1;
-               }
+        let mut new_text = concat!(
+            "       let mut x = 0;\n",
+            "       while x < 10 {\n",
+            "           x += 1;\n",
+            "       }",
+        );
+        while !new_text.is_empty() {
+            let max_len = cmp::min(new_text.len(), 10);
+            let len = rng.gen_range(1..=max_len);
+            let (chunk, suffix) = new_text.split_at(len);
+            provider.send_completion(chunk);
+            new_text = suffix;
+            deterministic.run_until_parked();
+        }
+        provider.finish_completion();
+        deterministic.run_until_parked();
+
+        assert_eq!(
+            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+            indoc! {"
+                fn main() {
+                    let mut x = 0;
+                    while x < 10 {
+                        x += 1;
+                    }
+                }
+            "}
+        );
+    }
+
+    #[gpui::test(iterations = 10)]
+    async fn test_autoindent_when_generating_past_indentation(
+        cx: &mut TestAppContext,
+        mut rng: StdRng,
+        deterministic: Arc<Deterministic>,
+    ) {
+        cx.set_global(cx.read(SettingsStore::test));
+        cx.update(language_settings::init);
+
+        let text = indoc! {"
+            fn main() {
+                le
+            }
         "};
+        let buffer =
+            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
+        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
+        let position = buffer.read_with(cx, |buffer, cx| {
+            let snapshot = buffer.snapshot(cx);
+            snapshot.anchor_before(Point::new(1, 6))
+        });
+        let provider = Arc::new(TestCompletionProvider::new());
+        let codegen = cx.add_model(|cx| {
+            Codegen::new(
+                buffer.clone(),
+                CodegenKind::Generate { position },
+                provider.clone(),
+                cx,
+            )
+        });
+        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+        let mut new_text = concat!(
+            "t mut x = 0;\n",
+            "while x < 10 {\n",
+            "    x += 1;\n",
+            "}", //
+        );
+        while !new_text.is_empty() {
+            let max_len = cmp::min(new_text.len(), 10);
+            let len = rng.gen_range(1..=max_len);
+            let (chunk, suffix) = new_text.split_at(len);
+            provider.send_completion(chunk);
+            new_text = suffix;
+            deterministic.run_until_parked();
+        }
+        provider.finish_completion();
+        deterministic.run_until_parked();
+
+        assert_eq!(
+            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+            indoc! {"
+                fn main() {
+                    let mut x = 0;
+                    while x < 10 {
+                        x += 1;
+                    }
+                }
+            "}
+        );
+    }
+
+    #[gpui::test(iterations = 10)]
+    async fn test_autoindent_when_generating_before_indentation(
+        cx: &mut TestAppContext,
+        mut rng: StdRng,
+        deterministic: Arc<Deterministic>,
+    ) {
+        cx.set_global(cx.read(SettingsStore::test));
+        cx.update(language_settings::init);
+
+        let text = concat!(
+            "fn main() {\n",
+            "  \n",
+            "}\n" //
+        );
+        let buffer =
+            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
+        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
+        let position = buffer.read_with(cx, |buffer, cx| {
+            let snapshot = buffer.snapshot(cx);
+            snapshot.anchor_before(Point::new(1, 2))
+        });
+        let provider = Arc::new(TestCompletionProvider::new());
+        let codegen = cx.add_model(|cx| {
+            Codegen::new(
+                buffer.clone(),
+                CodegenKind::Generate { position },
+                provider.clone(),
+                cx,
+            )
+        });
+        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+        let mut new_text = concat!(
+            "let mut x = 0;\n",
+            "while x < 10 {\n",
+            "    x += 1;\n",
+            "}", //
+        );
         while !new_text.is_empty() {
             let max_len = cmp::min(new_text.len(), 10);
             let len = rng.gen_range(1..=max_len);