Detailed changes
@@ -102,14 +102,20 @@ dependencies = [
"anyhow",
"chrono",
"collections",
+ "ctor",
"editor",
+ "env_logger 0.9.3",
"fs",
"futures 0.3.28",
"gpui",
+ "indoc",
"isahc",
"language",
+ "log",
"menu",
+ "ordered-float",
"project",
+ "rand 0.8.5",
"regex",
"schemars",
"search",
@@ -5649,6 +5655,7 @@ dependencies = [
name = "quick_action_bar"
version = "0.1.0"
dependencies = [
+ "ai",
"editor",
"gpui",
"search",
@@ -530,7 +530,8 @@
"bindings": {
"alt-enter": "editor::OpenExcerpts",
"cmd-f8": "editor::GoToHunk",
- "cmd-shift-f8": "editor::GoToPrevHunk"
+ "cmd-shift-f8": "editor::GoToPrevHunk",
+ "ctrl-enter": "assistant::InlineAssist"
}
},
{
@@ -24,7 +24,9 @@ workspace = { path = "../workspace" }
anyhow.workspace = true
chrono = { version = "0.4", features = ["serde"] }
futures.workspace = true
+indoc.workspace = true
isahc.workspace = true
+ordered-float.workspace = true
regex.workspace = true
schemars.workspace = true
serde.workspace = true
@@ -35,3 +37,8 @@ tiktoken-rs = "0.4"
[dev-dependencies]
editor = { path = "../editor", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
+
+ctor.workspace = true
+env_logger.workspace = true
+log.workspace = true
+rand.workspace = true
@@ -1,28 +1,33 @@
pub mod assistant;
mod assistant_settings;
+mod streaming_diff;
-use anyhow::Result;
+use anyhow::{anyhow, Result};
pub use assistant::AssistantPanel;
use assistant_settings::OpenAIModel;
use chrono::{DateTime, Local};
use collections::HashMap;
use fs::Fs;
-use futures::StreamExt;
-use gpui::AppContext;
+use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
+use gpui::{executor::Background, AppContext};
+use isahc::{http::StatusCode, Request, RequestExt};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{
cmp::Reverse,
ffi::OsStr,
fmt::{self, Display},
+ io,
path::PathBuf,
sync::Arc,
};
use util::paths::CONVERSATIONS_DIR;
+const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
+
// Data types for chat completion requests
#[derive(Debug, Serialize)]
-struct OpenAIRequest {
+pub struct OpenAIRequest {
model: String,
messages: Vec<RequestMessage>,
stream: bool,
@@ -116,7 +121,7 @@ struct RequestMessage {
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-struct ResponseMessage {
+pub struct ResponseMessage {
role: Option<Role>,
content: Option<String>,
}
@@ -150,7 +155,7 @@ impl Display for Role {
}
#[derive(Deserialize, Debug)]
-struct OpenAIResponseStreamEvent {
+pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
@@ -160,14 +165,14 @@ struct OpenAIResponseStreamEvent {
}
#[derive(Deserialize, Debug)]
-struct Usage {
+pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
-struct ChatChoiceDelta {
+pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
@@ -191,3 +196,97 @@ struct OpenAIChoice {
pub fn init(cx: &mut AppContext) {
assistant::init(cx);
}
+
+pub async fn stream_completion(
+ api_key: String,
+ executor: Arc<Background>,
+ mut request: OpenAIRequest,
+) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+ request.stream = true;
+
+ let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
+
+ let json_data = serde_json::to_string(&request)?;
+ let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(json_data)?
+ .send_async()
+ .await?;
+
+ let status = response.status();
+ if status == StatusCode::OK {
+ executor
+ .spawn(async move {
+ let mut lines = BufReader::new(response.body_mut()).lines();
+
+ fn parse_line(
+ line: Result<String, io::Error>,
+ ) -> Result<Option<OpenAIResponseStreamEvent>> {
+ if let Some(data) = line?.strip_prefix("data: ") {
+ let event = serde_json::from_str(&data)?;
+ Ok(Some(event))
+ } else {
+ Ok(None)
+ }
+ }
+
+ while let Some(line) = lines.next().await {
+ if let Some(event) = parse_line(line).transpose() {
+ let done = event.as_ref().map_or(false, |event| {
+ event
+ .choices
+ .last()
+ .map_or(false, |choice| choice.finish_reason.is_some())
+ });
+ if tx.unbounded_send(event).is_err() {
+ break;
+ }
+
+ if done {
+ break;
+ }
+ }
+ }
+
+ anyhow::Ok(())
+ })
+ .detach();
+
+ Ok(rx)
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ #[derive(Deserialize)]
+ struct OpenAIResponse {
+ error: OpenAIError,
+ }
+
+ #[derive(Deserialize)]
+ struct OpenAIError {
+ message: String,
+ }
+
+ match serde_json::from_str::<OpenAIResponse>(&body) {
+ Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+ "Failed to connect to OpenAI API: {}",
+ response.error.message,
+ )),
+
+ _ => Err(anyhow!(
+ "Failed to connect to OpenAI API: {} {}",
+ response.status(),
+ body,
+ )),
+ }
+ }
+}
+
+#[cfg(test)]
+#[ctor::ctor]
+fn init_logger() {
+ if std::env::var("RUST_LOG").is_ok() {
+ env_logger::init();
+ }
+}
@@ -1,53 +1,63 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
- MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
- RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
+ stream_completion,
+ streaming_diff::{Hunk, StreamingDiff},
+ MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, Role,
+ SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
-use collections::{HashMap, HashSet};
+use collections::{hash_map, HashMap, HashSet, VecDeque};
use editor::{
- display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
+ display_map::{
+ BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint,
+ },
scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
- Anchor, Editor, ToOffset,
+ Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, ToPoint,
};
use fs::Fs;
-use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
+use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{
actions,
- elements::*,
- executor::Background,
+ elements::{
+ ChildView, Component, Empty, Flex, Label, MouseEventHandler, ParentElement, SafeStylable,
+ Stack, Svg, Text, UniformList, UniformListState,
+ },
+ fonts::HighlightStyle,
geometry::vector::{vec2f, Vector2F},
platform::{CursorStyle, MouseButton},
- Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
- Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
+ Action, AnyElement, AppContext, AsyncAppContext, ClipboardItem, Element, Entity, ModelContext,
+ ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle,
+ WindowContext,
+};
+use language::{
+ language_settings::SoftWrap, Buffer, LanguageRegistry, Point, Rope, ToOffset as _,
+ TransactionId,
};
-use isahc::{http::StatusCode, Request, RequestExt};
-use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
use search::BufferSearchBar;
-use serde::Deserialize;
use settings::SettingsStore;
use std::{
- cell::RefCell,
+ cell::{Cell, RefCell},
cmp, env,
fmt::Write,
- io, iter,
+ future, iter,
ops::Range,
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
time::Duration,
};
-use theme::AssistantStyle;
+use theme::{
+ components::{action_button::Button, ComponentExt},
+ AssistantStyle,
+};
use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
use workspace::{
dock::{DockPosition, Panel},
searchable::Direction,
- Save, ToggleZoom, Toolbar, Workspace,
+ Save, Toast, ToggleZoom, Toolbar, Workspace,
};
-const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
-
actions!(
assistant,
[
@@ -58,6 +68,8 @@ actions!(
QuoteSelection,
ToggleFocus,
ResetKey,
+ InlineAssist,
+ ToggleIncludeConversation,
]
);
@@ -89,6 +101,13 @@ pub fn init(cx: &mut AppContext) {
workspace.toggle_panel_focus::<AssistantPanel>(cx);
},
);
+ cx.add_action(AssistantPanel::inline_assist);
+ cx.add_action(AssistantPanel::cancel_last_inline_assist);
+ cx.add_action(InlineAssistant::confirm);
+ cx.add_action(InlineAssistant::cancel);
+ cx.add_action(InlineAssistant::toggle_include_conversation);
+ cx.add_action(InlineAssistant::move_up);
+ cx.add_action(InlineAssistant::move_down);
}
#[derive(Debug)]
@@ -118,10 +137,17 @@ pub struct AssistantPanel {
languages: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
subscriptions: Vec<Subscription>,
+ next_inline_assist_id: usize,
+ pending_inline_assists: HashMap<usize, PendingInlineAssist>,
+ pending_inline_assist_ids_by_editor: HashMap<WeakViewHandle<Editor>, Vec<usize>>,
+ include_conversation_in_next_inline_assist: bool,
+ inline_prompt_history: VecDeque<String>,
_watch_saved_conversations: Task<Result<()>>,
}
impl AssistantPanel {
+ const INLINE_PROMPT_HISTORY_MAX_LEN: usize = 20;
+
pub fn load(
workspace: WeakViewHandle<Workspace>,
cx: AsyncAppContext,
@@ -181,6 +207,11 @@ impl AssistantPanel {
width: None,
height: None,
subscriptions: Default::default(),
+ next_inline_assist_id: 0,
+ pending_inline_assists: Default::default(),
+ pending_inline_assist_ids_by_editor: Default::default(),
+ include_conversation_in_next_inline_assist: false,
+ inline_prompt_history: Default::default(),
_watch_saved_conversations,
};
@@ -201,6 +232,720 @@ impl AssistantPanel {
})
}
+ pub fn inline_assist(
+ workspace: &mut Workspace,
+ _: &InlineAssist,
+ cx: &mut ViewContext<Workspace>,
+ ) {
+ let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
+ if this
+ .update(cx, |assistant, cx| assistant.load_api_key(cx))
+ .is_some()
+ {
+ this
+ } else {
+ workspace.focus_panel::<AssistantPanel>(cx);
+ return;
+ }
+ } else {
+ return;
+ };
+
+ let active_editor = if let Some(active_editor) = workspace
+ .active_item(cx)
+ .and_then(|item| item.act_as::<Editor>(cx))
+ {
+ active_editor
+ } else {
+ return;
+ };
+
+ this.update(cx, |assistant, cx| {
+ assistant.new_inline_assist(&active_editor, cx)
+ });
+ }
+
+ fn new_inline_assist(&mut self, editor: &ViewHandle<Editor>, cx: &mut ViewContext<Self>) {
+ let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
+ let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
+ let selection = editor.read(cx).selections.newest_anchor().clone();
+ let range = selection.start.bias_left(&snapshot)..selection.end.bias_right(&snapshot);
+ let assist_kind = if editor.read(cx).selections.newest::<usize>(cx).is_empty() {
+ InlineAssistKind::Generate
+ } else {
+ InlineAssistKind::Transform
+ };
+ let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
+ let inline_assistant = cx.add_view(|cx| {
+ let assistant = InlineAssistant::new(
+ inline_assist_id,
+ assist_kind,
+ measurements.clone(),
+ self.include_conversation_in_next_inline_assist,
+ self.inline_prompt_history.clone(),
+ cx,
+ );
+ cx.focus_self();
+ assistant
+ });
+ let block_id = editor.update(cx, |editor, cx| {
+ editor.change_selections(None, cx, |selections| {
+ selections.select_anchor_ranges([selection.head()..selection.head()])
+ });
+ editor.insert_blocks(
+ [BlockProperties {
+ style: BlockStyle::Flex,
+ position: selection.head().bias_left(&snapshot),
+ height: 2,
+ render: Arc::new({
+ let inline_assistant = inline_assistant.clone();
+ move |cx: &mut BlockContext| {
+ measurements.set(BlockMeasurements {
+ anchor_x: cx.anchor_x,
+ gutter_width: cx.gutter_width,
+ });
+ ChildView::new(&inline_assistant, cx).into_any()
+ }
+ }),
+ disposition: if selection.reversed {
+ BlockDisposition::Above
+ } else {
+ BlockDisposition::Below
+ },
+ }],
+ Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)),
+ cx,
+ )[0]
+ });
+
+ self.pending_inline_assists.insert(
+ inline_assist_id,
+ PendingInlineAssist {
+ kind: assist_kind,
+ editor: editor.downgrade(),
+ range,
+ highlighted_ranges: Default::default(),
+ inline_assistant: Some((block_id, inline_assistant.clone())),
+ code_generation: Task::ready(None),
+ transaction_id: None,
+ _subscriptions: vec![
+ cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event),
+ cx.subscribe(editor, {
+ let inline_assistant = inline_assistant.downgrade();
+ move |this, editor, event, cx| {
+ if let Some(inline_assistant) = inline_assistant.upgrade(cx) {
+ match event {
+ editor::Event::SelectionsChanged { local } => {
+ if *local && inline_assistant.read(cx).has_focus {
+ cx.focus(&editor);
+ }
+ }
+ editor::Event::TransactionUndone {
+ transaction_id: tx_id,
+ } => {
+ if let Some(pending_assist) =
+ this.pending_inline_assists.get(&inline_assist_id)
+ {
+ if pending_assist.transaction_id == Some(*tx_id) {
+ // Notice we are supplying `undo: false` here. This
+ // is because there's no need to undo the transaction
+ // because the user just did so.
+ this.close_inline_assist(
+ inline_assist_id,
+ false,
+ cx,
+ );
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+ }),
+ ],
+ },
+ );
+ self.pending_inline_assist_ids_by_editor
+ .entry(editor.downgrade())
+ .or_default()
+ .push(inline_assist_id);
+ self.update_highlights_for_editor(&editor, cx);
+ }
+
+ fn handle_inline_assistant_event(
+ &mut self,
+ inline_assistant: ViewHandle<InlineAssistant>,
+ event: &InlineAssistantEvent,
+ cx: &mut ViewContext<Self>,
+ ) {
+ let assist_id = inline_assistant.read(cx).id;
+ match event {
+ InlineAssistantEvent::Confirmed {
+ prompt,
+ include_conversation,
+ } => {
+ self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
+ }
+ InlineAssistantEvent::Canceled => {
+ self.close_inline_assist(assist_id, true, cx);
+ }
+ InlineAssistantEvent::Dismissed => {
+ self.hide_inline_assist(assist_id, cx);
+ }
+ InlineAssistantEvent::IncludeConversationToggled {
+ include_conversation,
+ } => {
+ self.include_conversation_in_next_inline_assist = *include_conversation;
+ }
+ }
+ }
+
+ fn cancel_last_inline_assist(
+ workspace: &mut Workspace,
+ _: &editor::Cancel,
+ cx: &mut ViewContext<Workspace>,
+ ) {
+ let panel = if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
+ panel
+ } else {
+ return;
+ };
+ let editor = if let Some(editor) = workspace
+ .active_item(cx)
+ .and_then(|item| item.downcast::<Editor>())
+ {
+ editor
+ } else {
+ return;
+ };
+
+ let handled = panel.update(cx, |panel, cx| {
+ if let Some(assist_id) = panel
+ .pending_inline_assist_ids_by_editor
+ .get(&editor.downgrade())
+ .and_then(|assist_ids| assist_ids.last().copied())
+ {
+ panel.close_inline_assist(assist_id, true, cx);
+ true
+ } else {
+ false
+ }
+ });
+
+ if !handled {
+ cx.propagate_action();
+ }
+ }
+
+ fn close_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext<Self>) {
+ self.hide_inline_assist(assist_id, cx);
+
+ if let Some(pending_assist) = self.pending_inline_assists.remove(&assist_id) {
+ if let hash_map::Entry::Occupied(mut entry) = self
+ .pending_inline_assist_ids_by_editor
+ .entry(pending_assist.editor)
+ {
+ entry.get_mut().retain(|id| *id != assist_id);
+ if entry.get().is_empty() {
+ entry.remove();
+ }
+ }
+
+ if let Some(editor) = pending_assist.editor.upgrade(cx) {
+ self.update_highlights_for_editor(&editor, cx);
+
+ if undo {
+ if let Some(transaction_id) = pending_assist.transaction_id {
+ editor.update(cx, |editor, cx| {
+ editor.buffer().update(cx, |buffer, cx| {
+ buffer.undo_transaction(transaction_id, cx)
+ });
+ });
+ }
+ }
+ }
+ }
+ }
+
+ fn hide_inline_assist(&mut self, assist_id: usize, cx: &mut ViewContext<Self>) {
+ if let Some(pending_assist) = self.pending_inline_assists.get_mut(&assist_id) {
+ if let Some(editor) = pending_assist.editor.upgrade(cx) {
+ if let Some((block_id, _)) = pending_assist.inline_assistant.take() {
+ editor.update(cx, |editor, cx| {
+ editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
+ });
+ }
+ }
+ }
+ }
+
+ fn confirm_inline_assist(
+ &mut self,
+ inline_assist_id: usize,
+ user_prompt: &str,
+ include_conversation: bool,
+ cx: &mut ViewContext<Self>,
+ ) {
+ let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
+ api_key
+ } else {
+ return;
+ };
+
+ let conversation = if include_conversation {
+ self.active_editor()
+ .map(|editor| editor.read(cx).conversation.clone())
+ } else {
+ None
+ };
+
+ let pending_assist =
+ if let Some(pending_assist) = self.pending_inline_assists.get_mut(&inline_assist_id) {
+ pending_assist
+ } else {
+ return;
+ };
+
+ let editor = if let Some(editor) = pending_assist.editor.upgrade(cx) {
+ editor
+ } else {
+ return;
+ };
+
+ self.inline_prompt_history.push_back(user_prompt.into());
+ if self.inline_prompt_history.len() > Self::INLINE_PROMPT_HISTORY_MAX_LEN {
+ self.inline_prompt_history.pop_front();
+ }
+ let range = pending_assist.range.clone();
+ let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
+ let selected_text = snapshot
+ .text_for_range(range.start..range.end)
+ .collect::<Rope>();
+
+ let selection_start = range.start.to_point(&snapshot);
+ let selection_end = range.end.to_point(&snapshot);
+
+ let mut base_indent: Option<language::IndentSize> = None;
+ let mut start_row = selection_start.row;
+ if snapshot.is_line_blank(start_row) {
+ if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) {
+ start_row = prev_non_blank_row;
+ }
+ }
+ for row in start_row..=selection_end.row {
+ if snapshot.is_line_blank(row) {
+ continue;
+ }
+
+ let line_indent = snapshot.indent_size_for_line(row);
+ if let Some(base_indent) = base_indent.as_mut() {
+ if line_indent.len < base_indent.len {
+ *base_indent = line_indent;
+ }
+ } else {
+ base_indent = Some(line_indent);
+ }
+ }
+
+ let mut normalized_selected_text = selected_text.clone();
+ if let Some(base_indent) = base_indent {
+ for row in selection_start.row..=selection_end.row {
+ let selection_row = row - selection_start.row;
+ let line_start =
+ normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
+ let indent_len = if row == selection_start.row {
+ base_indent.len.saturating_sub(selection_start.column)
+ } else {
+ let line_len = normalized_selected_text.line_len(selection_row);
+ cmp::min(line_len, base_indent.len)
+ };
+ let indent_end = cmp::min(
+ line_start + indent_len as usize,
+ normalized_selected_text.len(),
+ );
+ normalized_selected_text.replace(line_start..indent_end, "");
+ }
+ }
+
+ let language = snapshot.language_at(range.start);
+ let language_name = if let Some(language) = language.as_ref() {
+ if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
+ None
+ } else {
+ Some(language.name())
+ }
+ } else {
+ None
+ };
+ let language_name = language_name.as_deref();
+
+ let mut prompt = String::new();
+ if let Some(language_name) = language_name {
+ writeln!(prompt, "You're an expert {language_name} engineer.").unwrap();
+ }
+ match pending_assist.kind {
+ InlineAssistKind::Transform => {
+ writeln!(
+ prompt,
+ "You're currently working inside an editor on this file:"
+ )
+ .unwrap();
+ if let Some(language_name) = language_name {
+ writeln!(prompt, "```{language_name}").unwrap();
+ } else {
+ writeln!(prompt, "```").unwrap();
+ }
+ for chunk in snapshot.text_for_range(Anchor::min()..Anchor::max()) {
+ write!(prompt, "{chunk}").unwrap();
+ }
+ writeln!(prompt, "```").unwrap();
+
+ writeln!(
+ prompt,
+ "In particular, the user has selected the following text:"
+ )
+ .unwrap();
+ if let Some(language_name) = language_name {
+ writeln!(prompt, "```{language_name}").unwrap();
+ } else {
+ writeln!(prompt, "```").unwrap();
+ }
+ writeln!(prompt, "{normalized_selected_text}").unwrap();
+ writeln!(prompt, "```").unwrap();
+ writeln!(prompt).unwrap();
+ writeln!(
+ prompt,
+ "Modify the selected text given the user prompt: {user_prompt}"
+ )
+ .unwrap();
+ writeln!(
+ prompt,
+ "You MUST reply only with the edited selected text, not the entire file."
+ )
+ .unwrap();
+ }
+ InlineAssistKind::Generate => {
+ writeln!(
+ prompt,
+ "You're currently working inside an editor on this file:"
+ )
+ .unwrap();
+ if let Some(language_name) = language_name {
+ writeln!(prompt, "```{language_name}").unwrap();
+ } else {
+ writeln!(prompt, "```").unwrap();
+ }
+ for chunk in snapshot.text_for_range(Anchor::min()..range.start) {
+ write!(prompt, "{chunk}").unwrap();
+ }
+ write!(prompt, "<|>").unwrap();
+ for chunk in snapshot.text_for_range(range.start..Anchor::max()) {
+ write!(prompt, "{chunk}").unwrap();
+ }
+ writeln!(prompt).unwrap();
+ writeln!(prompt, "```").unwrap();
+ writeln!(
+ prompt,
+ "Assume the cursor is located where the `<|>` marker is."
+ )
+ .unwrap();
+ writeln!(
+ prompt,
+ "Text can't be replaced, so assume your answer will be inserted at the cursor."
+ )
+ .unwrap();
+ writeln!(
+ prompt,
+ "Complete the text given the user prompt: {user_prompt}"
+ )
+ .unwrap();
+ }
+ }
+ if let Some(language_name) = language_name {
+ writeln!(prompt, "Your answer MUST always be valid {language_name}.").unwrap();
+ }
+ writeln!(prompt, "Always wrap your response in a Markdown codeblock.").unwrap();
+ writeln!(prompt, "Never make remarks about the output.").unwrap();
+
+ let mut messages = Vec::new();
+ let mut model = settings::get::<AssistantSettings>(cx)
+ .default_open_ai_model
+ .clone();
+ if let Some(conversation) = conversation {
+ let conversation = conversation.read(cx);
+ let buffer = conversation.buffer.read(cx);
+ messages.extend(
+ conversation
+ .messages(cx)
+ .map(|message| message.to_open_ai_message(buffer)),
+ );
+ model = conversation.model.clone();
+ }
+
+ messages.push(RequestMessage {
+ role: Role::User,
+ content: prompt,
+ });
+ let request = OpenAIRequest {
+ model: model.full_name().into(),
+ messages,
+ stream: true,
+ };
+ let response = stream_completion(api_key, cx.background().clone(), request);
+ let editor = editor.downgrade();
+
+ pending_assist.code_generation = cx.spawn(|this, mut cx| {
+ async move {
+ let mut edit_start = range.start.to_offset(&snapshot);
+
+ let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
+ let diff = cx.background().spawn(async move {
+ let chunks = strip_markdown_codeblock(response.await?.filter_map(
+ |message| async move {
+ match message {
+ Ok(mut message) => Some(Ok(message.choices.pop()?.delta.content?)),
+ Err(error) => Some(Err(error)),
+ }
+ },
+ ));
+ futures::pin_mut!(chunks);
+ let mut diff = StreamingDiff::new(selected_text.to_string());
+
+ let mut indent_len;
+ let indent_text;
+ if let Some(base_indent) = base_indent {
+ indent_len = base_indent.len;
+ indent_text = match base_indent.kind {
+ language::IndentKind::Space => " ",
+ language::IndentKind::Tab => "\t",
+ };
+ } else {
+ indent_len = 0;
+ indent_text = "";
+ };
+
+ let mut first_line_len = 0;
+ let mut first_line_non_whitespace_char_ix = None;
+ let mut first_line = true;
+ let mut new_text = String::new();
+
+ while let Some(chunk) = chunks.next().await {
+ let chunk = chunk?;
+
+ let mut lines = chunk.split('\n');
+ if let Some(mut line) = lines.next() {
+ if first_line {
+ if first_line_non_whitespace_char_ix.is_none() {
+ if let Some(mut char_ix) =
+ line.find(|ch: char| !ch.is_whitespace())
+ {
+ line = &line[char_ix..];
+ char_ix += first_line_len;
+ first_line_non_whitespace_char_ix = Some(char_ix);
+ let first_line_indent = char_ix
+ .saturating_sub(selection_start.column as usize)
+ as usize;
+ new_text.push_str(&indent_text.repeat(first_line_indent));
+ indent_len = indent_len.saturating_sub(char_ix as u32);
+ }
+ }
+ first_line_len += line.len();
+ }
+
+ if first_line_non_whitespace_char_ix.is_some() {
+ new_text.push_str(line);
+ }
+ }
+
+ for line in lines {
+ first_line = false;
+ new_text.push('\n');
+ if !line.is_empty() {
+ new_text.push_str(&indent_text.repeat(indent_len as usize));
+ }
+ new_text.push_str(line);
+ }
+
+ let hunks = diff.push_new(&new_text);
+ hunks_tx.send(hunks).await?;
+ new_text.clear();
+ }
+ hunks_tx.send(diff.finish()).await?;
+
+ anyhow::Ok(())
+ });
+
+ while let Some(hunks) = hunks_rx.next().await {
+ let editor = if let Some(editor) = editor.upgrade(&cx) {
+ editor
+ } else {
+ break;
+ };
+
+ let this = if let Some(this) = this.upgrade(&cx) {
+ this
+ } else {
+ break;
+ };
+
+ this.update(&mut cx, |this, cx| {
+ let pending_assist = if let Some(pending_assist) =
+ this.pending_inline_assists.get_mut(&inline_assist_id)
+ {
+ pending_assist
+ } else {
+ return;
+ };
+
+ pending_assist.highlighted_ranges.clear();
+ editor.update(cx, |editor, cx| {
+ let transaction = editor.buffer().update(cx, |buffer, cx| {
+ // Avoid grouping assistant edits with user edits.
+ buffer.finalize_last_transaction(cx);
+
+ buffer.start_transaction(cx);
+ buffer.edit(
+ hunks.into_iter().filter_map(|hunk| match hunk {
+ Hunk::Insert { text } => {
+ let edit_start = snapshot.anchor_after(edit_start);
+ Some((edit_start..edit_start, text))
+ }
+ Hunk::Remove { len } => {
+ let edit_end = edit_start + len;
+ let edit_range = snapshot.anchor_after(edit_start)
+ ..snapshot.anchor_before(edit_end);
+ edit_start = edit_end;
+ Some((edit_range, String::new()))
+ }
+ Hunk::Keep { len } => {
+ let edit_end = edit_start + len;
+ let edit_range = snapshot.anchor_after(edit_start)
+ ..snapshot.anchor_before(edit_end);
+ edit_start += len;
+ pending_assist.highlighted_ranges.push(edit_range);
+ None
+ }
+ }),
+ None,
+ cx,
+ );
+
+ buffer.end_transaction(cx)
+ });
+
+ if let Some(transaction) = transaction {
+ if let Some(first_transaction) = pending_assist.transaction_id {
+ // Group all assistant edits into the first transaction.
+ editor.buffer().update(cx, |buffer, cx| {
+ buffer.merge_transactions(
+ transaction,
+ first_transaction,
+ cx,
+ )
+ });
+ } else {
+ pending_assist.transaction_id = Some(transaction);
+ editor.buffer().update(cx, |buffer, cx| {
+ buffer.finalize_last_transaction(cx)
+ });
+ }
+ }
+ });
+
+ this.update_highlights_for_editor(&editor, cx);
+ });
+ }
+
+ if let Err(error) = diff.await {
+ this.update(&mut cx, |this, cx| {
+ let pending_assist = if let Some(pending_assist) =
+ this.pending_inline_assists.get_mut(&inline_assist_id)
+ {
+ pending_assist
+ } else {
+ return;
+ };
+
+ if let Some((_, inline_assistant)) =
+ pending_assist.inline_assistant.as_ref()
+ {
+ inline_assistant.update(cx, |inline_assistant, cx| {
+ inline_assistant.set_error(error, cx);
+ });
+ } else if let Some(workspace) = this.workspace.upgrade(cx) {
+ workspace.update(cx, |workspace, cx| {
+ workspace.show_toast(
+ Toast::new(
+ inline_assist_id,
+ format!("Inline assistant error: {}", error),
+ ),
+ cx,
+ );
+ })
+ }
+ })?;
+ } else {
+ let _ = this.update(&mut cx, |this, cx| {
+ this.close_inline_assist(inline_assist_id, false, cx)
+ });
+ }
+
+ anyhow::Ok(())
+ }
+ .log_err()
+ });
+ }
+
+ fn update_highlights_for_editor(
+ &self,
+ editor: &ViewHandle<Editor>,
+ cx: &mut ViewContext<Self>,
+ ) {
+ let mut background_ranges = Vec::new();
+ let mut foreground_ranges = Vec::new();
+ let empty_inline_assist_ids = Vec::new();
+ let inline_assist_ids = self
+ .pending_inline_assist_ids_by_editor
+ .get(&editor.downgrade())
+ .unwrap_or(&empty_inline_assist_ids);
+
+ for inline_assist_id in inline_assist_ids {
+ if let Some(pending_assist) = self.pending_inline_assists.get(inline_assist_id) {
+ background_ranges.push(pending_assist.range.clone());
+ foreground_ranges.extend(pending_assist.highlighted_ranges.iter().cloned());
+ }
+ }
+
+ let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
+ merge_ranges(&mut background_ranges, &snapshot);
+ merge_ranges(&mut foreground_ranges, &snapshot);
+ editor.update(cx, |editor, cx| {
+ if background_ranges.is_empty() {
+ editor.clear_background_highlights::<PendingInlineAssist>(cx);
+ } else {
+ editor.highlight_background::<PendingInlineAssist>(
+ background_ranges,
+ |theme| theme.assistant.inline.pending_edit_background,
+ cx,
+ );
+ }
+
+ if foreground_ranges.is_empty() {
+ editor.clear_text_highlights::<PendingInlineAssist>(cx);
+ } else {
+ editor.highlight_text::<PendingInlineAssist>(
+ foreground_ranges,
+ HighlightStyle {
+ fade_out: Some(0.6),
+ ..Default::default()
+ },
+ cx,
+ );
+ }
+ });
+ }
+
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
let editor = cx.add_view(|cx| {
ConversationEditor::new(
@@ -570,6 +1315,32 @@ impl AssistantPanel {
.iter()
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
}
+
+ fn load_api_key(&mut self, cx: &mut ViewContext<Self>) -> Option<String> {
+ if self.api_key.borrow().is_none() && !self.has_read_credentials {
+ self.has_read_credentials = true;
+ let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ Some(api_key)
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ String::from_utf8(api_key).log_err()
+ } else {
+ None
+ };
+ if let Some(api_key) = api_key {
+ *self.api_key.borrow_mut() = Some(api_key);
+ } else if self.api_key_editor.is_none() {
+ self.api_key_editor = Some(build_api_key_editor(cx));
+ cx.notify();
+ }
+ }
+
+ self.api_key.borrow().clone()
+ }
}
fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
@@ -753,27 +1524,7 @@ impl Panel for AssistantPanel {
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
if active {
- if self.api_key.borrow().is_none() && !self.has_read_credentials {
- self.has_read_credentials = true;
- let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
- Some(api_key)
- } else if let Some((_, api_key)) = cx
- .platform()
- .read_credentials(OPENAI_API_URL)
- .log_err()
- .flatten()
- {
- String::from_utf8(api_key).log_err()
- } else {
- None
- };
- if let Some(api_key) = api_key {
- *self.api_key.borrow_mut() = Some(api_key);
- } else if self.api_key_editor.is_none() {
- self.api_key_editor = Some(build_api_key_editor(cx));
- cx.notify();
- }
- }
+ self.load_api_key(cx);
if self.editors.is_empty() {
self.new_conversation(cx);
@@ -1068,15 +1819,20 @@ impl Conversation {
cx: &mut ModelContext<Self>,
) -> Vec<MessageAnchor> {
let mut user_messages = Vec::new();
- let mut tasks = Vec::new();
- let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
- message
- .start
- .is_valid(self.buffer.read(cx))
- .then_some(message.id)
- });
+ let last_message_id = if let Some(last_message_id) =
+ self.message_anchors.iter().rev().find_map(|message| {
+ message
+ .start
+ .is_valid(self.buffer.read(cx))
+ .then_some(message.id)
+ }) {
+ last_message_id
+ } else {
+ return Default::default();
+ };
+ let mut should_assist = false;
for selected_message_id in selected_messages {
let selected_message_role =
if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
@@ -1093,144 +1849,111 @@ impl Conversation {
cx,
) {
user_messages.push(user_message);
- } else {
- continue;
}
} else {
- let request = OpenAIRequest {
- model: self.model.full_name().to_string(),
- messages: self
- .messages(cx)
- .filter(|message| matches!(message.status, MessageStatus::Done))
- .flat_map(|message| {
- let mut system_message = None;
- if message.id == selected_message_id {
- system_message = Some(RequestMessage {
- role: Role::System,
- content: concat!(
- "Treat the following messages as additional knowledge you have learned about, ",
- "but act as if they were not part of this conversation. That is, treat them ",
- "as if the user didn't see them and couldn't possibly inquire about them."
- ).into()
- });
- }
+ should_assist = true;
+ }
+ }
- Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
- })
- .chain(Some(RequestMessage {
- role: Role::System,
- content: format!(
- "Direct your reply to message with id {}. Do not include a [Message X] header.",
- selected_message_id.0
- ),
- }))
- .collect(),
- stream: true,
- };
+ if should_assist {
+ let Some(api_key) = self.api_key.borrow().clone() else {
+ return Default::default();
+ };
- let Some(api_key) = self.api_key.borrow().clone() else {
- continue;
- };
- let stream = stream_completion(api_key, cx.background().clone(), request);
- let assistant_message = self
- .insert_message_after(
- selected_message_id,
- Role::Assistant,
- MessageStatus::Pending,
- cx,
- )
- .unwrap();
-
- // Queue up the user's next reply
- if Some(selected_message_id) == last_message_id {
- let user_message = self
- .insert_message_after(
- assistant_message.id,
- Role::User,
- MessageStatus::Done,
- cx,
- )
- .unwrap();
- user_messages.push(user_message);
- }
+ let request = OpenAIRequest {
+ model: self.model.full_name().to_string(),
+ messages: self
+ .messages(cx)
+ .filter(|message| matches!(message.status, MessageStatus::Done))
+ .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
+ .collect(),
+ stream: true,
+ };
+
+ let stream = stream_completion(api_key, cx.background().clone(), request);
+ let assistant_message = self
+ .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
+ .unwrap();
+
+ // Queue up the user's next reply.
+ let user_message = self
+ .insert_message_after(assistant_message.id, Role::User, MessageStatus::Done, cx)
+ .unwrap();
+ user_messages.push(user_message);
+
+ let task = cx.spawn_weak({
+ |this, mut cx| async move {
+ let assistant_message_id = assistant_message.id;
+ let stream_completion = async {
+ let mut messages = stream.await?;
- tasks.push(cx.spawn_weak({
- |this, mut cx| async move {
- let assistant_message_id = assistant_message.id;
- let stream_completion = async {
- let mut messages = stream.await?;
-
- while let Some(message) = messages.next().await {
- let mut message = message?;
- if let Some(choice) = message.choices.pop() {
- this.upgrade(&cx)
- .ok_or_else(|| anyhow!("conversation was dropped"))?
- .update(&mut cx, |this, cx| {
- let text: Arc<str> = choice.delta.content?.into();
- let message_ix = this.message_anchors.iter().position(
- |message| message.id == assistant_message_id,
- )?;
- this.buffer.update(cx, |buffer, cx| {
- let offset = this.message_anchors[message_ix + 1..]
- .iter()
- .find(|message| message.start.is_valid(buffer))
- .map_or(buffer.len(), |message| {
- message
- .start
- .to_offset(buffer)
- .saturating_sub(1)
- });
- buffer.edit([(offset..offset, text)], None, cx);
- });
- cx.emit(ConversationEvent::StreamedCompletion);
-
- Some(())
+ while let Some(message) = messages.next().await {
+ let mut message = message?;
+ if let Some(choice) = message.choices.pop() {
+ this.upgrade(&cx)
+ .ok_or_else(|| anyhow!("conversation was dropped"))?
+ .update(&mut cx, |this, cx| {
+ let text: Arc<str> = choice.delta.content?.into();
+ let message_ix =
+ this.message_anchors.iter().position(|message| {
+ message.id == assistant_message_id
+ })?;
+ this.buffer.update(cx, |buffer, cx| {
+ let offset = this.message_anchors[message_ix + 1..]
+ .iter()
+ .find(|message| message.start.is_valid(buffer))
+ .map_or(buffer.len(), |message| {
+ message
+ .start
+ .to_offset(buffer)
+ .saturating_sub(1)
+ });
+ buffer.edit([(offset..offset, text)], None, cx);
});
- }
- smol::future::yield_now().await;
- }
+ cx.emit(ConversationEvent::StreamedCompletion);
- this.upgrade(&cx)
- .ok_or_else(|| anyhow!("conversation was dropped"))?
- .update(&mut cx, |this, cx| {
- this.pending_completions.retain(|completion| {
- completion.id != this.completion_count
+ Some(())
});
- this.summarize(cx);
- });
+ }
+ smol::future::yield_now().await;
+ }
- anyhow::Ok(())
- };
+ this.upgrade(&cx)
+ .ok_or_else(|| anyhow!("conversation was dropped"))?
+ .update(&mut cx, |this, cx| {
+ this.pending_completions
+ .retain(|completion| completion.id != this.completion_count);
+ this.summarize(cx);
+ });
- let result = stream_completion.await;
- if let Some(this) = this.upgrade(&cx) {
- this.update(&mut cx, |this, cx| {
- if let Some(metadata) =
- this.messages_metadata.get_mut(&assistant_message.id)
- {
- match result {
- Ok(_) => {
- metadata.status = MessageStatus::Done;
- }
- Err(error) => {
- metadata.status = MessageStatus::Error(
- error.to_string().trim().into(),
- );
- }
+ anyhow::Ok(())
+ };
+
+ let result = stream_completion.await;
+ if let Some(this) = this.upgrade(&cx) {
+ this.update(&mut cx, |this, cx| {
+ if let Some(metadata) =
+ this.messages_metadata.get_mut(&assistant_message.id)
+ {
+ match result {
+ Ok(_) => {
+ metadata.status = MessageStatus::Done;
+ }
+ Err(error) => {
+ metadata.status =
+ MessageStatus::Error(error.to_string().trim().into());
}
- cx.notify();
}
- });
- }
+ cx.notify();
+ }
+ });
}
- }));
- }
- }
+ }
+ });
- if !tasks.is_empty() {
self.pending_completions.push(PendingCompletion {
id: post_inc(&mut self.completion_count),
- _tasks: tasks,
+ _task: task,
});
}
@@ -0,0 +1,293 @@
+use collections::HashMap;
+use ordered_float::OrderedFloat;
+use std::{
+ cmp,
+ fmt::{self, Debug},
+ ops::Range,
+};
+
+struct Matrix {
+ cells: Vec<f64>,
+ rows: usize,
+ cols: usize,
+}
+
+impl Matrix {
+ fn new() -> Self {
+ Self {
+ cells: Vec::new(),
+ rows: 0,
+ cols: 0,
+ }
+ }
+
+ fn resize(&mut self, rows: usize, cols: usize) {
+ self.cells.resize(rows * cols, 0.);
+ self.rows = rows;
+ self.cols = cols;
+ }
+
+ fn get(&self, row: usize, col: usize) -> f64 {
+ if row >= self.rows {
+ panic!("row out of bounds")
+ }
+
+ if col >= self.cols {
+ panic!("col out of bounds")
+ }
+ self.cells[col * self.rows + row]
+ }
+
+ fn set(&mut self, row: usize, col: usize, value: f64) {
+ if row >= self.rows {
+ panic!("row out of bounds")
+ }
+
+ if col >= self.cols {
+ panic!("col out of bounds")
+ }
+
+ self.cells[col * self.rows + row] = value;
+ }
+}
+
+impl Debug for Matrix {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ writeln!(f)?;
+ for i in 0..self.rows {
+ for j in 0..self.cols {
+ write!(f, "{:5}", self.get(i, j))?;
+ }
+ writeln!(f)?;
+ }
+ Ok(())
+ }
+}
+
+#[derive(Debug)]
+pub enum Hunk {
+ Insert { text: String },
+ Remove { len: usize },
+ Keep { len: usize },
+}
+
+pub struct StreamingDiff {
+ old: Vec<char>,
+ new: Vec<char>,
+ scores: Matrix,
+ old_text_ix: usize,
+ new_text_ix: usize,
+ equal_runs: HashMap<(usize, usize), u32>,
+}
+
+impl StreamingDiff {
+ const INSERTION_SCORE: f64 = -1.;
+ const DELETION_SCORE: f64 = -20.;
+ const EQUALITY_BASE: f64 = 1.8;
+ const MAX_EQUALITY_EXPONENT: i32 = 16;
+
+ pub fn new(old: String) -> Self {
+ let old = old.chars().collect::<Vec<_>>();
+ let mut scores = Matrix::new();
+ scores.resize(old.len() + 1, 1);
+ for i in 0..=old.len() {
+ scores.set(i, 0, i as f64 * Self::DELETION_SCORE);
+ }
+ Self {
+ old,
+ new: Vec::new(),
+ scores,
+ old_text_ix: 0,
+ new_text_ix: 0,
+ equal_runs: Default::default(),
+ }
+ }
+
+ pub fn push_new(&mut self, text: &str) -> Vec<Hunk> {
+ self.new.extend(text.chars());
+ self.scores.resize(self.old.len() + 1, self.new.len() + 1);
+
+ for j in self.new_text_ix + 1..=self.new.len() {
+ self.scores.set(0, j, j as f64 * Self::INSERTION_SCORE);
+ for i in 1..=self.old.len() {
+ let insertion_score = self.scores.get(i, j - 1) + Self::INSERTION_SCORE;
+ let deletion_score = self.scores.get(i - 1, j) + Self::DELETION_SCORE;
+ let equality_score = if self.old[i - 1] == self.new[j - 1] {
+ let mut equal_run = self.equal_runs.get(&(i - 1, j - 1)).copied().unwrap_or(0);
+ equal_run += 1;
+ self.equal_runs.insert((i, j), equal_run);
+
+ let exponent = cmp::min(equal_run as i32 / 4, Self::MAX_EQUALITY_EXPONENT);
+ self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.powi(exponent)
+ } else {
+ f64::NEG_INFINITY
+ };
+
+ let score = insertion_score.max(deletion_score).max(equality_score);
+ self.scores.set(i, j, score);
+ }
+ }
+
+ let mut max_score = f64::NEG_INFINITY;
+ let mut next_old_text_ix = self.old_text_ix;
+ let next_new_text_ix = self.new.len();
+ for i in self.old_text_ix..=self.old.len() {
+ let score = self.scores.get(i, next_new_text_ix);
+ if score > max_score {
+ max_score = score;
+ next_old_text_ix = i;
+ }
+ }
+
+ let hunks = self.backtrack(next_old_text_ix, next_new_text_ix);
+ self.old_text_ix = next_old_text_ix;
+ self.new_text_ix = next_new_text_ix;
+ hunks
+ }
+
+ fn backtrack(&self, old_text_ix: usize, new_text_ix: usize) -> Vec<Hunk> {
+ let mut pending_insert: Option<Range<usize>> = None;
+ let mut hunks = Vec::new();
+ let mut i = old_text_ix;
+ let mut j = new_text_ix;
+ while (i, j) != (self.old_text_ix, self.new_text_ix) {
+ let insertion_score = if j > self.new_text_ix {
+ Some((i, j - 1))
+ } else {
+ None
+ };
+ let deletion_score = if i > self.old_text_ix {
+ Some((i - 1, j))
+ } else {
+ None
+ };
+ let equality_score = if i > self.old_text_ix && j > self.new_text_ix {
+ if self.old[i - 1] == self.new[j - 1] {
+ Some((i - 1, j - 1))
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ let (prev_i, prev_j) = [insertion_score, deletion_score, equality_score]
+ .iter()
+ .max_by_key(|cell| cell.map(|(i, j)| OrderedFloat(self.scores.get(i, j))))
+ .unwrap()
+ .unwrap();
+
+ if prev_i == i && prev_j == j - 1 {
+ if let Some(pending_insert) = pending_insert.as_mut() {
+ pending_insert.start = prev_j;
+ } else {
+ pending_insert = Some(prev_j..j);
+ }
+ } else {
+ if let Some(range) = pending_insert.take() {
+ hunks.push(Hunk::Insert {
+ text: self.new[range].iter().collect(),
+ });
+ }
+
+ let char_len = self.old[i - 1].len_utf8();
+ if prev_i == i - 1 && prev_j == j {
+ if let Some(Hunk::Remove { len }) = hunks.last_mut() {
+ *len += char_len;
+ } else {
+ hunks.push(Hunk::Remove { len: char_len })
+ }
+ } else {
+ if let Some(Hunk::Keep { len }) = hunks.last_mut() {
+ *len += char_len;
+ } else {
+ hunks.push(Hunk::Keep { len: char_len })
+ }
+ }
+ }
+
+ i = prev_i;
+ j = prev_j;
+ }
+
+ if let Some(range) = pending_insert.take() {
+ hunks.push(Hunk::Insert {
+ text: self.new[range].iter().collect(),
+ });
+ }
+
+ hunks.reverse();
+ hunks
+ }
+
+ pub fn finish(self) -> Vec<Hunk> {
+ self.backtrack(self.old.len(), self.new.len())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::env;
+
+ use super::*;
+ use rand::prelude::*;
+
+ #[gpui::test(iterations = 100)]
+ fn test_random_diffs(mut rng: StdRng) {
+ let old_text_len = env::var("OLD_TEXT_LEN")
+ .map(|i| i.parse().expect("invalid `OLD_TEXT_LEN` variable"))
+ .unwrap_or(10);
+ let new_text_len = env::var("NEW_TEXT_LEN")
+ .map(|i| i.parse().expect("invalid `NEW_TEXT_LEN` variable"))
+ .unwrap_or(10);
+
+ let old = util::RandomCharIter::new(&mut rng)
+ .take(old_text_len)
+ .collect::<String>();
+ log::info!("old text: {:?}", old);
+
+ let mut diff = StreamingDiff::new(old.clone());
+ let mut hunks = Vec::new();
+ let mut new_len = 0;
+ let mut new = String::new();
+ while new_len < new_text_len {
+ let new_chunk_len = rng.gen_range(1..=new_text_len - new_len);
+ let new_chunk = util::RandomCharIter::new(&mut rng)
+ .take(new_len)
+ .collect::<String>();
+ log::info!("new chunk: {:?}", new_chunk);
+ new_len += new_chunk_len;
+ new.push_str(&new_chunk);
+ let new_hunks = diff.push_new(&new_chunk);
+ log::info!("hunks: {:?}", new_hunks);
+ hunks.extend(new_hunks);
+ }
+ let final_hunks = diff.finish();
+ log::info!("final hunks: {:?}", final_hunks);
+ hunks.extend(final_hunks);
+
+ log::info!("new text: {:?}", new);
+ let mut old_ix = 0;
+ let mut new_ix = 0;
+ let mut patched = String::new();
+ for hunk in hunks {
+ match hunk {
+ Hunk::Keep { len } => {
+ assert_eq!(&old[old_ix..old_ix + len], &new[new_ix..new_ix + len]);
+ patched.push_str(&old[old_ix..old_ix + len]);
+ old_ix += len;
+ new_ix += len;
+ }
+ Hunk::Remove { len } => {
+ old_ix += len;
+ }
+ Hunk::Insert { text } => {
+ assert_eq!(text, &new[new_ix..new_ix + text.len()]);
+ patched.push_str(&text);
+ new_ix += text.len();
+ }
+ }
+ }
+ assert_eq!(patched, new);
+ }
+}
@@ -1635,6 +1635,15 @@ impl Editor {
self.read_only = read_only;
}
+ pub fn set_field_editor_style(
+ &mut self,
+ style: Option<Arc<GetFieldEditorTheme>>,
+ cx: &mut ViewContext<Self>,
+ ) {
+ self.get_field_editor_theme = style;
+ cx.notify();
+ }
+
pub fn replica_id_map(&self) -> Option<&HashMap<ReplicaId, ReplicaId>> {
self.replica_id_mapping.as_ref()
}
@@ -4989,6 +4998,9 @@ impl Editor {
self.unmark_text(cx);
self.refresh_copilot_suggestions(true, cx);
cx.emit(Event::Edited);
+ cx.emit(Event::TransactionUndone {
+ transaction_id: tx_id,
+ });
}
}
@@ -8428,6 +8440,9 @@ pub enum Event {
local: bool,
autoscroll: bool,
},
+ TransactionUndone {
+ transaction_id: TransactionId,
+ },
Closed,
}
@@ -8468,7 +8483,7 @@ impl View for Editor {
"Editor"
}
- fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
+ fn focus_in(&mut self, focused: AnyViewHandle, cx: &mut ViewContext<Self>) {
if cx.is_self_focused() {
let focused_event = EditorFocused(cx.handle());
cx.emit(Event::Focused);
@@ -8476,7 +8491,7 @@ impl View for Editor {
}
if let Some(rename) = self.pending_rename.as_ref() {
cx.focus(&rename.editor);
- } else {
+ } else if cx.is_self_focused() || !focused.is::<Editor>() {
if !self.focused {
self.blink_manager.update(cx, BlinkManager::enable);
}
@@ -617,6 +617,42 @@ impl MultiBuffer {
}
}
+ pub fn merge_transactions(
+ &mut self,
+ transaction: TransactionId,
+ destination: TransactionId,
+ cx: &mut ModelContext<Self>,
+ ) {
+ if let Some(buffer) = self.as_singleton() {
+ buffer.update(cx, |buffer, _| {
+ buffer.merge_transactions(transaction, destination)
+ });
+ } else {
+ if let Some(transaction) = self.history.forget(transaction) {
+ if let Some(destination) = self.history.transaction_mut(destination) {
+ for (buffer_id, buffer_transaction_id) in transaction.buffer_transactions {
+ if let Some(destination_buffer_transaction_id) =
+ destination.buffer_transactions.get(&buffer_id)
+ {
+ if let Some(state) = self.buffers.borrow().get(&buffer_id) {
+ state.buffer.update(cx, |buffer, _| {
+ buffer.merge_transactions(
+ buffer_transaction_id,
+ *destination_buffer_transaction_id,
+ )
+ });
+ }
+ } else {
+ destination
+ .buffer_transactions
+ .insert(buffer_id, buffer_transaction_id);
+ }
+ }
+ }
+ }
+ }
+ }
+
pub fn finalize_last_transaction(&mut self, cx: &mut ModelContext<Self>) {
self.history.finalize_last_transaction();
for BufferState { buffer, .. } in self.buffers.borrow().values() {
@@ -788,6 +824,20 @@ impl MultiBuffer {
None
}
+ pub fn undo_transaction(&mut self, transaction_id: TransactionId, cx: &mut ModelContext<Self>) {
+ if let Some(buffer) = self.as_singleton() {
+ buffer.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
+ } else if let Some(transaction) = self.history.remove_from_undo(transaction_id) {
+ for (buffer_id, transaction_id) in &transaction.buffer_transactions {
+ if let Some(BufferState { buffer, .. }) = self.buffers.borrow().get(buffer_id) {
+ buffer.update(cx, |buffer, cx| {
+ buffer.undo_transaction(*transaction_id, cx)
+ });
+ }
+ }
+ }
+ }
+
pub fn stream_excerpts_with_context_lines(
&mut self,
buffer: ModelHandle<Buffer>,
@@ -2316,6 +2366,16 @@ impl MultiBufferSnapshot {
}
}
+ pub fn prev_non_blank_row(&self, mut row: u32) -> Option<u32> {
+ while row > 0 {
+ row -= 1;
+ if !self.is_line_blank(row) {
+ return Some(row);
+ }
+ }
+ None
+ }
+
pub fn line_len(&self, row: u32) -> u32 {
if let Some((_, range)) = self.buffer_line_for_row(row) {
range.end.column - range.start.column
@@ -3347,6 +3407,35 @@ impl History {
}
}
+ fn forget(&mut self, transaction_id: TransactionId) -> Option<Transaction> {
+ if let Some(ix) = self
+ .undo_stack
+ .iter()
+ .rposition(|transaction| transaction.id == transaction_id)
+ {
+ Some(self.undo_stack.remove(ix))
+ } else if let Some(ix) = self
+ .redo_stack
+ .iter()
+ .rposition(|transaction| transaction.id == transaction_id)
+ {
+ Some(self.redo_stack.remove(ix))
+ } else {
+ None
+ }
+ }
+
+ fn transaction_mut(&mut self, transaction_id: TransactionId) -> Option<&mut Transaction> {
+ self.undo_stack
+ .iter_mut()
+ .find(|transaction| transaction.id == transaction_id)
+ .or_else(|| {
+ self.redo_stack
+ .iter_mut()
+ .find(|transaction| transaction.id == transaction_id)
+ })
+ }
+
fn pop_undo(&mut self) -> Option<&mut Transaction> {
assert_eq!(self.transaction_depth, 0);
if let Some(transaction) = self.undo_stack.pop() {
@@ -3367,6 +3456,16 @@ impl History {
}
}
+ fn remove_from_undo(&mut self, transaction_id: TransactionId) -> Option<&Transaction> {
+ let ix = self
+ .undo_stack
+ .iter()
+ .rposition(|transaction| transaction.id == transaction_id)?;
+ let transaction = self.undo_stack.remove(ix);
+ self.redo_stack.push(transaction);
+ self.redo_stack.last()
+ }
+
fn group(&mut self) -> Option<TransactionId> {
let mut count = 0;
let mut transactions = self.undo_stack.iter();
@@ -1298,6 +1298,10 @@ impl Buffer {
self.text.forget_transaction(transaction_id);
}
+ pub fn merge_transactions(&mut self, transaction: TransactionId, destination: TransactionId) {
+ self.text.merge_transactions(transaction, destination);
+ }
+
pub fn wait_for_edits(
&mut self,
edit_ids: impl IntoIterator<Item = clock::Local>,
@@ -1664,6 +1668,22 @@ impl Buffer {
}
}
+ pub fn undo_transaction(
+ &mut self,
+ transaction_id: TransactionId,
+ cx: &mut ModelContext<Self>,
+ ) -> bool {
+ let was_dirty = self.is_dirty();
+ let old_version = self.version.clone();
+ if let Some(operation) = self.text.undo_transaction(transaction_id) {
+ self.send_operation(Operation::Buffer(operation), cx);
+ self.did_edit(&old_version, was_dirty, cx);
+ true
+ } else {
+ false
+ }
+ }
+
pub fn undo_to_transaction(
&mut self,
transaction_id: TransactionId,
@@ -9,6 +9,7 @@ path = "src/quick_action_bar.rs"
doctest = false
[dependencies]
+ai = { path = "../ai" }
editor = { path = "../editor" }
gpui = { path = "../gpui" }
search = { path = "../search" }
@@ -1,25 +1,29 @@
+use ai::{assistant::InlineAssist, AssistantPanel};
use editor::Editor;
use gpui::{
elements::{Empty, Flex, MouseEventHandler, ParentElement, Svg},
platform::{CursorStyle, MouseButton},
Action, AnyElement, Element, Entity, EventContext, Subscription, View, ViewContext, ViewHandle,
+ WeakViewHandle,
};
use search::{buffer_search, BufferSearchBar};
-use workspace::{item::ItemHandle, ToolbarItemLocation, ToolbarItemView};
+use workspace::{item::ItemHandle, ToolbarItemLocation, ToolbarItemView, Workspace};
pub struct QuickActionBar {
buffer_search_bar: ViewHandle<BufferSearchBar>,
active_item: Option<Box<dyn ItemHandle>>,
_inlay_hints_enabled_subscription: Option<Subscription>,
+ workspace: WeakViewHandle<Workspace>,
}
impl QuickActionBar {
- pub fn new(buffer_search_bar: ViewHandle<BufferSearchBar>) -> Self {
+ pub fn new(buffer_search_bar: ViewHandle<BufferSearchBar>, workspace: &Workspace) -> Self {
Self {
buffer_search_bar,
active_item: None,
_inlay_hints_enabled_subscription: None,
+ workspace: workspace.weak_handle(),
}
}
@@ -88,6 +92,21 @@ impl View for QuickActionBar {
));
}
+ bar.add_child(render_quick_action_bar_button(
+ 2,
+ "icons/radix/magic-wand.svg",
+ false,
+ ("Inline Assist".into(), Some(Box::new(InlineAssist))),
+ cx,
+ move |this, cx| {
+ if let Some(workspace) = this.workspace.upgrade(cx) {
+ workspace.update(cx, |workspace, cx| {
+ AssistantPanel::inline_assist(workspace, &Default::default(), cx);
+ });
+ }
+ },
+ ));
+
bar.into_any()
}
}
@@ -384,6 +384,16 @@ impl<'a> From<&'a str> for Rope {
}
}
+impl<'a> FromIterator<&'a str> for Rope {
+ fn from_iter<T: IntoIterator<Item = &'a str>>(iter: T) -> Self {
+ let mut rope = Rope::new();
+ for chunk in iter {
+ rope.push(chunk);
+ }
+ rope
+ }
+}
+
impl From<String> for Rope {
fn from(text: String) -> Self {
Rope::from(text.as_str())
@@ -22,6 +22,7 @@ use postage::{oneshot, prelude::*};
pub use rope::*;
pub use selection::*;
+use util::ResultExt;
use std::{
cmp::{self, Ordering, Reverse},
@@ -263,7 +264,19 @@ impl History {
}
}
- fn remove_from_undo(&mut self, transaction_id: TransactionId) -> &[HistoryEntry] {
+ fn remove_from_undo(&mut self, transaction_id: TransactionId) -> Option<&HistoryEntry> {
+ assert_eq!(self.transaction_depth, 0);
+
+ let entry_ix = self
+ .undo_stack
+ .iter()
+ .rposition(|entry| entry.transaction.id == transaction_id)?;
+ let entry = self.undo_stack.remove(entry_ix);
+ self.redo_stack.push(entry);
+ self.redo_stack.last()
+ }
+
+ fn remove_from_undo_until(&mut self, transaction_id: TransactionId) -> &[HistoryEntry] {
assert_eq!(self.transaction_depth, 0);
let redo_stack_start_len = self.redo_stack.len();
@@ -278,20 +291,43 @@ impl History {
&self.redo_stack[redo_stack_start_len..]
}
- fn forget(&mut self, transaction_id: TransactionId) {
+ fn forget(&mut self, transaction_id: TransactionId) -> Option<Transaction> {
assert_eq!(self.transaction_depth, 0);
if let Some(entry_ix) = self
.undo_stack
.iter()
.rposition(|entry| entry.transaction.id == transaction_id)
{
- self.undo_stack.remove(entry_ix);
+ Some(self.undo_stack.remove(entry_ix).transaction)
} else if let Some(entry_ix) = self
.redo_stack
.iter()
.rposition(|entry| entry.transaction.id == transaction_id)
{
- self.undo_stack.remove(entry_ix);
+ Some(self.redo_stack.remove(entry_ix).transaction)
+ } else {
+ None
+ }
+ }
+
+ fn transaction_mut(&mut self, transaction_id: TransactionId) -> Option<&mut Transaction> {
+ let entry = self
+ .undo_stack
+ .iter_mut()
+ .rfind(|entry| entry.transaction.id == transaction_id)
+ .or_else(|| {
+ self.redo_stack
+ .iter_mut()
+ .rfind(|entry| entry.transaction.id == transaction_id)
+ })?;
+ Some(&mut entry.transaction)
+ }
+
+ fn merge_transactions(&mut self, transaction: TransactionId, destination: TransactionId) {
+ if let Some(transaction) = self.forget(transaction) {
+ if let Some(destination) = self.transaction_mut(destination) {
+ destination.edit_ids.extend(transaction.edit_ids);
+ }
}
}
@@ -1183,11 +1219,20 @@ impl Buffer {
}
}
+ pub fn undo_transaction(&mut self, transaction_id: TransactionId) -> Option<Operation> {
+ let transaction = self
+ .history
+ .remove_from_undo(transaction_id)?
+ .transaction
+ .clone();
+ self.undo_or_redo(transaction).log_err()
+ }
+
#[allow(clippy::needless_collect)]
pub fn undo_to_transaction(&mut self, transaction_id: TransactionId) -> Vec<Operation> {
let transactions = self
.history
- .remove_from_undo(transaction_id)
+ .remove_from_undo_until(transaction_id)
.iter()
.map(|entry| entry.transaction.clone())
.collect::<Vec<_>>();
@@ -1202,6 +1247,10 @@ impl Buffer {
self.history.forget(transaction_id);
}
+ pub fn merge_transactions(&mut self, transaction: TransactionId, destination: TransactionId) {
+ self.history.merge_transactions(transaction, destination);
+ }
+
pub fn redo(&mut self) -> Option<(TransactionId, Operation)> {
if let Some(entry) = self.history.pop_redo() {
let transaction = entry.transaction.clone();
@@ -1150,6 +1150,17 @@ pub struct AssistantStyle {
pub api_key_editor: FieldEditor,
pub api_key_prompt: ContainedText,
pub saved_conversation: SavedConversation,
+ pub inline: InlineAssistantStyle,
+}
+
+#[derive(Clone, Deserialize, Default, JsonSchema)]
+pub struct InlineAssistantStyle {
+ #[serde(flatten)]
+ pub container: ContainerStyle,
+ pub editor: FieldEditor,
+ pub disabled_editor: FieldEditor,
+ pub pending_edit_background: Color,
+ pub include_conversation: ToggleIconButtonStyle,
}
#[derive(Clone, Deserialize, Default, JsonSchema)]
@@ -264,8 +264,9 @@ pub fn initialize_workspace(
toolbar.add_item(breadcrumbs, cx);
let buffer_search_bar = cx.add_view(BufferSearchBar::new);
toolbar.add_item(buffer_search_bar.clone(), cx);
- let quick_action_bar =
- cx.add_view(|_| QuickActionBar::new(buffer_search_bar));
+ let quick_action_bar = cx.add_view(|_| {
+ QuickActionBar::new(buffer_search_bar, workspace)
+ });
toolbar.add_item(quick_action_bar, cx);
let project_search_bar = cx.add_view(|_| ProjectSearchBar::new());
toolbar.add_item(project_search_bar, cx);
@@ -1,5 +1,5 @@
import { text, border, background, foreground, TextStyle } from "./components"
-import { Interactive, interactive } from "../element"
+import { Interactive, interactive, toggleable } from "../element"
import { tab_bar_button } from "../component/tab_bar_button"
import { StyleSets, useTheme } from "../theme"
@@ -59,6 +59,85 @@ export default function assistant(): any {
background: background(theme.highest),
padding: { left: 12 },
},
+ inline: {
+ background: background(theme.highest),
+ margin: { top: 3, bottom: 3 },
+ border: border(theme.lowest, "on", {
+ top: true,
+ bottom: true,
+ overlay: true,
+ }),
+ editor: {
+ text: text(theme.highest, "mono", "default", { size: "sm" }),
+ placeholder_text: text(theme.highest, "sans", "on", "disabled"),
+ selection: theme.players[0],
+ },
+ disabled_editor: {
+ text: text(theme.highest, "mono", "disabled", { size: "sm" }),
+ placeholder_text: text(theme.highest, "sans", "on", "disabled"),
+ selection: {
+ cursor: text(theme.highest, "mono", "disabled").color,
+ selection: theme.players[0].selection,
+ },
+ },
+ pending_edit_background: background(theme.highest, "positive"),
+ include_conversation: toggleable({
+ base: interactive({
+ base: {
+ icon_size: 12,
+ color: foreground(theme.highest, "variant"),
+
+ button_width: 12,
+ background: background(theme.highest, "on"),
+ corner_radius: 2,
+ border: {
+ width: 1., color: background(theme.highest, "on")
+ },
+ padding: {
+ left: 4,
+ right: 4,
+ top: 4,
+ bottom: 4,
+ },
+ },
+ state: {
+ hovered: {
+ ...text(theme.highest, "mono", "variant", "hovered"),
+ background: background(theme.highest, "on", "hovered"),
+ border: {
+ width: 1., color: background(theme.highest, "on", "hovered")
+ },
+ },
+ clicked: {
+ ...text(theme.highest, "mono", "variant", "pressed"),
+ background: background(theme.highest, "on", "pressed"),
+ border: {
+ width: 1., color: background(theme.highest, "on", "pressed")
+ },
+ },
+ },
+ }),
+ state: {
+ active: {
+ default: {
+ icon_size: 12,
+ button_width: 12,
+ color: foreground(theme.highest, "variant"),
+ background: background(theme.highest, "accent"),
+ border: border(theme.highest, "accent"),
+ },
+ hovered: {
+ background: background(theme.highest, "accent", "hovered"),
+ border: border(theme.highest, "accent", "hovered"),
+ },
+ clicked: {
+ background: background(theme.highest, "accent", "pressed"),
+ border: border(theme.highest, "accent", "pressed"),
+ },
+ },
+ },
+ }),
+ },
message_header: {
margin: { bottom: 4, top: 4 },
background: background(theme.highest),