crates/ai/src/ai.rs 🔗
@@ -1,5 +1,6 @@
pub mod assistant;
mod assistant_settings;
+mod codegen;
mod streaming_diff;
use anyhow::{anyhow, Result};
Antonio Scandurra created
crates/ai/src/ai.rs | 1
crates/ai/src/assistant.rs | 553 ++++++--------------------------
crates/ai/src/codegen.rs | 468 +++++++++++++++++++++++++++
crates/editor/src/editor.rs | 10
crates/editor/src/multi_buffer.rs | 43 +-
5 files changed, 607 insertions(+), 468 deletions(-)
@@ -1,5 +1,6 @@
pub mod assistant;
mod assistant_settings;
+mod codegen;
mod streaming_diff;
use anyhow::{anyhow, Result};
@@ -1,9 +1,8 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
- stream_completion,
- streaming_diff::{Hunk, StreamingDiff},
- MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, Role,
- SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
+ codegen::{self, Codegen, OpenAICompletionProvider},
+ stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage,
+ Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
@@ -13,10 +12,10 @@ use editor::{
BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint,
},
scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
- Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, ToPoint,
+ Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset,
};
use fs::Fs;
-use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
+use futures::StreamExt;
use gpui::{
actions,
elements::{
@@ -30,17 +29,14 @@ use gpui::{
ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle,
WindowContext,
};
-use language::{
- language_settings::SoftWrap, Buffer, LanguageRegistry, Point, Rope, ToOffset as _,
- TransactionId,
-};
+use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
use search::BufferSearchBar;
use settings::SettingsStore;
use std::{
cell::{Cell, RefCell},
cmp, env,
fmt::Write,
- future, iter,
+ iter,
ops::Range,
path::{Path, PathBuf},
rc::Rc,
@@ -266,10 +262,22 @@ impl AssistantPanel {
}
fn new_inline_assist(&mut self, editor: &ViewHandle<Editor>, cx: &mut ViewContext<Self>) {
+ let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
+ api_key
+ } else {
+ return;
+ };
+
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
let selection = editor.read(cx).selections.newest_anchor().clone();
let range = selection.start.bias_left(&snapshot)..selection.end.bias_right(&snapshot);
+ let provider = Arc::new(OpenAICompletionProvider::new(
+ api_key,
+ cx.background().clone(),
+ ));
+ let codegen =
+ cx.add_model(|cx| Codegen::new(editor.read(cx).buffer().clone(), range, provider, cx));
let assist_kind = if editor.read(cx).selections.newest::<usize>(cx).is_empty() {
InlineAssistKind::Generate
} else {
@@ -283,6 +291,7 @@ impl AssistantPanel {
measurements.clone(),
self.include_conversation_in_next_inline_assist,
self.inline_prompt_history.clone(),
+ codegen.clone(),
cx,
);
cx.focus_self();
@@ -323,44 +332,53 @@ impl AssistantPanel {
PendingInlineAssist {
kind: assist_kind,
editor: editor.downgrade(),
- range,
- highlighted_ranges: Default::default(),
inline_assistant: Some((block_id, inline_assistant.clone())),
- code_generation: Task::ready(None),
- transaction_id: None,
+ codegen: codegen.clone(),
_subscriptions: vec![
cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event),
cx.subscribe(editor, {
let inline_assistant = inline_assistant.downgrade();
- move |this, editor, event, cx| {
+ move |_, editor, event, cx| {
if let Some(inline_assistant) = inline_assistant.upgrade(cx) {
- match event {
- editor::Event::SelectionsChanged { local } => {
- if *local && inline_assistant.read(cx).has_focus {
- cx.focus(&editor);
- }
+ if let editor::Event::SelectionsChanged { local } = event {
+ if *local && inline_assistant.read(cx).has_focus {
+ cx.focus(&editor);
}
- editor::Event::TransactionUndone {
- transaction_id: tx_id,
- } => {
- if let Some(pending_assist) =
- this.pending_inline_assists.get(&inline_assist_id)
- {
- if pending_assist.transaction_id == Some(*tx_id) {
- // Notice we are supplying `undo: false` here. This
- // is because there's no need to undo the transaction
- // because the user just did so.
- this.close_inline_assist(
- inline_assist_id,
- false,
- cx,
- );
- }
- }
+ }
+ }
+ }
+ }),
+ cx.subscribe(&codegen, move |this, codegen, event, cx| match event {
+ codegen::Event::Undone => {
+ this.finish_inline_assist(inline_assist_id, false, cx)
+ }
+ codegen::Event::Finished => {
+ let pending_assist = if let Some(pending_assist) =
+ this.pending_inline_assists.get(&inline_assist_id)
+ {
+ pending_assist
+ } else {
+ return;
+ };
+
+ let error = codegen
+ .read(cx)
+ .error()
+ .map(|error| format!("Inline assistant error: {}", error));
+ if let Some(error) = error {
+ if pending_assist.inline_assistant.is_none() {
+ if let Some(workspace) = this.workspace.upgrade(cx) {
+ workspace.update(cx, |workspace, cx| {
+ workspace.show_toast(
+ Toast::new(inline_assist_id, error),
+ cx,
+ );
+ })
}
- _ => {}
}
}
+
+ this.finish_inline_assist(inline_assist_id, false, cx);
}
}),
],
@@ -388,7 +406,7 @@ impl AssistantPanel {
self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
}
InlineAssistantEvent::Canceled => {
- self.close_inline_assist(assist_id, true, cx);
+ self.finish_inline_assist(assist_id, true, cx);
}
InlineAssistantEvent::Dismissed => {
self.hide_inline_assist(assist_id, cx);
@@ -417,7 +435,7 @@ impl AssistantPanel {
.get(&editor.downgrade())
.and_then(|assist_ids| assist_ids.last().copied())
{
- panel.close_inline_assist(assist_id, true, cx);
+ panel.finish_inline_assist(assist_id, true, cx);
true
} else {
false
@@ -432,7 +450,7 @@ impl AssistantPanel {
cx.propagate_action();
}
- fn close_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext<Self>) {
+ fn finish_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext<Self>) {
self.hide_inline_assist(assist_id, cx);
if let Some(pending_assist) = self.pending_inline_assists.remove(&assist_id) {
@@ -450,13 +468,9 @@ impl AssistantPanel {
self.update_highlights_for_editor(&editor, cx);
if undo {
- if let Some(transaction_id) = pending_assist.transaction_id {
- editor.update(cx, |editor, cx| {
- editor.buffer().update(cx, |buffer, cx| {
- buffer.undo_transaction(transaction_id, cx)
- });
- });
- }
+ pending_assist
+ .codegen
+ .update(cx, |codegen, cx| codegen.undo(cx));
}
}
}
@@ -481,12 +495,6 @@ impl AssistantPanel {
include_conversation: bool,
cx: &mut ViewContext<Self>,
) {
- let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
- api_key
- } else {
- return;
- };
-
let conversation = if include_conversation {
self.active_editor()
.map(|editor| editor.read(cx).conversation.clone())
@@ -514,56 +522,9 @@ impl AssistantPanel {
self.inline_prompt_history.pop_front();
}
- let range = pending_assist.range.clone();
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
- let selected_text = snapshot
- .text_for_range(range.start..range.end)
- .collect::<Rope>();
-
- let selection_start = range.start.to_point(&snapshot);
- let selection_end = range.end.to_point(&snapshot);
-
- let mut base_indent: Option<language::IndentSize> = None;
- let mut start_row = selection_start.row;
- if snapshot.is_line_blank(start_row) {
- if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) {
- start_row = prev_non_blank_row;
- }
- }
- for row in start_row..=selection_end.row {
- if snapshot.is_line_blank(row) {
- continue;
- }
-
- let line_indent = snapshot.indent_size_for_line(row);
- if let Some(base_indent) = base_indent.as_mut() {
- if line_indent.len < base_indent.len {
- *base_indent = line_indent;
- }
- } else {
- base_indent = Some(line_indent);
- }
- }
-
- let mut normalized_selected_text = selected_text.clone();
- if let Some(base_indent) = base_indent {
- for row in selection_start.row..=selection_end.row {
- let selection_row = row - selection_start.row;
- let line_start =
- normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
- let indent_len = if row == selection_start.row {
- base_indent.len.saturating_sub(selection_start.column)
- } else {
- let line_len = normalized_selected_text.line_len(selection_row);
- cmp::min(line_len, base_indent.len)
- };
- let indent_end = cmp::min(
- line_start + indent_len as usize,
- normalized_selected_text.len(),
- );
- normalized_selected_text.replace(line_start..indent_end, "");
- }
- }
+ let range = pending_assist.codegen.read(cx).range();
+ let selected_text = snapshot.text_for_range(range.clone()).collect::<String>();
let language = snapshot.language_at(range.start);
let language_name = if let Some(language) = language.as_ref() {
@@ -608,7 +569,7 @@ impl AssistantPanel {
} else {
writeln!(prompt, "```").unwrap();
}
- writeln!(prompt, "{normalized_selected_text}").unwrap();
+ writeln!(prompt, "{selected_text}").unwrap();
writeln!(prompt, "```").unwrap();
writeln!(prompt).unwrap();
writeln!(
@@ -689,209 +650,9 @@ impl AssistantPanel {
messages,
stream: true,
};
- let response = stream_completion(api_key, cx.background().clone(), request);
- let editor = editor.downgrade();
-
- pending_assist.code_generation = cx.spawn(|this, mut cx| {
- async move {
- let mut edit_start = range.start.to_offset(&snapshot);
-
- let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
- let diff = cx.background().spawn(async move {
- let chunks = strip_markdown_codeblock(response.await?.filter_map(
- |message| async move {
- match message {
- Ok(mut message) => Some(Ok(message.choices.pop()?.delta.content?)),
- Err(error) => Some(Err(error)),
- }
- },
- ));
- futures::pin_mut!(chunks);
- let mut diff = StreamingDiff::new(selected_text.to_string());
-
- let mut indent_len;
- let indent_text;
- if let Some(base_indent) = base_indent {
- indent_len = base_indent.len;
- indent_text = match base_indent.kind {
- language::IndentKind::Space => " ",
- language::IndentKind::Tab => "\t",
- };
- } else {
- indent_len = 0;
- indent_text = "";
- };
-
- let mut first_line_len = 0;
- let mut first_line_non_whitespace_char_ix = None;
- let mut first_line = true;
- let mut new_text = String::new();
-
- while let Some(chunk) = chunks.next().await {
- let chunk = chunk?;
-
- let mut lines = chunk.split('\n');
- if let Some(mut line) = lines.next() {
- if first_line {
- if first_line_non_whitespace_char_ix.is_none() {
- if let Some(mut char_ix) =
- line.find(|ch: char| !ch.is_whitespace())
- {
- line = &line[char_ix..];
- char_ix += first_line_len;
- first_line_non_whitespace_char_ix = Some(char_ix);
- let first_line_indent = char_ix
- .saturating_sub(selection_start.column as usize)
- as usize;
- new_text.push_str(&indent_text.repeat(first_line_indent));
- indent_len = indent_len.saturating_sub(char_ix as u32);
- }
- }
- first_line_len += line.len();
- }
-
- if first_line_non_whitespace_char_ix.is_some() {
- new_text.push_str(line);
- }
- }
-
- for line in lines {
- first_line = false;
- new_text.push('\n');
- if !line.is_empty() {
- new_text.push_str(&indent_text.repeat(indent_len as usize));
- }
- new_text.push_str(line);
- }
-
- let hunks = diff.push_new(&new_text);
- hunks_tx.send(hunks).await?;
- new_text.clear();
- }
- hunks_tx.send(diff.finish()).await?;
-
- anyhow::Ok(())
- });
-
- while let Some(hunks) = hunks_rx.next().await {
- let editor = if let Some(editor) = editor.upgrade(&cx) {
- editor
- } else {
- break;
- };
-
- let this = if let Some(this) = this.upgrade(&cx) {
- this
- } else {
- break;
- };
-
- this.update(&mut cx, |this, cx| {
- let pending_assist = if let Some(pending_assist) =
- this.pending_inline_assists.get_mut(&inline_assist_id)
- {
- pending_assist
- } else {
- return;
- };
-
- pending_assist.highlighted_ranges.clear();
- editor.update(cx, |editor, cx| {
- let transaction = editor.buffer().update(cx, |buffer, cx| {
- // Avoid grouping assistant edits with user edits.
- buffer.finalize_last_transaction(cx);
-
- buffer.start_transaction(cx);
- buffer.edit(
- hunks.into_iter().filter_map(|hunk| match hunk {
- Hunk::Insert { text } => {
- let edit_start = snapshot.anchor_after(edit_start);
- Some((edit_start..edit_start, text))
- }
- Hunk::Remove { len } => {
- let edit_end = edit_start + len;
- let edit_range = snapshot.anchor_after(edit_start)
- ..snapshot.anchor_before(edit_end);
- edit_start = edit_end;
- Some((edit_range, String::new()))
- }
- Hunk::Keep { len } => {
- let edit_end = edit_start + len;
- let edit_range = snapshot.anchor_after(edit_start)
- ..snapshot.anchor_before(edit_end);
- edit_start += len;
- pending_assist.highlighted_ranges.push(edit_range);
- None
- }
- }),
- None,
- cx,
- );
-
- buffer.end_transaction(cx)
- });
-
- if let Some(transaction) = transaction {
- if let Some(first_transaction) = pending_assist.transaction_id {
- // Group all assistant edits into the first transaction.
- editor.buffer().update(cx, |buffer, cx| {
- buffer.merge_transactions(
- transaction,
- first_transaction,
- cx,
- )
- });
- } else {
- pending_assist.transaction_id = Some(transaction);
- editor.buffer().update(cx, |buffer, cx| {
- buffer.finalize_last_transaction(cx)
- });
- }
- }
- });
-
- this.update_highlights_for_editor(&editor, cx);
- });
- }
-
- if let Err(error) = diff.await {
- this.update(&mut cx, |this, cx| {
- let pending_assist = if let Some(pending_assist) =
- this.pending_inline_assists.get_mut(&inline_assist_id)
- {
- pending_assist
- } else {
- return;
- };
-
- if let Some((_, inline_assistant)) =
- pending_assist.inline_assistant.as_ref()
- {
- inline_assistant.update(cx, |inline_assistant, cx| {
- inline_assistant.set_error(error, cx);
- });
- } else if let Some(workspace) = this.workspace.upgrade(cx) {
- workspace.update(cx, |workspace, cx| {
- workspace.show_toast(
- Toast::new(
- inline_assist_id,
- format!("Inline assistant error: {}", error),
- ),
- cx,
- );
- })
- }
- })?;
- } else {
- let _ = this.update(&mut cx, |this, cx| {
- this.close_inline_assist(inline_assist_id, false, cx)
- });
- }
-
- anyhow::Ok(())
- }
- .log_err()
- });
+ pending_assist
+ .codegen
+ .update(cx, |codegen, cx| codegen.start(request, cx));
}
fn update_highlights_for_editor(
@@ -909,8 +670,9 @@ impl AssistantPanel {
for inline_assist_id in inline_assist_ids {
if let Some(pending_assist) = self.pending_inline_assists.get(inline_assist_id) {
- background_ranges.push(pending_assist.range.clone());
- foreground_ranges.extend(pending_assist.highlighted_ranges.iter().cloned());
+ let codegen = pending_assist.codegen.read(cx);
+ background_ranges.push(codegen.range());
+ foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
}
}
@@ -2900,11 +2662,11 @@ struct InlineAssistant {
has_focus: bool,
include_conversation: bool,
measurements: Rc<Cell<BlockMeasurements>>,
- error: Option<anyhow::Error>,
prompt_history: VecDeque<String>,
prompt_history_ix: Option<usize>,
pending_prompt: String,
- _subscription: Subscription,
+ codegen: ModelHandle<Codegen>,
+ _subscriptions: Vec<Subscription>,
}
impl Entity for InlineAssistant {
@@ -2933,7 +2695,7 @@ impl View for InlineAssistant {
.element()
.aligned(),
)
- .with_children(if let Some(error) = self.error.as_ref() {
+ .with_children(if let Some(error) = self.codegen.read(cx).error() {
Some(
Svg::new("icons/circle_x_mark_12.svg")
.with_color(theme.assistant.error_icon.color)
@@ -3011,6 +2773,7 @@ impl InlineAssistant {
measurements: Rc<Cell<BlockMeasurements>>,
include_conversation: bool,
prompt_history: VecDeque<String>,
+ codegen: ModelHandle<Codegen>,
cx: &mut ViewContext<Self>,
) -> Self {
let prompt_editor = cx.add_view(|cx| {
@@ -3025,7 +2788,10 @@ impl InlineAssistant {
editor.set_placeholder_text(placeholder, cx);
editor
});
- let subscription = cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events);
+ let subscriptions = vec![
+ cx.observe(&codegen, Self::handle_codegen_changed),
+ cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
+ ];
Self {
id,
prompt_editor,
@@ -3033,11 +2799,11 @@ impl InlineAssistant {
has_focus: false,
include_conversation,
measurements,
- error: None,
prompt_history,
prompt_history_ix: None,
pending_prompt: String::new(),
- _subscription: subscription,
+ codegen,
+ _subscriptions: subscriptions,
}
}
@@ -3053,6 +2819,31 @@ impl InlineAssistant {
}
}
+ fn handle_codegen_changed(&mut self, _: ModelHandle<Codegen>, cx: &mut ViewContext<Self>) {
+ let is_read_only = !self.codegen.read(cx).idle();
+ self.prompt_editor.update(cx, |editor, cx| {
+ let was_read_only = editor.read_only();
+ if was_read_only != is_read_only {
+ if is_read_only {
+ editor.set_read_only(true);
+ editor.set_field_editor_style(
+ Some(Arc::new(|theme| {
+ theme.assistant.inline.disabled_editor.clone()
+ })),
+ cx,
+ );
+ } else {
+ editor.set_read_only(false);
+ editor.set_field_editor_style(
+ Some(Arc::new(|theme| theme.assistant.inline.editor.clone())),
+ cx,
+ );
+ }
+ }
+ });
+ cx.notify();
+ }
+
fn cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
cx.emit(InlineAssistantEvent::Canceled);
}
@@ -3076,7 +2867,6 @@ impl InlineAssistant {
include_conversation: self.include_conversation,
});
self.confirmed = true;
- self.error = None;
cx.notify();
}
}
@@ -3093,19 +2883,6 @@ impl InlineAssistant {
cx.notify();
}
- fn set_error(&mut self, error: anyhow::Error, cx: &mut ViewContext<Self>) {
- self.error = Some(error);
- self.confirmed = false;
- self.prompt_editor.update(cx, |editor, cx| {
- editor.set_read_only(false);
- editor.set_field_editor_style(
- Some(Arc::new(|theme| theme.assistant.inline.editor.clone())),
- cx,
- );
- });
- cx.notify();
- }
-
fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
if let Some(ix) = self.prompt_history_ix {
if ix > 0 {
@@ -3154,11 +2931,8 @@ struct BlockMeasurements {
struct PendingInlineAssist {
kind: InlineAssistKind,
editor: WeakViewHandle<Editor>,
- range: Range<Anchor>,
- highlighted_ranges: Vec<Range<Anchor>>,
inline_assistant: Option<(BlockId, ViewHandle<InlineAssistant>)>,
- code_generation: Task<Option<()>>,
- transaction_id: Option<TransactionId>,
+ codegen: ModelHandle<Codegen>,
_subscriptions: Vec<Subscription>,
}
@@ -3184,65 +2958,10 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
}
}
-fn strip_markdown_codeblock(
- stream: impl Stream<Item = Result<String>>,
-) -> impl Stream<Item = Result<String>> {
- let mut first_line = true;
- let mut buffer = String::new();
- let mut starts_with_fenced_code_block = false;
- stream.filter_map(move |chunk| {
- let chunk = match chunk {
- Ok(chunk) => chunk,
- Err(err) => return future::ready(Some(Err(err))),
- };
- buffer.push_str(&chunk);
-
- if first_line {
- if buffer == "" || buffer == "`" || buffer == "``" {
- return future::ready(None);
- } else if buffer.starts_with("```") {
- starts_with_fenced_code_block = true;
- if let Some(newline_ix) = buffer.find('\n') {
- buffer.replace_range(..newline_ix + 1, "");
- first_line = false;
- } else {
- return future::ready(None);
- }
- }
- }
-
- let text = if starts_with_fenced_code_block {
- buffer
- .strip_suffix("\n```\n")
- .or_else(|| buffer.strip_suffix("\n```"))
- .or_else(|| buffer.strip_suffix("\n``"))
- .or_else(|| buffer.strip_suffix("\n`"))
- .or_else(|| buffer.strip_suffix('\n'))
- .unwrap_or(&buffer)
- } else {
- &buffer
- };
-
- if text.contains('\n') {
- first_line = false;
- }
-
- let remainder = buffer.split_off(text.len());
- let result = if buffer.is_empty() {
- None
- } else {
- Some(Ok(buffer.clone()))
- };
- buffer = remainder;
- future::ready(result)
- })
-}
-
#[cfg(test)]
mod tests {
use super::*;
use crate::MessageId;
- use futures::stream;
use gpui::AppContext;
#[gpui::test]
@@ -3611,62 +3330,6 @@ mod tests {
);
}
- #[gpui::test]
- async fn test_strip_markdown_codeblock() {
- assert_eq!(
- strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
- .map(|chunk| chunk.unwrap())
- .collect::<String>()
- .await,
- "Lorem ipsum dolor"
- );
- assert_eq!(
- strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
- .map(|chunk| chunk.unwrap())
- .collect::<String>()
- .await,
- "Lorem ipsum dolor"
- );
- assert_eq!(
- strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
- .map(|chunk| chunk.unwrap())
- .collect::<String>()
- .await,
- "Lorem ipsum dolor"
- );
- assert_eq!(
- strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
- .map(|chunk| chunk.unwrap())
- .collect::<String>()
- .await,
- "Lorem ipsum dolor"
- );
- assert_eq!(
- strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
- .map(|chunk| chunk.unwrap())
- .collect::<String>()
- .await,
- "```js\nLorem ipsum dolor\n```"
- );
- assert_eq!(
- strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
- .map(|chunk| chunk.unwrap())
- .collect::<String>()
- .await,
- "``\nLorem ipsum dolor\n```"
- );
-
- fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
- stream::iter(
- text.chars()
- .collect::<Vec<_>>()
- .chunks(size)
- .map(|chunk| Ok(chunk.iter().collect::<String>()))
- .collect::<Vec<_>>(),
- )
- }
- }
-
fn messages(
conversation: &ModelHandle<Conversation>,
cx: &AppContext,
@@ -0,0 +1,468 @@
+use crate::{
+ stream_completion,
+ streaming_diff::{Hunk, StreamingDiff},
+ OpenAIRequest,
+};
+use anyhow::Result;
+use editor::{multi_buffer, Anchor, MultiBuffer, ToOffset, ToPoint};
+use futures::{
+ channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
+};
+use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
+use language::{IndentSize, Point, Rope, TransactionId};
+use std::{cmp, future, ops::Range, sync::Arc};
+
+pub trait CompletionProvider {
+ fn complete(
+ &self,
+ prompt: OpenAIRequest,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+}
+
+pub struct OpenAICompletionProvider {
+ api_key: String,
+ executor: Arc<Background>,
+}
+
+impl OpenAICompletionProvider {
+ pub fn new(api_key: String, executor: Arc<Background>) -> Self {
+ Self { api_key, executor }
+ }
+}
+
+impl CompletionProvider for OpenAICompletionProvider {
+ fn complete(
+ &self,
+ prompt: OpenAIRequest,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
+ async move {
+ let response = request.await?;
+ let stream = response
+ .filter_map(|response| async move {
+ match response {
+ Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
+ Err(error) => Some(Err(error)),
+ }
+ })
+ .boxed();
+ Ok(stream)
+ }
+ .boxed()
+ }
+}
+
+pub enum Event {
+ Finished,
+ Undone,
+}
+
+pub struct Codegen {
+ provider: Arc<dyn CompletionProvider>,
+ buffer: ModelHandle<MultiBuffer>,
+ range: Range<Anchor>,
+ last_equal_ranges: Vec<Range<Anchor>>,
+ transaction_id: Option<TransactionId>,
+ error: Option<anyhow::Error>,
+ generation: Task<()>,
+ idle: bool,
+ _subscription: gpui::Subscription,
+}
+
+impl Entity for Codegen {
+ type Event = Event;
+}
+
+impl Codegen {
+ pub fn new(
+ buffer: ModelHandle<MultiBuffer>,
+ range: Range<Anchor>,
+ provider: Arc<dyn CompletionProvider>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ Self {
+ provider,
+ buffer: buffer.clone(),
+ range,
+ last_equal_ranges: Default::default(),
+ transaction_id: Default::default(),
+ error: Default::default(),
+ idle: true,
+ generation: Task::ready(()),
+ _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
+ }
+ }
+
+ fn handle_buffer_event(
+ &mut self,
+ _buffer: ModelHandle<MultiBuffer>,
+ event: &multi_buffer::Event,
+ cx: &mut ModelContext<Self>,
+ ) {
+ if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
+ if self.transaction_id == Some(*transaction_id) {
+ self.transaction_id = None;
+ self.generation = Task::ready(());
+ cx.emit(Event::Undone);
+ }
+ }
+ }
+
+ pub fn range(&self) -> Range<Anchor> {
+ self.range.clone()
+ }
+
+ pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
+ &self.last_equal_ranges
+ }
+
+ pub fn idle(&self) -> bool {
+ self.idle
+ }
+
+ pub fn error(&self) -> Option<&anyhow::Error> {
+ self.error.as_ref()
+ }
+
+ pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
+ let range = self.range.clone();
+ let snapshot = self.buffer.read(cx).snapshot(cx);
+ let selected_text = snapshot
+ .text_for_range(range.start..range.end)
+ .collect::<Rope>();
+
+ let selection_start = range.start.to_point(&snapshot);
+ let selection_end = range.end.to_point(&snapshot);
+
+ let mut base_indent: Option<IndentSize> = None;
+ let mut start_row = selection_start.row;
+ if snapshot.is_line_blank(start_row) {
+ if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) {
+ start_row = prev_non_blank_row;
+ }
+ }
+ for row in start_row..=selection_end.row {
+ if snapshot.is_line_blank(row) {
+ continue;
+ }
+
+ let line_indent = snapshot.indent_size_for_line(row);
+ if let Some(base_indent) = base_indent.as_mut() {
+ if line_indent.len < base_indent.len {
+ *base_indent = line_indent;
+ }
+ } else {
+ base_indent = Some(line_indent);
+ }
+ }
+
+ let mut normalized_selected_text = selected_text.clone();
+ if let Some(base_indent) = base_indent {
+ for row in selection_start.row..=selection_end.row {
+ let selection_row = row - selection_start.row;
+ let line_start =
+ normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
+ let indent_len = if row == selection_start.row {
+ base_indent.len.saturating_sub(selection_start.column)
+ } else {
+ let line_len = normalized_selected_text.line_len(selection_row);
+ cmp::min(line_len, base_indent.len)
+ };
+ let indent_end = cmp::min(
+ line_start + indent_len as usize,
+ normalized_selected_text.len(),
+ );
+ normalized_selected_text.replace(line_start..indent_end, "");
+ }
+ }
+
+ let response = self.provider.complete(prompt);
+ self.generation = cx.spawn_weak(|this, mut cx| {
+ async move {
+ let generate = async {
+ let mut edit_start = range.start.to_offset(&snapshot);
+
+ let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
+ let diff = cx.background().spawn(async move {
+ let chunks = strip_markdown_codeblock(response.await?);
+ futures::pin_mut!(chunks);
+ let mut diff = StreamingDiff::new(selected_text.to_string());
+
+ let mut indent_len;
+ let indent_text;
+ if let Some(base_indent) = base_indent {
+ indent_len = base_indent.len;
+ indent_text = match base_indent.kind {
+ language::IndentKind::Space => " ",
+ language::IndentKind::Tab => "\t",
+ };
+ } else {
+ indent_len = 0;
+ indent_text = "";
+ };
+
+ let mut first_line_len = 0;
+ let mut first_line_non_whitespace_char_ix = None;
+ let mut first_line = true;
+ let mut new_text = String::new();
+
+ while let Some(chunk) = chunks.next().await {
+ let chunk = chunk?;
+
+ let mut lines = chunk.split('\n');
+ if let Some(mut line) = lines.next() {
+ if first_line {
+ if first_line_non_whitespace_char_ix.is_none() {
+ if let Some(mut char_ix) =
+ line.find(|ch: char| !ch.is_whitespace())
+ {
+ line = &line[char_ix..];
+ char_ix += first_line_len;
+ first_line_non_whitespace_char_ix = Some(char_ix);
+ let first_line_indent = char_ix
+ .saturating_sub(selection_start.column as usize)
+ as usize;
+ new_text
+ .push_str(&indent_text.repeat(first_line_indent));
+ indent_len = indent_len.saturating_sub(char_ix as u32);
+ }
+ }
+ first_line_len += line.len();
+ }
+
+ if first_line_non_whitespace_char_ix.is_some() {
+ new_text.push_str(line);
+ }
+ }
+
+ for line in lines {
+ first_line = false;
+ new_text.push('\n');
+ if !line.is_empty() {
+ new_text.push_str(&indent_text.repeat(indent_len as usize));
+ }
+ new_text.push_str(line);
+ }
+
+ let hunks = diff.push_new(&new_text);
+ hunks_tx.send(hunks).await?;
+ new_text.clear();
+ }
+ hunks_tx.send(diff.finish()).await?;
+
+ anyhow::Ok(())
+ });
+
+ while let Some(hunks) = hunks_rx.next().await {
+ let this = if let Some(this) = this.upgrade(&cx) {
+ this
+ } else {
+ break;
+ };
+
+ this.update(&mut cx, |this, cx| {
+ this.last_equal_ranges.clear();
+
+ let transaction = this.buffer.update(cx, |buffer, cx| {
+ // Avoid grouping assistant edits with user edits.
+ buffer.finalize_last_transaction(cx);
+
+ buffer.start_transaction(cx);
+ buffer.edit(
+ hunks.into_iter().filter_map(|hunk| match hunk {
+ Hunk::Insert { text } => {
+ let edit_start = snapshot.anchor_after(edit_start);
+ Some((edit_start..edit_start, text))
+ }
+ Hunk::Remove { len } => {
+ let edit_end = edit_start + len;
+ let edit_range = snapshot.anchor_after(edit_start)
+ ..snapshot.anchor_before(edit_end);
+ edit_start = edit_end;
+ Some((edit_range, String::new()))
+ }
+ Hunk::Keep { len } => {
+ let edit_end = edit_start + len;
+ let edit_range = snapshot.anchor_after(edit_start)
+ ..snapshot.anchor_before(edit_end);
+ edit_start += len;
+ this.last_equal_ranges.push(edit_range);
+ None
+ }
+ }),
+ None,
+ cx,
+ );
+
+ buffer.end_transaction(cx)
+ });
+
+ if let Some(transaction) = transaction {
+ if let Some(first_transaction) = this.transaction_id {
+ // Group all assistant edits into the first transaction.
+ this.buffer.update(cx, |buffer, cx| {
+ buffer.merge_transactions(
+ transaction,
+ first_transaction,
+ cx,
+ )
+ });
+ } else {
+ this.transaction_id = Some(transaction);
+ this.buffer.update(cx, |buffer, cx| {
+ buffer.finalize_last_transaction(cx)
+ });
+ }
+ }
+
+ cx.notify();
+ });
+ }
+
+ diff.await?;
+ anyhow::Ok(())
+ };
+
+ let result = generate.await;
+ if let Some(this) = this.upgrade(&cx) {
+ this.update(&mut cx, |this, cx| {
+ this.last_equal_ranges.clear();
+ this.idle = true;
+ if let Err(error) = result {
+ this.error = Some(error);
+ }
+ cx.emit(Event::Finished);
+ cx.notify();
+ });
+ }
+ }
+ });
+ self.error.take();
+ self.idle = false;
+ cx.notify();
+ }
+
+ pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
+ if let Some(transaction_id) = self.transaction_id {
+ self.buffer
+ .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
+ }
+ }
+}
+
+fn strip_markdown_codeblock(
+ stream: impl Stream<Item = Result<String>>,
+) -> impl Stream<Item = Result<String>> {
+ let mut first_line = true;
+ let mut buffer = String::new();
+ let mut starts_with_fenced_code_block = false;
+ stream.filter_map(move |chunk| {
+ let chunk = match chunk {
+ Ok(chunk) => chunk,
+ Err(err) => return future::ready(Some(Err(err))),
+ };
+ buffer.push_str(&chunk);
+
+ if first_line {
+ if buffer == "" || buffer == "`" || buffer == "``" {
+ return future::ready(None);
+ } else if buffer.starts_with("```") {
+ starts_with_fenced_code_block = true;
+ if let Some(newline_ix) = buffer.find('\n') {
+ buffer.replace_range(..newline_ix + 1, "");
+ first_line = false;
+ } else {
+ return future::ready(None);
+ }
+ }
+ }
+
+ let text = if starts_with_fenced_code_block {
+ buffer
+ .strip_suffix("\n```\n")
+ .or_else(|| buffer.strip_suffix("\n```"))
+ .or_else(|| buffer.strip_suffix("\n``"))
+ .or_else(|| buffer.strip_suffix("\n`"))
+ .or_else(|| buffer.strip_suffix('\n'))
+ .unwrap_or(&buffer)
+ } else {
+ &buffer
+ };
+
+ if text.contains('\n') {
+ first_line = false;
+ }
+
+ let remainder = buffer.split_off(text.len());
+ let result = if buffer.is_empty() {
+ None
+ } else {
+ Some(Ok(buffer.clone()))
+ };
+ buffer = remainder;
+ future::ready(result)
+ })
+}
+
+#[cfg(test)]
+mod tests {
+ use futures::stream;
+
+ use super::*;
+
+ #[gpui::test]
+ async fn test_strip_markdown_codeblock() {
+ assert_eq!(
+ strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
+ .map(|chunk| chunk.unwrap())
+ .collect::<String>()
+ .await,
+ "Lorem ipsum dolor"
+ );
+ assert_eq!(
+ strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
+ .map(|chunk| chunk.unwrap())
+ .collect::<String>()
+ .await,
+ "Lorem ipsum dolor"
+ );
+ assert_eq!(
+ strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
+ .map(|chunk| chunk.unwrap())
+ .collect::<String>()
+ .await,
+ "Lorem ipsum dolor"
+ );
+ assert_eq!(
+ strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
+ .map(|chunk| chunk.unwrap())
+ .collect::<String>()
+ .await,
+ "Lorem ipsum dolor"
+ );
+ assert_eq!(
+ strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
+ .map(|chunk| chunk.unwrap())
+ .collect::<String>()
+ .await,
+ "```js\nLorem ipsum dolor\n```"
+ );
+ assert_eq!(
+ strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
+ .map(|chunk| chunk.unwrap())
+ .collect::<String>()
+ .await,
+ "``\nLorem ipsum dolor\n```"
+ );
+
+ fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
+ stream::iter(
+ text.chars()
+ .collect::<Vec<_>>()
+ .chunks(size)
+ .map(|chunk| Ok(chunk.iter().collect::<String>()))
+ .collect::<Vec<_>>(),
+ )
+ }
+ }
+}
@@ -1734,6 +1734,10 @@ impl Editor {
}
}
+ pub fn read_only(&self) -> bool {
+ self.read_only
+ }
+
pub fn set_read_only(&mut self, read_only: bool) {
self.read_only = read_only;
}
@@ -5103,9 +5107,6 @@ impl Editor {
self.unmark_text(cx);
self.refresh_copilot_suggestions(true, cx);
cx.emit(Event::Edited);
- cx.emit(Event::TransactionUndone {
- transaction_id: tx_id,
- });
}
}
@@ -8548,9 +8549,6 @@ pub enum Event {
local: bool,
autoscroll: bool,
},
- TransactionUndone {
- transaction_id: TransactionId,
- },
Closed,
}
@@ -70,6 +70,9 @@ pub enum Event {
Edited {
sigleton_buffer_edited: bool,
},
+ TransactionUndone {
+ transaction_id: TransactionId,
+ },
Reloaded,
DiffBaseChanged,
LanguageChanged,
@@ -771,30 +774,36 @@ impl MultiBuffer {
}
pub fn undo(&mut self, cx: &mut ModelContext<Self>) -> Option<TransactionId> {
+ let mut transaction_id = None;
if let Some(buffer) = self.as_singleton() {
- return buffer.update(cx, |buffer, cx| buffer.undo(cx));
- }
+ transaction_id = buffer.update(cx, |buffer, cx| buffer.undo(cx));
+ } else {
+ while let Some(transaction) = self.history.pop_undo() {
+ let mut undone = false;
+ for (buffer_id, buffer_transaction_id) in &mut transaction.buffer_transactions {
+ if let Some(BufferState { buffer, .. }) = self.buffers.borrow().get(buffer_id) {
+ undone |= buffer.update(cx, |buffer, cx| {
+ let undo_to = *buffer_transaction_id;
+ if let Some(entry) = buffer.peek_undo_stack() {
+ *buffer_transaction_id = entry.transaction_id();
+ }
+ buffer.undo_to_transaction(undo_to, cx)
+ });
+ }
+ }
- while let Some(transaction) = self.history.pop_undo() {
- let mut undone = false;
- for (buffer_id, buffer_transaction_id) in &mut transaction.buffer_transactions {
- if let Some(BufferState { buffer, .. }) = self.buffers.borrow().get(buffer_id) {
- undone |= buffer.update(cx, |buffer, cx| {
- let undo_to = *buffer_transaction_id;
- if let Some(entry) = buffer.peek_undo_stack() {
- *buffer_transaction_id = entry.transaction_id();
- }
- buffer.undo_to_transaction(undo_to, cx)
- });
+ if undone {
+ transaction_id = Some(transaction.id);
+ break;
}
}
+ }
- if undone {
- return Some(transaction.id);
- }
+ if let Some(transaction_id) = transaction_id {
+ cx.emit(Event::TransactionUndone { transaction_id });
}
- None
+ transaction_id
}
pub fn redo(&mut self, cx: &mut ModelContext<Self>) -> Option<TransactionId> {