@@ -16,7 +16,7 @@ use editor::{
Anchor, Editor, MultiBufferSnapshot, ToOffset, ToPoint,
};
use fs::Fs;
-use futures::{channel::mpsc, SinkExt, StreamExt};
+use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{
actions,
elements::*,
@@ -620,7 +620,10 @@ impl AssistantPanel {
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
let diff = cx.background().spawn(async move {
- let mut messages = response.await?;
+ let chunks = strip_markdown_codeblock(response.await?.filter_map(
+ |message| async move { message.ok()?.choices.pop()?.delta.content },
+ ));
+ futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let indentation_len;
@@ -636,93 +639,21 @@ impl AssistantPanel {
indentation_text = "";
};
- let mut inside_first_line = true;
- let mut starts_with_fenced_code_block = None;
- let mut has_pending_newline = false;
- let mut new_text = String::new();
-
- while let Some(message) = messages.next().await {
- let mut message = message?;
- if let Some(mut choice) = message.choices.pop() {
- if has_pending_newline {
- has_pending_newline = false;
- choice
- .delta
- .content
- .get_or_insert(String::new())
- .insert(0, '\n');
- }
-
- // Buffer a trailing codeblock fence. Note that we don't stop
- // right away because this may be an inner fence that we need
- // to insert into the editor.
- if starts_with_fenced_code_block.is_some()
- && choice.delta.content.as_deref() == Some("\n```")
- {
- new_text.push_str("\n```");
- continue;
- }
-
- // If this was the last completion and we started with a codeblock
- // fence and we ended with another codeblock fence, then we can
- // stop right away. Otherwise, whatever text we buffered will be
- // processed normally.
- if choice.finish_reason.is_some()
- && starts_with_fenced_code_block.unwrap_or(false)
- && new_text == "\n```"
- {
- break;
- }
-
- if let Some(text) = choice.delta.content {
- // Never push a newline if there's nothing after it. This is
- // useful to detect if the newline was pushed because of a
- // trailing codeblock fence.
- let text = if let Some(prefix) = text.strip_suffix('\n') {
- has_pending_newline = true;
- prefix
- } else {
- text.as_str()
- };
-
- if text.is_empty() {
- continue;
- }
-
- let mut lines = text.split('\n');
- if let Some(line) = lines.next() {
- if starts_with_fenced_code_block.is_none() {
- starts_with_fenced_code_block =
- Some(line.starts_with("```"));
- }
-
- // Avoid pushing the first line if it's the start of a fenced code block.
- if !inside_first_line || !starts_with_fenced_code_block.unwrap()
- {
- new_text.push_str(&line);
- }
- }
+ let mut new_text = indentation_text
+ .repeat(indentation_len.saturating_sub(selection_start.column) as usize);
- for line in lines {
- if inside_first_line && starts_with_fenced_code_block.unwrap() {
- // If we were inside the first line and that line was the
- // start of a fenced code block, we just need to push the
- // leading indentation of the original selection.
- new_text.push_str(&indentation_text.repeat(
- indentation_len.saturating_sub(selection_start.column)
- as usize,
- ));
- } else {
- // Otherwise, we need to push a newline and the base indentation.
- new_text.push('\n');
- new_text.push_str(
- &indentation_text.repeat(indentation_len as usize),
- );
- }
+ while let Some(message) = chunks.next().await {
+ let mut lines = message.split('\n');
+ if let Some(first_line) = lines.next() {
+ new_text.push_str(first_line);
+ }
- new_text.push_str(line);
- inside_first_line = false;
- }
+ for line in lines {
+ new_text.push('\n');
+ if !line.is_empty() {
+ new_text
+ .push_str(&indentation_text.repeat(indentation_len as usize));
+ new_text.push_str(line);
}
}
@@ -2919,10 +2850,58 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
}
}
+fn strip_markdown_codeblock(stream: impl Stream<Item = String>) -> impl Stream<Item = String> {
+ let mut first_line = true;
+ let mut buffer = String::new();
+ let mut starts_with_fenced_code_block = false;
+ stream.filter_map(move |chunk| {
+ buffer.push_str(&chunk);
+
+ if first_line {
+ if buffer == "" || buffer == "`" || buffer == "``" {
+ return futures::future::ready(None);
+ } else if buffer.starts_with("```") {
+ starts_with_fenced_code_block = true;
+ if let Some(newline_ix) = buffer.find('\n') {
+ buffer.replace_range(..newline_ix + 1, "");
+ first_line = false;
+ } else {
+ return futures::future::ready(None);
+ }
+ }
+ }
+
+ let text = if starts_with_fenced_code_block {
+ buffer
+ .strip_suffix("\n```")
+ .or_else(|| buffer.strip_suffix("\n``"))
+ .or_else(|| buffer.strip_suffix("\n`"))
+ .or_else(|| buffer.strip_suffix('\n'))
+ .unwrap_or(&buffer)
+ } else {
+ &buffer
+ };
+
+ if text.contains('\n') {
+ first_line = false;
+ }
+
+ let remainder = buffer.split_off(text.len());
+ let result = if buffer.is_empty() {
+ None
+ } else {
+ Some(buffer.clone())
+ };
+ buffer = remainder;
+ futures::future::ready(result)
+ })
+}
+
#[cfg(test)]
mod tests {
use super::*;
use crate::MessageId;
+ use futures::stream;
use gpui::AppContext;
#[gpui::test]
@@ -3291,6 +3270,50 @@ mod tests {
);
}
+ #[gpui::test]
+ async fn test_strip_markdown_codeblock() {
+ assert_eq!(
+ strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
+ .collect::<String>()
+ .await,
+ "Lorem ipsum dolor"
+ );
+ assert_eq!(
+ strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
+ .collect::<String>()
+ .await,
+ "Lorem ipsum dolor"
+ );
+ assert_eq!(
+ strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
+ .collect::<String>()
+ .await,
+ "Lorem ipsum dolor"
+ );
+ assert_eq!(
+ strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
+ .collect::<String>()
+ .await,
+ "```js\nLorem ipsum dolor\n```"
+ );
+ assert_eq!(
+ strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
+ .collect::<String>()
+ .await,
+ "``\nLorem ipsum dolor\n```"
+ );
+
+ fn chunks(text: &str, size: usize) -> impl Stream<Item = String> {
+ stream::iter(
+ text.chars()
+ .collect::<Vec<_>>()
+ .chunks(size)
+ .map(|chunk| chunk.iter().collect::<String>())
+ .collect::<Vec<_>>(),
+ )
+ }
+ }
+
fn messages(
conversation: &ModelHandle<Conversation>,
cx: &AppContext,