Extract a `strip_markdown_codeblock` function

Antonio Scandurra created

Change summary

crates/ai/src/assistant.rs | 197 ++++++++++++++++++++++-----------------
1 file changed, 110 insertions(+), 87 deletions(-)

Detailed changes

crates/ai/src/assistant.rs 🔗

@@ -16,7 +16,7 @@ use editor::{
     Anchor, Editor, MultiBufferSnapshot, ToOffset, ToPoint,
 };
 use fs::Fs;
-use futures::{channel::mpsc, SinkExt, StreamExt};
+use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
 use gpui::{
     actions,
     elements::*,
@@ -620,7 +620,10 @@ impl AssistantPanel {
 
                 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
                 let diff = cx.background().spawn(async move {
-                    let mut messages = response.await?;
+                    let chunks = strip_markdown_codeblock(response.await?.filter_map(
+                        |message| async move { message.ok()?.choices.pop()?.delta.content },
+                    ));
+                    futures::pin_mut!(chunks);
                     let mut diff = StreamingDiff::new(selected_text.to_string());
 
                     let indentation_len;
@@ -636,93 +639,21 @@ impl AssistantPanel {
                         indentation_text = "";
                     };
 
-                    let mut inside_first_line = true;
-                    let mut starts_with_fenced_code_block = None;
-                    let mut has_pending_newline = false;
-                    let mut new_text = String::new();
-
-                    while let Some(message) = messages.next().await {
-                        let mut message = message?;
-                        if let Some(mut choice) = message.choices.pop() {
-                            if has_pending_newline {
-                                has_pending_newline = false;
-                                choice
-                                    .delta
-                                    .content
-                                    .get_or_insert(String::new())
-                                    .insert(0, '\n');
-                            }
-
-                            // Buffer a trailing codeblock fence. Note that we don't stop
-                            // right away because this may be an inner fence that we need
-                            // to insert into the editor.
-                            if starts_with_fenced_code_block.is_some()
-                                && choice.delta.content.as_deref() == Some("\n```")
-                            {
-                                new_text.push_str("\n```");
-                                continue;
-                            }
-
-                            // If this was the last completion and we started with a codeblock
-                            // fence and we ended with another codeblock fence, then we can
-                            // stop right away. Otherwise, whatever text we buffered will be
-                            // processed normally.
-                            if choice.finish_reason.is_some()
-                                && starts_with_fenced_code_block.unwrap_or(false)
-                                && new_text == "\n```"
-                            {
-                                break;
-                            }
-
-                            if let Some(text) = choice.delta.content {
-                                // Never push a newline if there's nothing after it. This is
-                                // useful to detect if the newline was pushed because of a
-                                // trailing codeblock fence.
-                                let text = if let Some(prefix) = text.strip_suffix('\n') {
-                                    has_pending_newline = true;
-                                    prefix
-                                } else {
-                                    text.as_str()
-                                };
-
-                                if text.is_empty() {
-                                    continue;
-                                }
-
-                                let mut lines = text.split('\n');
-                                if let Some(line) = lines.next() {
-                                    if starts_with_fenced_code_block.is_none() {
-                                        starts_with_fenced_code_block =
-                                            Some(line.starts_with("```"));
-                                    }
-
-                                    // Avoid pushing the first line if it's the start of a fenced code block.
-                                    if !inside_first_line || !starts_with_fenced_code_block.unwrap()
-                                    {
-                                        new_text.push_str(&line);
-                                    }
-                                }
+                    let mut new_text = indentation_text
+                        .repeat(indentation_len.saturating_sub(selection_start.column) as usize);
 
-                                for line in lines {
-                                    if inside_first_line && starts_with_fenced_code_block.unwrap() {
-                                        // If we were inside the first line and that line was the
-                                        // start of a fenced code block, we just need to push the
-                                        // leading indentation of the original selection.
-                                        new_text.push_str(&indentation_text.repeat(
-                                            indentation_len.saturating_sub(selection_start.column)
-                                                as usize,
-                                        ));
-                                    } else {
-                                        // Otherwise, we need to push a newline and the base indentation.
-                                        new_text.push('\n');
-                                        new_text.push_str(
-                                            &indentation_text.repeat(indentation_len as usize),
-                                        );
-                                    }
+                    while let Some(message) = chunks.next().await {
+                        let mut lines = message.split('\n');
+                        if let Some(first_line) = lines.next() {
+                            new_text.push_str(first_line);
+                        }
 
-                                    new_text.push_str(line);
-                                    inside_first_line = false;
-                                }
+                        for line in lines {
+                            new_text.push('\n');
+                            if !line.is_empty() {
+                                new_text
+                                    .push_str(&indentation_text.repeat(indentation_len as usize));
+                                new_text.push_str(line);
                             }
                         }
 
@@ -2919,10 +2850,58 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
     }
 }
 
+fn strip_markdown_codeblock(stream: impl Stream<Item = String>) -> impl Stream<Item = String> {
+    let mut first_line = true;
+    let mut buffer = String::new();
+    let mut starts_with_fenced_code_block = false;
+    stream.filter_map(move |chunk| {
+        buffer.push_str(&chunk);
+
+        if first_line {
+            if buffer == "" || buffer == "`" || buffer == "``" {
+                return futures::future::ready(None);
+            } else if buffer.starts_with("```") {
+                starts_with_fenced_code_block = true;
+                if let Some(newline_ix) = buffer.find('\n') {
+                    buffer.replace_range(..newline_ix + 1, "");
+                    first_line = false;
+                } else {
+                    return futures::future::ready(None);
+                }
+            }
+        }
+
+        let text = if starts_with_fenced_code_block {
+            buffer
+                .strip_suffix("\n```")
+                .or_else(|| buffer.strip_suffix("\n``"))
+                .or_else(|| buffer.strip_suffix("\n`"))
+                .or_else(|| buffer.strip_suffix('\n'))
+                .unwrap_or(&buffer)
+        } else {
+            &buffer
+        };
+
+        if text.contains('\n') {
+            first_line = false;
+        }
+
+        let remainder = buffer.split_off(text.len());
+        let result = if buffer.is_empty() {
+            None
+        } else {
+            Some(buffer.clone())
+        };
+        buffer = remainder;
+        futures::future::ready(result)
+    })
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
     use crate::MessageId;
+    use futures::stream;
     use gpui::AppContext;
 
     #[gpui::test]
@@ -3291,6 +3270,50 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    async fn test_strip_markdown_codeblock() {
+        assert_eq!(
+            strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
+                .collect::<String>()
+                .await,
+            "Lorem ipsum dolor"
+        );
+        assert_eq!(
+            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
+                .collect::<String>()
+                .await,
+            "Lorem ipsum dolor"
+        );
+        assert_eq!(
+            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
+                .collect::<String>()
+                .await,
+            "Lorem ipsum dolor"
+        );
+        assert_eq!(
+            strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
+                .collect::<String>()
+                .await,
+            "```js\nLorem ipsum dolor\n```"
+        );
+        assert_eq!(
+            strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
+                .collect::<String>()
+                .await,
+            "``\nLorem ipsum dolor\n```"
+        );
+
+        fn chunks(text: &str, size: usize) -> impl Stream<Item = String> {
+            stream::iter(
+                text.chars()
+                    .collect::<Vec<_>>()
+                    .chunks(size)
+                    .map(|chunk| chunk.iter().collect::<String>())
+                    .collect::<Vec<_>>(),
+            )
+        }
+    }
+
     fn messages(
         conversation: &ModelHandle<Conversation>,
         cx: &AppContext,