Cargo.lock 🔗
@@ -116,6 +116,7 @@ dependencies = [
"serde",
"serde_json",
"settings",
+ "similar",
"smol",
"theme",
"tiktoken-rs 0.4.5",
Antonio Scandurra created
Cargo.lock | 1
crates/ai/Cargo.toml | 1
crates/ai/src/ai.rs | 107 ++++++++++++++++-
crates/ai/src/assistant.rs | 99 ----------------
crates/ai/src/refactor.rs | 233 ++++++++++++++++++++++++++++++++++++---
prompt.md | 11 -
6 files changed, 315 insertions(+), 137 deletions(-)
@@ -116,6 +116,7 @@ dependencies = [
"serde",
"serde_json",
"settings",
+ "similar",
"smol",
"theme",
"tiktoken-rs 0.4.5",
@@ -29,6 +29,7 @@ regex.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
+similar = "1.3"
smol.workspace = true
tiktoken-rs = "0.4"
@@ -2,27 +2,31 @@ pub mod assistant;
mod assistant_settings;
mod refactor;
-use anyhow::Result;
+use anyhow::{anyhow, Result};
pub use assistant::AssistantPanel;
use chrono::{DateTime, Local};
use collections::HashMap;
use fs::Fs;
-use futures::StreamExt;
-use gpui::AppContext;
+use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
+use gpui::{executor::Background, AppContext};
+use isahc::{http::StatusCode, Request, RequestExt};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{
cmp::Reverse,
ffi::OsStr,
fmt::{self, Display},
+ io,
path::PathBuf,
sync::Arc,
};
use util::paths::CONVERSATIONS_DIR;
+const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
+
// Data types for chat completion requests
#[derive(Debug, Serialize)]
-struct OpenAIRequest {
+pub struct OpenAIRequest {
model: String,
messages: Vec<RequestMessage>,
stream: bool,
@@ -116,7 +120,7 @@ struct RequestMessage {
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-struct ResponseMessage {
+pub struct ResponseMessage {
role: Option<Role>,
content: Option<String>,
}
@@ -150,7 +154,7 @@ impl Display for Role {
}
#[derive(Deserialize, Debug)]
-struct OpenAIResponseStreamEvent {
+pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
@@ -160,14 +164,14 @@ struct OpenAIResponseStreamEvent {
}
#[derive(Deserialize, Debug)]
-struct Usage {
+pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
-struct ChatChoiceDelta {
+pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
@@ -190,4 +194,91 @@ struct OpenAIChoice {
pub fn init(cx: &mut AppContext) {
assistant::init(cx);
+ refactor::init(cx);
+}
+
+pub async fn stream_completion(
+ api_key: String,
+ executor: Arc<Background>,
+ mut request: OpenAIRequest,
+) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+ request.stream = true;
+
+ let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
+
+ let json_data = serde_json::to_string(&request)?;
+ let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(json_data)?
+ .send_async()
+ .await?;
+
+ let status = response.status();
+ if status == StatusCode::OK {
+ executor
+ .spawn(async move {
+ let mut lines = BufReader::new(response.body_mut()).lines();
+
+ fn parse_line(
+ line: Result<String, io::Error>,
+ ) -> Result<Option<OpenAIResponseStreamEvent>> {
+ if let Some(data) = line?.strip_prefix("data: ") {
+ let event = serde_json::from_str(&data)?;
+ Ok(Some(event))
+ } else {
+ Ok(None)
+ }
+ }
+
+ while let Some(line) = lines.next().await {
+ if let Some(event) = parse_line(line).transpose() {
+ let done = event.as_ref().map_or(false, |event| {
+ event
+ .choices
+ .last()
+ .map_or(false, |choice| choice.finish_reason.is_some())
+ });
+ if tx.unbounded_send(event).is_err() {
+ break;
+ }
+
+ if done {
+ break;
+ }
+ }
+ }
+
+ anyhow::Ok(())
+ })
+ .detach();
+
+ Ok(rx)
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ #[derive(Deserialize)]
+ struct OpenAIResponse {
+ error: OpenAIError,
+ }
+
+ #[derive(Deserialize)]
+ struct OpenAIError {
+ message: String,
+ }
+
+ match serde_json::from_str::<OpenAIResponse>(&body) {
+ Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+ "Failed to connect to OpenAI API: {}",
+ response.error.message,
+ )),
+
+ _ => Err(anyhow!(
+ "Failed to connect to OpenAI API: {} {}",
+ response.status(),
+ body,
+ )),
+ }
+ }
}
@@ -1,7 +1,7 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings},
- MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
- RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
+ stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage,
+ Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
@@ -12,26 +12,23 @@ use editor::{
Anchor, Editor, ToOffset,
};
use fs::Fs;
-use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
+use futures::StreamExt;
use gpui::{
actions,
elements::*,
- executor::Background,
geometry::vector::{vec2f, Vector2F},
platform::{CursorStyle, MouseButton},
Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
};
-use isahc::{http::StatusCode, Request, RequestExt};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
use search::BufferSearchBar;
-use serde::Deserialize;
use settings::SettingsStore;
use std::{
cell::RefCell,
cmp, env,
fmt::Write,
- io, iter,
+ iter,
ops::Range,
path::{Path, PathBuf},
rc::Rc,
@@ -46,8 +43,6 @@ use workspace::{
Save, ToggleZoom, Toolbar, Workspace,
};
-const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
-
actions!(
assistant,
[
@@ -2144,92 +2139,6 @@ impl Message {
}
}
-async fn stream_completion(
- api_key: String,
- executor: Arc<Background>,
- mut request: OpenAIRequest,
-) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
- request.stream = true;
-
- let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
-
- let json_data = serde_json::to_string(&request)?;
- let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_key))
- .body(json_data)?
- .send_async()
- .await?;
-
- let status = response.status();
- if status == StatusCode::OK {
- executor
- .spawn(async move {
- let mut lines = BufReader::new(response.body_mut()).lines();
-
- fn parse_line(
- line: Result<String, io::Error>,
- ) -> Result<Option<OpenAIResponseStreamEvent>> {
- if let Some(data) = line?.strip_prefix("data: ") {
- let event = serde_json::from_str(&data)?;
- Ok(Some(event))
- } else {
- Ok(None)
- }
- }
-
- while let Some(line) = lines.next().await {
- if let Some(event) = parse_line(line).transpose() {
- let done = event.as_ref().map_or(false, |event| {
- event
- .choices
- .last()
- .map_or(false, |choice| choice.finish_reason.is_some())
- });
- if tx.unbounded_send(event).is_err() {
- break;
- }
-
- if done {
- break;
- }
- }
- }
-
- anyhow::Ok(())
- })
- .detach();
-
- Ok(rx)
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- #[derive(Deserialize)]
- struct OpenAIResponse {
- error: OpenAIError,
- }
-
- #[derive(Deserialize)]
- struct OpenAIError {
- message: String,
- }
-
- match serde_json::from_str::<OpenAIResponse>(&body) {
- Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
- "Failed to connect to OpenAI API: {}",
- response.error.message,
- )),
-
- _ => Err(anyhow!(
- "Failed to connect to OpenAI API: {} {}",
- response.status(),
- body,
- )),
- }
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
@@ -1,16 +1,24 @@
-use collections::HashMap;
-use editor::Editor;
+use crate::{stream_completion, OpenAIRequest, RequestMessage, Role};
+use collections::{BTreeMap, BTreeSet, HashMap, HashSet};
+use editor::{Anchor, Editor, MultiBuffer, MultiBufferSnapshot, ToOffset};
+use futures::{io::BufWriter, AsyncReadExt, AsyncWriteExt, StreamExt};
use gpui::{
actions, elements::*, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle,
+ WeakViewHandle,
};
-use std::sync::Arc;
+use menu::Confirm;
+use serde::Deserialize;
+use similar::ChangeTag;
+use std::{env, iter, ops::Range, sync::Arc};
+use util::TryFutureExt;
use workspace::{Modal, Workspace};
actions!(assistant, [Refactor]);
-fn init(cx: &mut AppContext) {
+pub fn init(cx: &mut AppContext) {
cx.set_global(RefactoringAssistant::new());
cx.add_action(RefactoringModal::deploy);
+ cx.add_action(RefactoringModal::confirm);
}
pub struct RefactoringAssistant {
@@ -24,10 +32,122 @@ impl RefactoringAssistant {
}
}
- fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {}
+ fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {
+ let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
+ let selection = editor.read(cx).selections.newest_anchor().clone();
+ let selected_text = buffer
+ .text_for_range(selection.start..selection.end)
+ .collect::<String>();
+ let language_name = buffer
+ .language_at(selection.start)
+ .map(|language| language.name());
+ let language_name = language_name.as_deref().unwrap_or("");
+ let request = OpenAIRequest {
+ model: "gpt-4".into(),
+ messages: vec![
+ RequestMessage {
+ role: Role::User,
+ content: format!(
+ "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code."
+ ),
+ }],
+ stream: true,
+ };
+ let api_key = env::var("OPENAI_API_KEY").unwrap();
+ let response = stream_completion(api_key, cx.background().clone(), request);
+ let editor = editor.downgrade();
+ self.pending_edits_by_editor.insert(
+ editor.id(),
+ cx.spawn(|mut cx| {
+ async move {
+ let selection_start = selection.start.to_offset(&buffer);
+
+ // Find unique words in the selected text to use as diff boundaries.
+ let mut duplicate_words = HashSet::default();
+ let mut unique_old_words = HashMap::default();
+ for (range, word) in words(&selected_text) {
+ if !duplicate_words.contains(word) {
+ if unique_old_words.insert(word, range.end).is_some() {
+ unique_old_words.remove(word);
+ duplicate_words.insert(word);
+ }
+ }
+ }
+
+ let mut new_text = String::new();
+ let mut messages = response.await?;
+ let mut new_word_search_start_ix = 0;
+ let mut last_old_word_end_ix = 0;
+
+ 'outer: loop {
+ let start = new_word_search_start_ix;
+ let mut words = words(&new_text[start..]);
+ while let Some((range, new_word)) = words.next() {
+ // We found a word in the new text that was unique in the old text. We can use
+ // it as a diff boundary, and start applying edits.
+ if let Some(old_word_end_ix) = unique_old_words.remove(new_word) {
+ if old_word_end_ix > last_old_word_end_ix {
+ drop(words);
+
+ let remainder = new_text.split_off(start + range.end);
+ let edits = diff(
+ selection_start + last_old_word_end_ix,
+ &selected_text[last_old_word_end_ix..old_word_end_ix],
+ &new_text,
+ &buffer,
+ );
+ editor.update(&mut cx, |editor, cx| {
+ editor
+ .buffer()
+ .update(cx, |buffer, cx| buffer.edit(edits, None, cx))
+ })?;
+
+ new_text = remainder;
+ new_word_search_start_ix = 0;
+ last_old_word_end_ix = old_word_end_ix;
+ continue 'outer;
+ }
+ }
+
+ new_word_search_start_ix = start + range.end;
+ }
+ drop(words);
+
+ // Buffer incoming text, stopping if the stream was exhausted.
+ if let Some(message) = messages.next().await {
+ let mut message = message?;
+ if let Some(choice) = message.choices.pop() {
+ if let Some(text) = choice.delta.content {
+ new_text.push_str(&text);
+ }
+ }
+ } else {
+ break;
+ }
+ }
+
+ let edits = diff(
+ selection_start + last_old_word_end_ix,
+ &selected_text[last_old_word_end_ix..],
+ &new_text,
+ &buffer,
+ );
+ editor.update(&mut cx, |editor, cx| {
+ editor
+ .buffer()
+ .update(cx, |buffer, cx| buffer.edit(edits, None, cx))
+ })?;
+
+ anyhow::Ok(())
+ }
+ .log_err()
+ }),
+ );
+ }
}
struct RefactoringModal {
+ editor: WeakViewHandle<Editor>,
prompt_editor: ViewHandle<Editor>,
has_focus: bool,
}
@@ -42,7 +162,7 @@ impl View for RefactoringModal {
}
fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
- todo!()
+ ChildView::new(&self.prompt_editor, cx).into_any()
}
fn focus_in(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
@@ -60,29 +180,96 @@ impl Modal for RefactoringModal {
}
fn dismiss_on_event(event: &Self::Event) -> bool {
- todo!()
+ // TODO
+ false
}
}
impl RefactoringModal {
fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext<Workspace>) {
- workspace.toggle_modal(cx, |_, cx| {
- let prompt_editor = cx.add_view(|cx| {
- Editor::auto_height(
- 4,
- Some(Arc::new(|theme| theme.search.editor.input.clone())),
- cx,
- )
+ if let Some(editor) = workspace
+ .active_item(cx)
+ .and_then(|item| Some(item.downcast::<Editor>()?.downgrade()))
+ {
+ workspace.toggle_modal(cx, |_, cx| {
+ let prompt_editor = cx.add_view(|cx| {
+ Editor::auto_height(
+ 4,
+ Some(Arc::new(|theme| theme.search.editor.input.clone())),
+ cx,
+ )
+ });
+ cx.add_view(|_| RefactoringModal {
+ editor,
+ prompt_editor,
+ has_focus: false,
+ })
});
- cx.add_view(|_| RefactoringModal {
- prompt_editor,
- has_focus: false,
- })
- });
+ }
+ }
+
+ fn confirm(&mut self, _: &Confirm, cx: &mut ViewContext<Self>) {
+ if let Some(editor) = self.editor.upgrade(cx) {
+ let prompt = self.prompt_editor.read(cx).text(cx);
+ cx.update_global(|assistant: &mut RefactoringAssistant, cx| {
+ assistant.refactor(&editor, &prompt, cx);
+ });
+ }
}
}
+fn words(text: &str) -> impl Iterator<Item = (Range<usize>, &str)> {
+ let mut word_start_ix = None;
+ let mut chars = text.char_indices();
+ iter::from_fn(move || {
+ while let Some((ix, ch)) = chars.next() {
+ if let Some(start_ix) = word_start_ix {
+ if !ch.is_alphanumeric() {
+ let word = &text[start_ix..ix];
+ word_start_ix.take();
+ return Some((start_ix..ix, word));
+ }
+ } else {
+ if ch.is_alphanumeric() {
+ word_start_ix = Some(ix);
+ }
+ }
+ }
+ None
+ })
+}
-// ABCDEFG
-// XCDEFG
-//
-//
+fn diff<'a>(
+ start_ix: usize,
+ old_text: &'a str,
+ new_text: &'a str,
+ old_buffer_snapshot: &MultiBufferSnapshot,
+) -> Vec<(Range<Anchor>, &'a str)> {
+ let mut edit_start = start_ix;
+ let mut edits = Vec::new();
+ let diff = similar::TextDiff::from_words(old_text, &new_text);
+ for change in diff.iter_all_changes() {
+ let value = change.value();
+ let edit_end = edit_start + value.len();
+ match change.tag() {
+ ChangeTag::Equal => {
+ edit_start = edit_end;
+ }
+ ChangeTag::Delete => {
+ edits.push((
+ old_buffer_snapshot.anchor_after(edit_start)
+ ..old_buffer_snapshot.anchor_before(edit_end),
+ "",
+ ));
+ edit_start = edit_end;
+ }
+ ChangeTag::Insert => {
+ edits.push((
+ old_buffer_snapshot.anchor_after(edit_start)
+ ..old_buffer_snapshot.anchor_after(edit_start),
+ value,
+ ));
+ }
+ }
+ }
+ edits
+}
@@ -1,11 +0,0 @@
-Given a snippet as the input, you must produce an array of edits. An edit has the following structure:
-
-{ skip: "skip", delete: "delete", insert: "insert" }
-
-`skip` is a string in the input that should be left unchanged. `delete` is a string in the input located right after the skipped text that should be deleted. `insert` is a new string that should be inserted after the end of the text in `skip`. It's crucial that a string in the input can only be skipped or deleted once and only once.
-
-Your task is to produce an array of edits. `delete` and `insert` can be empty if nothing changed. When `skip`, `delete` or `insert` are longer than 20 characters, split them into multiple edits.
-
-Check your reasoning by concatenating all the strings in `skip` and `delete`. If the text is the same as the input snippet then the edits are valid.
-
-It's crucial that you reply only with edits. No prose or remarks.