Cargo.lock 🔗
@@ -114,6 +114,7 @@ dependencies = [
"serde",
"serde_json",
"settings",
+ "smol",
"theme",
"tiktoken-rs",
"util",
Antonio Scandurra created
Closes
https://linear.app/zed-industries/issue/Z-2384/hitting-cmd-enter-in-a-user-or-system-message-should-generate-a
Release Notes:
- Introduced the ability to generate assistant messages for any
user/system message, as well as generating multiple assists at the same
time, one for each cursor. (preview-only)
Cargo.lock | 1
crates/ai/Cargo.toml | 1
crates/ai/src/ai.rs | 2
crates/ai/src/assistant.rs | 484 ++++++++++++++++++++++++++++-----------
4 files changed, 345 insertions(+), 143 deletions(-)
@@ -114,6 +114,7 @@ dependencies = [
"serde",
"serde_json",
"settings",
+ "smol",
"theme",
"tiktoken-rs",
"util",
@@ -28,6 +28,7 @@ isahc.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
+smol.workspace = true
tiktoken-rs = "0.4"
[dev-dependencies]
@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use std::fmt::{self, Display};
// Data types for chat completion requests
-#[derive(Serialize)]
+#[derive(Debug, Serialize)]
struct OpenAIRequest {
model: String,
messages: Vec<RequestMessage>,
@@ -473,7 +473,7 @@ impl Assistant {
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
) -> Self {
- let model = "gpt-3.5-turbo";
+ let model = "gpt-3.5-turbo-0613";
let markdown = language_registry.language_for_name("Markdown");
let buffer = cx.add_model(|cx| {
let mut buffer = Buffer::new(0, "", cx);
@@ -518,7 +518,7 @@ impl Assistant {
MessageMetadata {
role: Role::User,
sent_at: Local::now(),
- error: None,
+ status: MessageStatus::Done,
},
);
@@ -543,7 +543,7 @@ impl Assistant {
fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
let messages = self
- .open_ai_request_messages(cx)
+ .messages(cx)
.into_iter()
.filter_map(|message| {
Some(tiktoken_rs::ChatCompletionRequestMessage {
@@ -552,7 +552,7 @@ impl Assistant {
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
- content: message.content,
+ content: self.buffer.read(cx).text_for_range(message.range).collect(),
name: None,
})
})
@@ -589,97 +589,169 @@ impl Assistant {
cx.notify();
}
- fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(MessageAnchor, MessageAnchor)> {
- let request = OpenAIRequest {
- model: self.model.clone(),
- messages: self.open_ai_request_messages(cx),
- stream: true,
- };
+ fn assist(
+ &mut self,
+ selected_messages: HashSet<MessageId>,
+ cx: &mut ModelContext<Self>,
+ ) -> Vec<MessageAnchor> {
+ let mut user_messages = Vec::new();
+ let mut tasks = Vec::new();
+ for selected_message_id in selected_messages {
+ let selected_message_role =
+ if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
+ metadata.role
+ } else {
+ continue;
+ };
+
+ if selected_message_role == Role::Assistant {
+ if let Some(user_message) = self.insert_message_after(
+ selected_message_id,
+ Role::User,
+ MessageStatus::Done,
+ cx,
+ ) {
+ user_messages.push(user_message);
+ } else {
+ continue;
+ }
+ } else {
+ let request = OpenAIRequest {
+ model: self.model.clone(),
+ messages: self
+ .messages(cx)
+ .filter(|message| matches!(message.status, MessageStatus::Done))
+ .flat_map(|message| {
+ let mut system_message = None;
+ if message.id == selected_message_id {
+ system_message = Some(RequestMessage {
+ role: Role::System,
+ content: concat!(
+ "Treat the following messages as additional knowledge you have learned about, ",
+ "but act as if they were not part of this conversation. That is, treat them ",
+ "as if the user didn't see them and couldn't possibly inquire about them."
+ ).into()
+ });
+ }
+
+ Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
+ })
+ .chain(Some(RequestMessage {
+ role: Role::System,
+ content: format!(
+ "Direct your reply to message with id {}. Do not include a [Message X] header.",
+ selected_message_id.0
+ ),
+ }))
+ .collect(),
+ stream: true,
+ };
+
+ let Some(api_key) = self.api_key.borrow().clone() else { continue };
+ let stream = stream_completion(api_key, cx.background().clone(), request);
+ let assistant_message = self
+ .insert_message_after(
+ selected_message_id,
+ Role::Assistant,
+ MessageStatus::Pending,
+ cx,
+ )
+ .unwrap();
+
+ tasks.push(cx.spawn_weak({
+ |this, mut cx| async move {
+ let assistant_message_id = assistant_message.id;
+ let stream_completion = async {
+ let mut messages = stream.await?;
+
+ while let Some(message) = messages.next().await {
+ let mut message = message?;
+ if let Some(choice) = message.choices.pop() {
+ this.upgrade(&cx)
+ .ok_or_else(|| anyhow!("assistant was dropped"))?
+ .update(&mut cx, |this, cx| {
+ let text: Arc<str> = choice.delta.content?.into();
+ let message_ix = this.message_anchors.iter().position(
+ |message| message.id == assistant_message_id,
+ )?;
+ this.buffer.update(cx, |buffer, cx| {
+ let offset = this.message_anchors[message_ix + 1..]
+ .iter()
+ .find(|message| message.start.is_valid(buffer))
+ .map_or(buffer.len(), |message| {
+ message
+ .start
+ .to_offset(buffer)
+ .saturating_sub(1)
+ });
+ buffer.edit([(offset..offset, text)], None, cx);
+ });
+ cx.emit(AssistantEvent::StreamedCompletion);
+
+ Some(())
+ });
+ }
+ smol::future::yield_now().await;
+ }
- let api_key = self.api_key.borrow().clone()?;
- let stream = stream_completion(api_key, cx.background().clone(), request);
- let assistant_message =
- self.insert_message_after(self.message_anchors.last()?.id, Role::Assistant, cx)?;
- let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
- let task = cx.spawn_weak({
- |this, mut cx| async move {
- let assistant_message_id = assistant_message.id;
- let stream_completion = async {
- let mut messages = stream.await?;
-
- while let Some(message) = messages.next().await {
- let mut message = message?;
- if let Some(choice) = message.choices.pop() {
this.upgrade(&cx)
.ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |this, cx| {
- let text: Arc<str> = choice.delta.content?.into();
- let message_ix = this
- .message_anchors
- .iter()
- .position(|message| message.id == assistant_message_id)?;
- this.buffer.update(cx, |buffer, cx| {
- let offset = if message_ix + 1 == this.message_anchors.len()
- {
- buffer.len()
- } else {
- this.message_anchors[message_ix + 1]
- .start
- .to_offset(buffer)
- .saturating_sub(1)
- };
- buffer.edit([(offset..offset, text)], None, cx);
+ this.pending_completions.retain(|completion| {
+ completion.id != this.completion_count
});
- cx.emit(AssistantEvent::StreamedCompletion);
-
- Some(())
+ this.summarize(cx);
});
- }
- }
-
- this.upgrade(&cx)
- .ok_or_else(|| anyhow!("assistant was dropped"))?
- .update(&mut cx, |this, cx| {
- this.pending_completions
- .retain(|completion| completion.id != this.completion_count);
- this.summarize(cx);
- });
-
- anyhow::Ok(())
- };
- let result = stream_completion.await;
- if let Some(this) = this.upgrade(&cx) {
- this.update(&mut cx, |this, cx| {
- if let Err(error) = result {
- if let Some(metadata) =
- this.messages_metadata.get_mut(&assistant_message.id)
- {
- metadata.error = Some(error.to_string().trim().into());
- cx.notify();
- }
+ anyhow::Ok(())
+ };
+
+ let result = stream_completion.await;
+ if let Some(this) = this.upgrade(&cx) {
+ this.update(&mut cx, |this, cx| {
+ if let Some(metadata) =
+ this.messages_metadata.get_mut(&assistant_message.id)
+ {
+ match result {
+ Ok(_) => {
+ metadata.status = MessageStatus::Done;
+ }
+ Err(error) => {
+ metadata.status = MessageStatus::Error(
+ error.to_string().trim().into(),
+ );
+ }
+ }
+ cx.notify();
+ }
+ });
}
- });
- }
+ }
+ }));
}
- });
+ }
- self.pending_completions.push(PendingCompletion {
- id: post_inc(&mut self.completion_count),
- _task: task,
- });
- Some((assistant_message, user_message))
+ if !tasks.is_empty() {
+ self.pending_completions.push(PendingCompletion {
+ id: post_inc(&mut self.completion_count),
+ _tasks: tasks,
+ });
+ }
+
+ user_messages
}
fn cancel_last_assist(&mut self) -> bool {
self.pending_completions.pop().is_some()
}
- fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
- if let Some(metadata) = self.messages_metadata.get_mut(&id) {
- metadata.role.cycle();
- cx.emit(AssistantEvent::MessagesEdited);
- cx.notify();
+ fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
+ for id in ids {
+ if let Some(metadata) = self.messages_metadata.get_mut(&id) {
+ metadata.role.cycle();
+ cx.emit(AssistantEvent::MessagesEdited);
+ cx.notify();
+ }
}
}
@@ -687,6 +759,7 @@ impl Assistant {
&mut self,
message_id: MessageId,
role: Role,
+ status: MessageStatus,
cx: &mut ModelContext<Self>,
) -> Option<MessageAnchor> {
if let Some(prev_message_ix) = self
@@ -713,7 +786,7 @@ impl Assistant {
MessageMetadata {
role,
sent_at: Local::now(),
- error: None,
+ status,
},
);
cx.emit(AssistantEvent::MessagesEdited);
@@ -772,7 +845,7 @@ impl Assistant {
MessageMetadata {
role,
sent_at: Local::now(),
- error: None,
+ status: MessageStatus::Done,
},
);
@@ -814,7 +887,7 @@ impl Assistant {
MessageMetadata {
role,
sent_at: Local::now(),
- error: None,
+ status: MessageStatus::Done,
},
);
(Some(selection), Some(suffix))
@@ -833,16 +906,19 @@ impl Assistant {
if self.message_anchors.len() >= 2 && self.summary.is_none() {
let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key {
- let mut messages = self.open_ai_request_messages(cx);
- messages.truncate(2);
- messages.push(RequestMessage {
- role: Role::User,
- content: "Summarize the conversation into a short title without punctuation"
- .into(),
- });
+ let messages = self
+ .messages(cx)
+ .take(2)
+ .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
+ .chain(Some(RequestMessage {
+ role: Role::User,
+ content:
+ "Summarize the conversation into a short title without punctuation"
+ .into(),
+ }));
let request = OpenAIRequest {
model: self.model.clone(),
- messages,
+ messages: messages.collect(),
stream: true,
};
@@ -870,24 +946,39 @@ impl Assistant {
}
}
- fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
- let buffer = self.buffer.read(cx);
- self.messages(cx)
- .map(|message| RequestMessage {
- role: message.role,
- content: buffer.text_for_range(message.range).collect(),
- })
- .collect()
+ fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
+ self.messages_for_offsets([offset], cx).pop()
}
- fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option<Message> {
+ fn messages_for_offsets(
+ &self,
+ offsets: impl IntoIterator<Item = usize>,
+ cx: &AppContext,
+ ) -> Vec<Message> {
+ let mut result = Vec::new();
+
+ let buffer_len = self.buffer.read(cx).len();
let mut messages = self.messages(cx).peekable();
- while let Some(message) = messages.next() {
- if message.range.contains(&offset) || messages.peek().is_none() {
- return Some(message);
+ let mut offsets = offsets.into_iter().peekable();
+ while let Some(offset) = offsets.next() {
+ // Skip messages that start after the offset.
+ while messages.peek().map_or(false, |message| {
+ message.range.end < offset || (message.range.end == offset && offset < buffer_len)
+ }) {
+ messages.next();
+ }
+ let Some(message) = messages.peek() else { continue };
+
+ // Skip offsets that are in the same message.
+ while offsets.peek().map_or(false, |offset| {
+ message.range.contains(offset) || message.range.end == buffer_len
+ }) {
+ offsets.next();
}
+
+ result.push(message.clone());
}
- None
+ result
}
fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
@@ -916,7 +1007,7 @@ impl Assistant {
anchor: message_anchor.start,
role: metadata.role,
sent_at: metadata.sent_at,
- error: metadata.error.clone(),
+ status: metadata.status.clone(),
});
}
None
@@ -926,7 +1017,7 @@ impl Assistant {
struct PendingCompletion {
id: usize,
- _task: Task<()>,
+ _tasks: Vec<Task<()>>,
}
enum AssistantEditorEvent {
@@ -979,20 +1070,31 @@ impl AssistantEditor {
}
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
- let user_message = self.assistant.update(cx, |assistant, cx| {
- let (_, user_message) = assistant.assist(cx)?;
- Some(user_message)
+ let cursors = self.cursors(cx);
+
+ let user_messages = self.assistant.update(cx, |assistant, cx| {
+ let selected_messages = assistant
+ .messages_for_offsets(cursors, cx)
+ .into_iter()
+ .map(|message| message.id)
+ .collect();
+ assistant.assist(selected_messages, cx)
});
-
- if let Some(user_message) = user_message {
- let cursor = user_message
- .start
- .to_offset(&self.assistant.read(cx).buffer.read(cx));
+ let new_selections = user_messages
+ .iter()
+ .map(|message| {
+ let cursor = message
+ .start
+ .to_offset(self.assistant.read(cx).buffer.read(cx));
+ cursor..cursor
+ })
+ .collect::<Vec<_>>();
+ if !new_selections.is_empty() {
self.editor.update(cx, |editor, cx| {
editor.change_selections(
Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
cx,
- |selections| selections.select_ranges([cursor..cursor]),
+ |selections| selections.select_ranges(new_selections),
);
});
}
@@ -1008,14 +1110,25 @@ impl AssistantEditor {
}
fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
- let cursor_offset = self.editor.read(cx).selections.newest(cx).head();
+ let cursors = self.cursors(cx);
self.assistant.update(cx, |assistant, cx| {
- if let Some(message) = assistant.message_for_offset(cursor_offset, cx) {
- assistant.cycle_message_role(message.id, cx);
- }
+ let messages = assistant
+ .messages_for_offsets(cursors, cx)
+ .into_iter()
+ .map(|message| message.id)
+ .collect();
+ assistant.cycle_message_roles(messages, cx)
});
}
+ fn cursors(&self, cx: &AppContext) -> Vec<usize> {
+ let selections = self.editor.read(cx).selections.all::<usize>(cx);
+ selections
+ .into_iter()
+ .map(|selection| selection.head())
+ .collect()
+ }
+
fn handle_assistant_event(
&mut self,
_: ModelHandle<Assistant>,
@@ -1144,7 +1257,10 @@ impl AssistantEditor {
let assistant = assistant.clone();
move |_, _, cx| {
assistant.update(cx, |assistant, cx| {
- assistant.cycle_message_role(message_id, cx)
+ assistant.cycle_message_roles(
+ HashSet::from_iter(Some(message_id)),
+ cx,
+ )
})
}
});
@@ -1160,22 +1276,28 @@ impl AssistantEditor {
.with_style(style.sent_at.container)
.aligned(),
)
- .with_children(message.error.as_ref().map(|error| {
- Svg::new("icons/circle_x_mark_12.svg")
- .with_color(style.error_icon.color)
- .constrained()
- .with_width(style.error_icon.width)
- .contained()
- .with_style(style.error_icon.container)
- .with_tooltip::<ErrorTooltip>(
- message_id.0,
- error.to_string(),
- None,
- theme.tooltip.clone(),
- cx,
+ .with_children(
+ if let MessageStatus::Error(error) = &message.status {
+ Some(
+ Svg::new("icons/circle_x_mark_12.svg")
+ .with_color(style.error_icon.color)
+ .constrained()
+ .with_width(style.error_icon.width)
+ .contained()
+ .with_style(style.error_icon.container)
+ .with_tooltip::<ErrorTooltip>(
+ message_id.0,
+ error.to_string(),
+ None,
+ theme.tooltip.clone(),
+ cx,
+ )
+ .aligned(),
)
- .aligned()
- }))
+ } else {
+ None
+ },
+ )
.aligned()
.left()
.contained()
@@ -1308,8 +1430,8 @@ impl AssistantEditor {
fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
self.assistant.update(cx, |assistant, cx| {
let new_model = match assistant.model.as_str() {
- "gpt-4" => "gpt-3.5-turbo",
- _ => "gpt-4",
+ "gpt-4-0613" => "gpt-3.5-turbo-0613",
+ _ => "gpt-4-0613",
};
assistant.set_model(new_model.into(), cx);
});
@@ -1423,7 +1545,14 @@ struct MessageAnchor {
struct MessageMetadata {
role: Role,
sent_at: DateTime<Local>,
- error: Option<Arc<str>>,
+ status: MessageStatus,
+}
+
+#[derive(Clone, Debug)]
+enum MessageStatus {
+ Pending,
+ Done,
+ Error(Arc<str>),
}
#[derive(Clone, Debug)]
@@ -1434,7 +1563,18 @@ pub struct Message {
anchor: language::Anchor,
role: Role,
sent_at: DateTime<Local>,
- error: Option<Arc<str>>,
+ status: MessageStatus,
+}
+
+impl Message {
+ fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
+ let mut content = format!("[Message {}]\n", self.id.0).to_string();
+ content.extend(buffer.text_for_range(self.range.clone()));
+ RequestMessage {
+ role: self.role,
+ content,
+ }
+ }
}
async fn stream_completion(
@@ -1542,7 +1682,7 @@ mod tests {
let message_2 = assistant.update(cx, |assistant, cx| {
assistant
- .insert_message_after(message_1.id, Role::Assistant, cx)
+ .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@@ -1566,7 +1706,7 @@ mod tests {
let message_3 = assistant.update(cx, |assistant, cx| {
assistant
- .insert_message_after(message_2.id, Role::User, cx)
+ .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@@ -1580,7 +1720,7 @@ mod tests {
let message_4 = assistant.update(cx, |assistant, cx| {
assistant
- .insert_message_after(message_2.id, Role::User, cx)
+ .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@@ -1641,7 +1781,7 @@ mod tests {
// Ensure we can still insert after a merged message.
let message_5 = assistant.update(cx, |assistant, cx| {
assistant
- .insert_message_after(message_1.id, Role::System, cx)
+ .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@@ -1747,6 +1887,66 @@ mod tests {
);
}
+ #[gpui::test]
+ fn test_messages_for_offsets(cx: &mut AppContext) {
+ let registry = Arc::new(LanguageRegistry::test());
+ let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
+ let buffer = assistant.read(cx).buffer.clone();
+
+ let message_1 = assistant.read(cx).message_anchors[0].clone();
+ assert_eq!(
+ messages(&assistant, cx),
+ vec![(message_1.id, Role::User, 0..0)]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
+ let message_2 = assistant
+ .update(cx, |assistant, cx| {
+ assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
+ })
+ .unwrap();
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
+
+ let message_3 = assistant
+ .update(cx, |assistant, cx| {
+ assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
+ })
+ .unwrap();
+ buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
+
+ assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
+ assert_eq!(
+ messages(&assistant, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_2.id, Role::User, 4..8),
+ (message_3.id, Role::User, 8..11)
+ ]
+ );
+
+ assert_eq!(
+ message_ids_for_offsets(&assistant, &[0, 4, 9], cx),
+ [message_1.id, message_2.id, message_3.id]
+ );
+ assert_eq!(
+ message_ids_for_offsets(&assistant, &[0, 1, 11], cx),
+ [message_1.id, message_3.id]
+ );
+
+ fn message_ids_for_offsets(
+ assistant: &ModelHandle<Assistant>,
+ offsets: &[usize],
+ cx: &AppContext,
+ ) -> Vec<MessageId> {
+ assistant
+ .read(cx)
+ .messages_for_offsets(offsets.iter().copied(), cx)
+ .into_iter()
+ .map(|message| message.id)
+ .collect()
+ }
+ }
+
fn messages(
assistant: &ModelHandle<Assistant>,
cx: &AppContext,