@@ -8,7 +8,7 @@ use collections::{HashMap, HashSet};
use editor::{
display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
- Anchor, Editor, ToOffset as _,
+ Anchor, Editor,
};
use fs::Fs;
use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
@@ -40,7 +40,14 @@ const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
actions!(
assistant,
- [NewContext, Assist, QuoteSelection, ToggleFocus, ResetKey]
+ [
+ NewContext,
+ Assist,
+ Split,
+ QuoteSelection,
+ ToggleFocus,
+ ResetKey
+ ]
);
pub fn init(cx: &mut AppContext) {
@@ -64,6 +71,7 @@ pub fn init(cx: &mut AppContext) {
cx.capture_action(AssistantEditor::cancel_last_assist);
cx.add_action(AssistantEditor::quote_selection);
cx.capture_action(AssistantEditor::copy);
+ cx.capture_action(AssistantEditor::split);
cx.add_action(AssistantPanel::save_api_key);
cx.add_action(AssistantPanel::reset_api_key);
cx.add_action(
@@ -711,6 +719,67 @@ impl Assistant {
}
}
+ fn split_message(
+ &mut self,
+ range: Range<usize>,
+ cx: &mut ModelContext<Self>,
+ ) -> (Option<Message>, Option<Message>) {
+ let start_message = self.message_for_offset(range.start, cx);
+ let end_message = self.message_for_offset(range.end, cx);
+ if let Some((start_message, end_message)) = start_message.zip(end_message) {
+ let (start_message_ix, _, start_message_metadata) = start_message;
+ let (end_message_ix, _, _) = end_message;
+
+ // Prevent splitting when range spans multiple messages.
+ if start_message_ix != end_message_ix {
+ return (None, None);
+ }
+
+ let role = start_message_metadata.role;
+ self.buffer.update(cx, |buffer, cx| {
+ buffer.edit([(range.end..range.end, "\n")], None, cx)
+ });
+ let suffix = Message {
+ id: MessageId(post_inc(&mut self.next_message_id.0)),
+ start: self.buffer.read(cx).anchor_before(range.end + 1),
+ };
+ self.messages.insert(start_message_ix + 1, suffix.clone());
+ self.messages_metadata.insert(
+ suffix.id,
+ MessageMetadata {
+ role,
+ sent_at: Local::now(),
+ error: None,
+ },
+ );
+
+ if range.start == range.end {
+ (None, Some(suffix))
+ } else {
+ self.buffer.update(cx, |buffer, cx| {
+ buffer.edit([(range.start..range.start, "\n")], None, cx)
+ });
+ let selection = Message {
+ id: MessageId(post_inc(&mut self.next_message_id.0)),
+ start: self.buffer.read(cx).anchor_before(range.start + 1),
+ };
+ self.messages
+ .insert(start_message_ix + 1, selection.clone());
+ self.messages_metadata.insert(
+ selection.id,
+ MessageMetadata {
+ role,
+ sent_at: Local::now(),
+ error: None,
+ },
+ );
+ (Some(selection), Some(suffix))
+ }
+ } else {
+ (None, None)
+ }
+ }
+
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
if self.messages.len() >= 2 && self.summary.is_none() {
let api_key = self.api_key.borrow().clone();
@@ -755,35 +824,39 @@ impl Assistant {
fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
let buffer = self.buffer.read(cx);
self.messages(cx)
- .map(|(_message, metadata, range)| RequestMessage {
+ .map(|(_ix, _message, metadata, range)| RequestMessage {
role: metadata.role,
content: buffer.text_for_range(range).collect(),
})
.collect()
}
- fn message_id_for_offset(&self, offset: usize, cx: &AppContext) -> Option<MessageId> {
- Some(
- self.messages(cx)
- .find(|(_, _, range)| range.contains(&offset))
- .map(|(message, _, _)| message)
- .or(self.messages.last())?
- .id,
- )
+ fn message_for_offset<'a>(
+ &'a self,
+ offset: usize,
+ cx: &'a AppContext,
+ ) -> Option<(usize, &Message, &MessageMetadata)> {
+ let mut messages = self.messages(cx).peekable();
+ while let Some((ix, message, metadata, range)) = messages.next() {
+ if range.contains(&offset) || messages.peek().is_none() {
+ return Some((ix, message, metadata));
+ }
+ }
+ None
}
fn messages<'a>(
&'a self,
cx: &'a AppContext,
- ) -> impl 'a + Iterator<Item = (&Message, &MessageMetadata, Range<usize>)> {
+ ) -> impl 'a + Iterator<Item = (usize, &Message, &MessageMetadata, Range<usize>)> {
let buffer = self.buffer.read(cx);
- let mut messages = self.messages.iter().peekable();
+ let mut messages = self.messages.iter().enumerate().peekable();
iter::from_fn(move || {
- while let Some(message) = messages.next() {
+ while let Some((ix, message)) = messages.next() {
let metadata = self.messages_metadata.get(&message.id)?;
let message_start = message.start.to_offset(buffer);
let mut message_end = None;
- while let Some(next_message) = messages.peek() {
+ while let Some((_, next_message)) = messages.peek() {
if next_message.start.is_valid(buffer) {
message_end = Some(next_message.start);
break;
@@ -794,7 +867,7 @@ impl Assistant {
let message_end = message_end
.unwrap_or(language::Anchor::MAX)
.to_offset(buffer);
- return Some((message, metadata, message_start..message_end));
+ return Some((ix, message, metadata, message_start..message_end));
}
None
})
@@ -857,21 +930,7 @@ impl AssistantEditor {
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
let user_message = self.assistant.update(cx, |assistant, cx| {
- let editor = self.editor.read(cx);
- let newest_selection = editor
- .selections
- .newest_anchor()
- .head()
- .to_offset(&editor.buffer().read(cx).snapshot(cx));
- let message_id = assistant.message_id_for_offset(newest_selection, cx)?;
- let metadata = assistant.messages_metadata.get(&message_id)?;
- let user_message = if metadata.role == Role::User {
- let (_, user_message) = assistant.assist(cx)?;
- user_message
- } else {
- let user_message = assistant.insert_message_after(message_id, Role::User, cx)?;
- user_message
- };
+ let (_, user_message) = assistant.assist(cx)?;
Some(user_message)
});
@@ -982,7 +1041,7 @@ impl AssistantEditor {
.assistant
.read(cx)
.messages(cx)
- .map(|(message, metadata, _)| BlockProperties {
+ .map(|(_, message, metadata, _)| BlockProperties {
position: buffer.anchor_in_excerpt(excerpt_id, message.start),
height: 2,
style: BlockStyle::Sticky,
@@ -1147,7 +1206,7 @@ impl AssistantEditor {
let selection = editor.selections.newest::<usize>(cx);
let mut copied_text = String::new();
let mut spanned_messages = 0;
- for (_message, metadata, message_range) in assistant.messages(cx) {
+ for (_ix, _message, metadata, message_range) in assistant.messages(cx) {
if message_range.start >= selection.range().end {
break;
} else if message_range.end >= selection.range().start {
@@ -1174,6 +1233,13 @@ impl AssistantEditor {
cx.propagate_action();
}
+ fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
+ self.assistant.update(cx, |assistant, cx| {
+ let range = self.editor.read(cx).selections.newest::<usize>(cx).range();
+ assistant.split_message(range, cx);
+ });
+ }
+
fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
self.assistant.update(cx, |assistant, cx| {
let new_model = match assistant.model.as_str() {
@@ -1510,6 +1576,30 @@ mod tests {
(message_3.id, Role::User, 4..5)
]
);
+
+ // Split a message into prefix, selection and suffix.
+ buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "3")], None, cx));
+ assert_eq!(
+ messages(&assistant, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_5.id, Role::System, 4..5),
+ (message_3.id, Role::User, 5..6)
+ ]
+ );
+ let (message_6, message_7) =
+ assistant.update(cx, |assistant, cx| assistant.split_message(2..3, cx));
+ let (message_6, message_7) = (message_6.unwrap(), message_7.unwrap());
+ assert_eq!(
+ messages(&assistant, cx),
+ vec![
+ (message_1.id, Role::User, 0..3),
+ (message_6.id, Role::User, 3..5),
+ (message_7.id, Role::User, 5..6),
+ (message_5.id, Role::System, 6..7),
+ (message_3.id, Role::User, 7..8)
+ ]
+ );
}
fn messages(
@@ -1519,7 +1609,7 @@ mod tests {
assistant
.read(cx)
.messages(cx)
- .map(|(message, metadata, range)| (message.id, metadata.role, range))
+ .map(|(_, message, metadata, range)| (message.id, metadata.role, range))
.collect()
}
}