Normalize indentation when refactoring

Antonio Scandurra created

Change summary

crates/ai/src/refactor.rs | 84 +++++++++++++++++++++++++++++++++++++---
crates/rope/src/rope.rs   | 10 ++++
2 files changed, 87 insertions(+), 7 deletions(-)

Detailed changes

crates/ai/src/refactor.rs 🔗

@@ -1,13 +1,14 @@
 use crate::{diff::Diff, stream_completion, OpenAIRequest, RequestMessage, Role};
 use collections::HashMap;
-use editor::{Editor, ToOffset};
+use editor::{Editor, ToOffset, ToPoint};
 use futures::{channel::mpsc, SinkExt, StreamExt};
 use gpui::{
     actions, elements::*, platform::MouseButton, AnyViewHandle, AppContext, Entity, Task, View,
     ViewContext, ViewHandle, WeakViewHandle,
 };
+use language::{Point, Rope};
 use menu::{Cancel, Confirm};
-use std::{env, sync::Arc};
+use std::{cmp, env, sync::Arc};
 use util::TryFutureExt;
 use workspace::{Modal, Workspace};
 
@@ -36,7 +37,48 @@ impl RefactoringAssistant {
         let selection = editor.read(cx).selections.newest_anchor().clone();
         let selected_text = snapshot
             .text_for_range(selection.start..selection.end)
-            .collect::<String>();
+            .collect::<Rope>();
+
+        let mut normalized_selected_text = selected_text.clone();
+        let mut base_indentation: Option<language::IndentSize> = None;
+        let selection_start = selection.start.to_point(&snapshot);
+        let selection_end = selection.end.to_point(&snapshot);
+        if selection_start.row < selection_end.row {
+            for row in selection_start.row..=selection_end.row {
+                if snapshot.is_line_blank(row) {
+                    continue;
+                }
+
+                let line_indentation = snapshot.indent_size_for_line(row);
+                if let Some(base_indentation) = base_indentation.as_mut() {
+                    if line_indentation.len < base_indentation.len {
+                        *base_indentation = line_indentation;
+                    }
+                } else {
+                    base_indentation = Some(line_indentation);
+                }
+            }
+        }
+
+        if let Some(base_indentation) = base_indentation {
+            for row in selection_start.row..=selection_end.row {
+                let selection_row = row - selection_start.row;
+                let line_start =
+                    normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
+                let indentation_len = if row == selection_start.row {
+                    base_indentation.len.saturating_sub(selection_start.column)
+                } else {
+                    let line_len = normalized_selected_text.line_len(selection_row);
+                    cmp::min(line_len, base_indentation.len)
+                };
+                let indentation_end = cmp::min(
+                    line_start + indentation_len as usize,
+                    normalized_selected_text.len(),
+                );
+                normalized_selected_text.replace(line_start..indentation_end, "");
+            }
+        }
+
         let language_name = snapshot
             .language_at(selection.start)
             .map(|language| language.name());
@@ -47,7 +89,7 @@ impl RefactoringAssistant {
                 RequestMessage {
                 role: Role::User,
                 content: format!(
-                    "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Never make remarks and reply only with the new code. Never change the leading whitespace on each line."
+                    "Given the following {language_name} snippet:\n{normalized_selected_text}\n{prompt}. Never make remarks and reply only with the new code."
                 ),
             }],
             stream: true,
@@ -64,21 +106,49 @@ impl RefactoringAssistant {
                     let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
                     let diff = cx.background().spawn(async move {
                         let mut messages = response.await?.ready_chunks(4);
-                        let mut diff = Diff::new(selected_text);
+                        let mut diff = Diff::new(selected_text.to_string());
 
+                        let indentation_len;
+                        let indentation_text;
+                        if let Some(base_indentation) = base_indentation {
+                            indentation_len = base_indentation.len;
+                            indentation_text = match base_indentation.kind {
+                                language::IndentKind::Space => " ",
+                                language::IndentKind::Tab => "\t",
+                            };
+                        } else {
+                            indentation_len = 0;
+                            indentation_text = "";
+                        };
+
+                        let mut new_text =
+                            indentation_text.repeat(
+                                indentation_len.saturating_sub(selection_start.column) as usize,
+                            );
                         while let Some(messages) = messages.next().await {
-                            let mut new_text = String::new();
                             for message in messages {
                                 let mut message = message?;
                                 if let Some(choice) = message.choices.pop() {
                                     if let Some(text) = choice.delta.content {
-                                        new_text.push_str(&text);
+                                        let mut lines = text.split('\n');
+                                        if let Some(first_line) = lines.next() {
+                                            new_text.push_str(&first_line);
+                                        }
+
+                                        for line in lines {
+                                            new_text.push('\n');
+                                            new_text.push_str(
+                                                &indentation_text.repeat(indentation_len as usize),
+                                            );
+                                            new_text.push_str(line);
+                                        }
                                     }
                                 }
                             }
 
                             let hunks = diff.push_new(&new_text);
                             hunks_tx.send(hunks).await?;
+                            new_text.clear();
                         }
                         hunks_tx.send(diff.finish()).await?;
 

crates/rope/src/rope.rs 🔗

@@ -384,6 +384,16 @@ impl<'a> From<&'a str> for Rope {
     }
 }
 
+impl<'a> FromIterator<&'a str> for Rope {
+    fn from_iter<T: IntoIterator<Item = &'a str>>(iter: T) -> Self {
+        let mut rope = Rope::new();
+        for chunk in iter {
+            rope.push(chunk);
+        }
+        rope
+    }
+}
+
 impl From<String> for Rope {
     fn from(text: String) -> Self {
         Rope::from(text.as_str())