Incrementally diff input coming from GPT

Antonio Scandurra created

Change summary

Cargo.lock                 |   1 
crates/ai/Cargo.toml       |   1 
crates/ai/src/ai.rs        | 107 ++++++++++++++++-
crates/ai/src/assistant.rs |  99 ----------------
crates/ai/src/refactor.rs  | 233 ++++++++++++++++++++++++++++++++++++---
prompt.md                  |  11 -
6 files changed, 315 insertions(+), 137 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -116,6 +116,7 @@ dependencies = [
  "serde",
  "serde_json",
  "settings",
+ "similar",
  "smol",
  "theme",
  "tiktoken-rs 0.4.5",

crates/ai/Cargo.toml 🔗

@@ -29,6 +29,7 @@ regex.workspace = true
 schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true
+similar = "1.3"
 smol.workspace = true
 tiktoken-rs = "0.4"
 

crates/ai/src/ai.rs 🔗

@@ -2,27 +2,31 @@ pub mod assistant;
 mod assistant_settings;
 mod refactor;
 
-use anyhow::Result;
+use anyhow::{anyhow, Result};
 pub use assistant::AssistantPanel;
 use chrono::{DateTime, Local};
 use collections::HashMap;
 use fs::Fs;
-use futures::StreamExt;
-use gpui::AppContext;
+use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
+use gpui::{executor::Background, AppContext};
+use isahc::{http::StatusCode, Request, RequestExt};
 use regex::Regex;
 use serde::{Deserialize, Serialize};
 use std::{
     cmp::Reverse,
     ffi::OsStr,
     fmt::{self, Display},
+    io,
     path::PathBuf,
     sync::Arc,
 };
 use util::paths::CONVERSATIONS_DIR;
 
+const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
+
 // Data types for chat completion requests
 #[derive(Debug, Serialize)]
-struct OpenAIRequest {
+pub struct OpenAIRequest {
     model: String,
     messages: Vec<RequestMessage>,
     stream: bool,
@@ -116,7 +120,7 @@ struct RequestMessage {
 }
 
 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-struct ResponseMessage {
+pub struct ResponseMessage {
     role: Option<Role>,
     content: Option<String>,
 }
@@ -150,7 +154,7 @@ impl Display for Role {
 }
 
 #[derive(Deserialize, Debug)]
-struct OpenAIResponseStreamEvent {
+pub struct OpenAIResponseStreamEvent {
     pub id: Option<String>,
     pub object: String,
     pub created: u32,
@@ -160,14 +164,14 @@ struct OpenAIResponseStreamEvent {
 }
 
 #[derive(Deserialize, Debug)]
-struct Usage {
+pub struct Usage {
     pub prompt_tokens: u32,
     pub completion_tokens: u32,
     pub total_tokens: u32,
 }
 
 #[derive(Deserialize, Debug)]
-struct ChatChoiceDelta {
+pub struct ChatChoiceDelta {
     pub index: u32,
     pub delta: ResponseMessage,
     pub finish_reason: Option<String>,
@@ -190,4 +194,91 @@ struct OpenAIChoice {
 
 pub fn init(cx: &mut AppContext) {
     assistant::init(cx);
+    refactor::init(cx);
+}
+
+pub async fn stream_completion(
+    api_key: String,
+    executor: Arc<Background>,
+    mut request: OpenAIRequest,
+) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+    request.stream = true;
+
+    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
+
+    let json_data = serde_json::to_string(&request)?;
+    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
+        .header("Content-Type", "application/json")
+        .header("Authorization", format!("Bearer {}", api_key))
+        .body(json_data)?
+        .send_async()
+        .await?;
+
+    let status = response.status();
+    if status == StatusCode::OK {
+        executor
+            .spawn(async move {
+                let mut lines = BufReader::new(response.body_mut()).lines();
+
+                fn parse_line(
+                    line: Result<String, io::Error>,
+                ) -> Result<Option<OpenAIResponseStreamEvent>> {
+                    if let Some(data) = line?.strip_prefix("data: ") {
+                        let event = serde_json::from_str(&data)?;
+                        Ok(Some(event))
+                    } else {
+                        Ok(None)
+                    }
+                }
+
+                while let Some(line) = lines.next().await {
+                    if let Some(event) = parse_line(line).transpose() {
+                        let done = event.as_ref().map_or(false, |event| {
+                            event
+                                .choices
+                                .last()
+                                .map_or(false, |choice| choice.finish_reason.is_some())
+                        });
+                        if tx.unbounded_send(event).is_err() {
+                            break;
+                        }
+
+                        if done {
+                            break;
+                        }
+                    }
+                }
+
+                anyhow::Ok(())
+            })
+            .detach();
+
+        Ok(rx)
+    } else {
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+
+        #[derive(Deserialize)]
+        struct OpenAIResponse {
+            error: OpenAIError,
+        }
+
+        #[derive(Deserialize)]
+        struct OpenAIError {
+            message: String,
+        }
+
+        match serde_json::from_str::<OpenAIResponse>(&body) {
+            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+                "Failed to connect to OpenAI API: {}",
+                response.error.message,
+            )),
+
+            _ => Err(anyhow!(
+                "Failed to connect to OpenAI API: {} {}",
+                response.status(),
+                body,
+            )),
+        }
+    }
 }

crates/ai/src/assistant.rs 🔗

@@ -1,7 +1,7 @@
 use crate::{
     assistant_settings::{AssistantDockPosition, AssistantSettings},
-    MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
-    RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
+    stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage,
+    Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
 };
 use anyhow::{anyhow, Result};
 use chrono::{DateTime, Local};
@@ -12,26 +12,23 @@ use editor::{
     Anchor, Editor, ToOffset,
 };
 use fs::Fs;
-use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
+use futures::StreamExt;
 use gpui::{
     actions,
     elements::*,
-    executor::Background,
     geometry::vector::{vec2f, Vector2F},
     platform::{CursorStyle, MouseButton},
     Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
     Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
 };
-use isahc::{http::StatusCode, Request, RequestExt};
 use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
 use search::BufferSearchBar;
-use serde::Deserialize;
 use settings::SettingsStore;
 use std::{
     cell::RefCell,
     cmp, env,
     fmt::Write,
-    io, iter,
+    iter,
     ops::Range,
     path::{Path, PathBuf},
     rc::Rc,
@@ -46,8 +43,6 @@ use workspace::{
     Save, ToggleZoom, Toolbar, Workspace,
 };
 
-const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
-
 actions!(
     assistant,
     [
@@ -2144,92 +2139,6 @@ impl Message {
     }
 }
 
-async fn stream_completion(
-    api_key: String,
-    executor: Arc<Background>,
-    mut request: OpenAIRequest,
-) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
-    request.stream = true;
-
-    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
-
-    let json_data = serde_json::to_string(&request)?;
-    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
-        .header("Content-Type", "application/json")
-        .header("Authorization", format!("Bearer {}", api_key))
-        .body(json_data)?
-        .send_async()
-        .await?;
-
-    let status = response.status();
-    if status == StatusCode::OK {
-        executor
-            .spawn(async move {
-                let mut lines = BufReader::new(response.body_mut()).lines();
-
-                fn parse_line(
-                    line: Result<String, io::Error>,
-                ) -> Result<Option<OpenAIResponseStreamEvent>> {
-                    if let Some(data) = line?.strip_prefix("data: ") {
-                        let event = serde_json::from_str(&data)?;
-                        Ok(Some(event))
-                    } else {
-                        Ok(None)
-                    }
-                }
-
-                while let Some(line) = lines.next().await {
-                    if let Some(event) = parse_line(line).transpose() {
-                        let done = event.as_ref().map_or(false, |event| {
-                            event
-                                .choices
-                                .last()
-                                .map_or(false, |choice| choice.finish_reason.is_some())
-                        });
-                        if tx.unbounded_send(event).is_err() {
-                            break;
-                        }
-
-                        if done {
-                            break;
-                        }
-                    }
-                }
-
-                anyhow::Ok(())
-            })
-            .detach();
-
-        Ok(rx)
-    } else {
-        let mut body = String::new();
-        response.body_mut().read_to_string(&mut body).await?;
-
-        #[derive(Deserialize)]
-        struct OpenAIResponse {
-            error: OpenAIError,
-        }
-
-        #[derive(Deserialize)]
-        struct OpenAIError {
-            message: String,
-        }
-
-        match serde_json::from_str::<OpenAIResponse>(&body) {
-            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
-                "Failed to connect to OpenAI API: {}",
-                response.error.message,
-            )),
-
-            _ => Err(anyhow!(
-                "Failed to connect to OpenAI API: {} {}",
-                response.status(),
-                body,
-            )),
-        }
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;

crates/ai/src/refactor.rs 🔗

@@ -1,16 +1,24 @@
-use collections::HashMap;
-use editor::Editor;
+use crate::{stream_completion, OpenAIRequest, RequestMessage, Role};
+use collections::{BTreeMap, BTreeSet, HashMap, HashSet};
+use editor::{Anchor, Editor, MultiBuffer, MultiBufferSnapshot, ToOffset};
+use futures::{io::BufWriter, AsyncReadExt, AsyncWriteExt, StreamExt};
 use gpui::{
     actions, elements::*, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle,
+    WeakViewHandle,
 };
-use std::sync::Arc;
+use menu::Confirm;
+use serde::Deserialize;
+use similar::ChangeTag;
+use std::{env, iter, ops::Range, sync::Arc};
+use util::TryFutureExt;
 use workspace::{Modal, Workspace};
 
 actions!(assistant, [Refactor]);
 
-fn init(cx: &mut AppContext) {
+pub fn init(cx: &mut AppContext) {
     cx.set_global(RefactoringAssistant::new());
     cx.add_action(RefactoringModal::deploy);
+    cx.add_action(RefactoringModal::confirm);
 }
 
 pub struct RefactoringAssistant {
@@ -24,10 +32,122 @@ impl RefactoringAssistant {
         }
     }
 
-    fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {}
+    fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {
+        let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
+        let selection = editor.read(cx).selections.newest_anchor().clone();
+        let selected_text = buffer
+            .text_for_range(selection.start..selection.end)
+            .collect::<String>();
+        let language_name = buffer
+            .language_at(selection.start)
+            .map(|language| language.name());
+        let language_name = language_name.as_deref().unwrap_or("");
+        let request = OpenAIRequest {
+            model: "gpt-4".into(),
+            messages: vec![
+                RequestMessage {
+                role: Role::User,
+                content: format!(
+                    "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code."
+                ),
+            }],
+            stream: true,
+        };
+        let api_key = env::var("OPENAI_API_KEY").unwrap();
+        let response = stream_completion(api_key, cx.background().clone(), request);
+        let editor = editor.downgrade();
+        self.pending_edits_by_editor.insert(
+            editor.id(),
+            cx.spawn(|mut cx| {
+                async move {
+                    let selection_start = selection.start.to_offset(&buffer);
+
+                    // Find unique words in the selected text to use as diff boundaries.
+                    let mut duplicate_words = HashSet::default();
+                    let mut unique_old_words = HashMap::default();
+                    for (range, word) in words(&selected_text) {
+                        if !duplicate_words.contains(word) {
+                            if unique_old_words.insert(word, range.end).is_some() {
+                                unique_old_words.remove(word);
+                                duplicate_words.insert(word);
+                            }
+                        }
+                    }
+
+                    let mut new_text = String::new();
+                    let mut messages = response.await?;
+                    let mut new_word_search_start_ix = 0;
+                    let mut last_old_word_end_ix = 0;
+
+                    'outer: loop {
+                        let start = new_word_search_start_ix;
+                        let mut words = words(&new_text[start..]);
+                        while let Some((range, new_word)) = words.next() {
+                            // We found a word in the new text that was unique in the old text. We can use
+                            // it as a diff boundary, and start applying edits.
+                            if let Some(old_word_end_ix) = unique_old_words.remove(new_word) {
+                                if old_word_end_ix > last_old_word_end_ix {
+                                    drop(words);
+
+                                    let remainder = new_text.split_off(start + range.end);
+                                    let edits = diff(
+                                        selection_start + last_old_word_end_ix,
+                                        &selected_text[last_old_word_end_ix..old_word_end_ix],
+                                        &new_text,
+                                        &buffer,
+                                    );
+                                    editor.update(&mut cx, |editor, cx| {
+                                        editor
+                                            .buffer()
+                                            .update(cx, |buffer, cx| buffer.edit(edits, None, cx))
+                                    })?;
+
+                                    new_text = remainder;
+                                    new_word_search_start_ix = 0;
+                                    last_old_word_end_ix = old_word_end_ix;
+                                    continue 'outer;
+                                }
+                            }
+
+                            new_word_search_start_ix = start + range.end;
+                        }
+                        drop(words);
+
+                        // Buffer incoming text, stopping if the stream was exhausted.
+                        if let Some(message) = messages.next().await {
+                            let mut message = message?;
+                            if let Some(choice) = message.choices.pop() {
+                                if let Some(text) = choice.delta.content {
+                                    new_text.push_str(&text);
+                                }
+                            }
+                        } else {
+                            break;
+                        }
+                    }
+
+                    let edits = diff(
+                        selection_start + last_old_word_end_ix,
+                        &selected_text[last_old_word_end_ix..],
+                        &new_text,
+                        &buffer,
+                    );
+                    editor.update(&mut cx, |editor, cx| {
+                        editor
+                            .buffer()
+                            .update(cx, |buffer, cx| buffer.edit(edits, None, cx))
+                    })?;
+
+                    anyhow::Ok(())
+                }
+                .log_err()
+            }),
+        );
+    }
 }
 
 struct RefactoringModal {
+    editor: WeakViewHandle<Editor>,
     prompt_editor: ViewHandle<Editor>,
     has_focus: bool,
 }
@@ -42,7 +162,7 @@ impl View for RefactoringModal {
     }
 
     fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
-        todo!()
+        ChildView::new(&self.prompt_editor, cx).into_any()
     }
 
     fn focus_in(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
@@ -60,29 +180,96 @@ impl Modal for RefactoringModal {
     }
 
     fn dismiss_on_event(event: &Self::Event) -> bool {
-        todo!()
+        // TODO
+        false
     }
 }
 
 impl RefactoringModal {
     fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext<Workspace>) {
-        workspace.toggle_modal(cx, |_, cx| {
-            let prompt_editor = cx.add_view(|cx| {
-                Editor::auto_height(
-                    4,
-                    Some(Arc::new(|theme| theme.search.editor.input.clone())),
-                    cx,
-                )
+        if let Some(editor) = workspace
+            .active_item(cx)
+            .and_then(|item| Some(item.downcast::<Editor>()?.downgrade()))
+        {
+            workspace.toggle_modal(cx, |_, cx| {
+                let prompt_editor = cx.add_view(|cx| {
+                    Editor::auto_height(
+                        4,
+                        Some(Arc::new(|theme| theme.search.editor.input.clone())),
+                        cx,
+                    )
+                });
+                cx.add_view(|_| RefactoringModal {
+                    editor,
+                    prompt_editor,
+                    has_focus: false,
+                })
             });
-            cx.add_view(|_| RefactoringModal {
-                prompt_editor,
-                has_focus: false,
-            })
-        });
+        }
+    }
+
+    fn confirm(&mut self, _: &Confirm, cx: &mut ViewContext<Self>) {
+        if let Some(editor) = self.editor.upgrade(cx) {
+            let prompt = self.prompt_editor.read(cx).text(cx);
+            cx.update_global(|assistant: &mut RefactoringAssistant, cx| {
+                assistant.refactor(&editor, &prompt, cx);
+            });
+        }
     }
 }
+fn words(text: &str) -> impl Iterator<Item = (Range<usize>, &str)> {
+    let mut word_start_ix = None;
+    let mut chars = text.char_indices();
+    iter::from_fn(move || {
+        while let Some((ix, ch)) = chars.next() {
+            if let Some(start_ix) = word_start_ix {
+                if !ch.is_alphanumeric() {
+                    let word = &text[start_ix..ix];
+                    word_start_ix.take();
+                    return Some((start_ix..ix, word));
+                }
+            } else {
+                if ch.is_alphanumeric() {
+                    word_start_ix = Some(ix);
+                }
+            }
+        }
+        None
+    })
+}
 
-// ABCDEFG
-// XCDEFG
-//
-//
+fn diff<'a>(
+    start_ix: usize,
+    old_text: &'a str,
+    new_text: &'a str,
+    old_buffer_snapshot: &MultiBufferSnapshot,
+) -> Vec<(Range<Anchor>, &'a str)> {
+    let mut edit_start = start_ix;
+    let mut edits = Vec::new();
+    let diff = similar::TextDiff::from_words(old_text, &new_text);
+    for change in diff.iter_all_changes() {
+        let value = change.value();
+        let edit_end = edit_start + value.len();
+        match change.tag() {
+            ChangeTag::Equal => {
+                edit_start = edit_end;
+            }
+            ChangeTag::Delete => {
+                edits.push((
+                    old_buffer_snapshot.anchor_after(edit_start)
+                        ..old_buffer_snapshot.anchor_before(edit_end),
+                    "",
+                ));
+                edit_start = edit_end;
+            }
+            ChangeTag::Insert => {
+                edits.push((
+                    old_buffer_snapshot.anchor_after(edit_start)
+                        ..old_buffer_snapshot.anchor_after(edit_start),
+                    value,
+                ));
+            }
+        }
+    }
+    edits
+}

prompt.md 🔗

@@ -1,11 +0,0 @@
-Given a snippet as the input, you must produce an array of edits. An edit has the following structure:
-
-{ skip: "skip", delete: "delete", insert: "insert" }
-
-`skip` is a string in the input that should be left unchanged. `delete` is a string in the input located right after the skipped text that should be deleted. `insert` is a new string that should be inserted after the end of the text in `skip`. It's crucial that a string in the input can only be skipped or deleted once and only once.
-
-Your task is to produce an array of edits. `delete` and `insert` can be empty if nothing changed. When `skip`, `delete` or `insert` are longer than 20 characters, split them into multiple edits.
-
-Check your reasoning by concatenating all the strings in `skip` and `delete`. If the text is the same as the input snippet then the edits are valid.
-
-It's crucial that you reply only with edits. No prose or remarks.