diff --git a/Cargo.lock b/Cargo.lock index 4902050017c3143965c3764f1b84d7c4fcaaba85..a4b12223e5a8fe6770464280bf652e08346a26ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,7 @@ dependencies = [ "serde", "serde_json", "settings", + "smol", "theme", "tiktoken-rs", "util", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 9d67cbd108e79145db2bae2c709ee4d7c0b61660..7f8954bb21ea88a0b14f7fd5bacf26743de3c6be 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -28,6 +28,7 @@ isahc.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true +smol.workspace = true tiktoken-rs = "0.4" [dev-dependencies] diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 40224b3229de1665e3fac89be0d035154e2cf67f..b3b62c6a2422f0a354f27a34f0e28ffea23af95f 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use std::fmt::{self, Display}; // Data types for chat completion requests -#[derive(Serialize)] +#[derive(Debug, Serialize)] struct OpenAIRequest { model: String, messages: Vec, diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index d5c50146c35b5d463bf573a082ac438c0ec16d11..ce816f147bac15db7a12ac2dbc7666303ff098b5 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -473,7 +473,7 @@ impl Assistant { language_registry: Arc, cx: &mut ModelContext, ) -> Self { - let model = "gpt-3.5-turbo"; + let model = "gpt-3.5-turbo-0613"; let markdown = language_registry.language_for_name("Markdown"); let buffer = cx.add_model(|cx| { let mut buffer = Buffer::new(0, "", cx); @@ -518,7 +518,7 @@ impl Assistant { MessageMetadata { role: Role::User, sent_at: Local::now(), - error: None, + status: MessageStatus::Done, }, ); @@ -543,7 +543,7 @@ impl Assistant { fn count_remaining_tokens(&mut self, cx: &mut ModelContext) { let messages = self - .open_ai_request_messages(cx) + .messages(cx) .into_iter() .filter_map(|message| { Some(tiktoken_rs::ChatCompletionRequestMessage { @@ -552,7 +552,7 @@ impl Assistant { Role::Assistant => "assistant".into(), Role::System => "system".into(), }, - content: message.content, + content: self.buffer.read(cx).text_for_range(message.range).collect(), name: None, }) }) @@ -589,97 +589,169 @@ impl Assistant { cx.notify(); } - fn assist(&mut self, cx: &mut ModelContext) -> Option<(MessageAnchor, MessageAnchor)> { - let request = OpenAIRequest { - model: self.model.clone(), - messages: self.open_ai_request_messages(cx), - stream: true, - }; + fn assist( + &mut self, + selected_messages: HashSet, + cx: &mut ModelContext, + ) -> Vec { + let mut user_messages = Vec::new(); + let mut tasks = Vec::new(); + for selected_message_id in selected_messages { + let selected_message_role = + if let Some(metadata) = self.messages_metadata.get(&selected_message_id) { + metadata.role + } else { + continue; + }; + + if selected_message_role == Role::Assistant { + if let Some(user_message) = self.insert_message_after( + selected_message_id, + Role::User, + MessageStatus::Done, + cx, + ) { + user_messages.push(user_message); + } else { + continue; + } + } else { + let request = OpenAIRequest { + model: self.model.clone(), + messages: self + .messages(cx) + .filter(|message| matches!(message.status, MessageStatus::Done)) + .flat_map(|message| { + let mut system_message = None; + if message.id == selected_message_id { + system_message = Some(RequestMessage { + role: Role::System, + content: concat!( + "Treat the following messages as additional knowledge you have learned about, ", + "but act as if they were not part of this conversation. That is, treat them ", + "as if the user didn't see them and couldn't possibly inquire about them." + ).into() + }); + } + + Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message) + }) + .chain(Some(RequestMessage { + role: Role::System, + content: format!( + "Direct your reply to message with id {}. Do not include a [Message X] header.", + selected_message_id.0 + ), + })) + .collect(), + stream: true, + }; + + let Some(api_key) = self.api_key.borrow().clone() else { continue }; + let stream = stream_completion(api_key, cx.background().clone(), request); + let assistant_message = self + .insert_message_after( + selected_message_id, + Role::Assistant, + MessageStatus::Pending, + cx, + ) + .unwrap(); + + tasks.push(cx.spawn_weak({ + |this, mut cx| async move { + let assistant_message_id = assistant_message.id; + let stream_completion = async { + let mut messages = stream.await?; + + while let Some(message) = messages.next().await { + let mut message = message?; + if let Some(choice) = message.choices.pop() { + this.upgrade(&cx) + .ok_or_else(|| anyhow!("assistant was dropped"))? + .update(&mut cx, |this, cx| { + let text: Arc = choice.delta.content?.into(); + let message_ix = this.message_anchors.iter().position( + |message| message.id == assistant_message_id, + )?; + this.buffer.update(cx, |buffer, cx| { + let offset = this.message_anchors[message_ix + 1..] + .iter() + .find(|message| message.start.is_valid(buffer)) + .map_or(buffer.len(), |message| { + message + .start + .to_offset(buffer) + .saturating_sub(1) + }); + buffer.edit([(offset..offset, text)], None, cx); + }); + cx.emit(AssistantEvent::StreamedCompletion); + + Some(()) + }); + } + smol::future::yield_now().await; + } - let api_key = self.api_key.borrow().clone()?; - let stream = stream_completion(api_key, cx.background().clone(), request); - let assistant_message = - self.insert_message_after(self.message_anchors.last()?.id, Role::Assistant, cx)?; - let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?; - let task = cx.spawn_weak({ - |this, mut cx| async move { - let assistant_message_id = assistant_message.id; - let stream_completion = async { - let mut messages = stream.await?; - - while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { this.upgrade(&cx) .ok_or_else(|| anyhow!("assistant was dropped"))? .update(&mut cx, |this, cx| { - let text: Arc = choice.delta.content?.into(); - let message_ix = this - .message_anchors - .iter() - .position(|message| message.id == assistant_message_id)?; - this.buffer.update(cx, |buffer, cx| { - let offset = if message_ix + 1 == this.message_anchors.len() - { - buffer.len() - } else { - this.message_anchors[message_ix + 1] - .start - .to_offset(buffer) - .saturating_sub(1) - }; - buffer.edit([(offset..offset, text)], None, cx); + this.pending_completions.retain(|completion| { + completion.id != this.completion_count }); - cx.emit(AssistantEvent::StreamedCompletion); - - Some(()) + this.summarize(cx); }); - } - } - - this.upgrade(&cx) - .ok_or_else(|| anyhow!("assistant was dropped"))? - .update(&mut cx, |this, cx| { - this.pending_completions - .retain(|completion| completion.id != this.completion_count); - this.summarize(cx); - }); - - anyhow::Ok(()) - }; - let result = stream_completion.await; - if let Some(this) = this.upgrade(&cx) { - this.update(&mut cx, |this, cx| { - if let Err(error) = result { - if let Some(metadata) = - this.messages_metadata.get_mut(&assistant_message.id) - { - metadata.error = Some(error.to_string().trim().into()); - cx.notify(); - } + anyhow::Ok(()) + }; + + let result = stream_completion.await; + if let Some(this) = this.upgrade(&cx) { + this.update(&mut cx, |this, cx| { + if let Some(metadata) = + this.messages_metadata.get_mut(&assistant_message.id) + { + match result { + Ok(_) => { + metadata.status = MessageStatus::Done; + } + Err(error) => { + metadata.status = MessageStatus::Error( + error.to_string().trim().into(), + ); + } + } + cx.notify(); + } + }); } - }); - } + } + })); } - }); + } - self.pending_completions.push(PendingCompletion { - id: post_inc(&mut self.completion_count), - _task: task, - }); - Some((assistant_message, user_message)) + if !tasks.is_empty() { + self.pending_completions.push(PendingCompletion { + id: post_inc(&mut self.completion_count), + _tasks: tasks, + }); + } + + user_messages } fn cancel_last_assist(&mut self) -> bool { self.pending_completions.pop().is_some() } - fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext) { - if let Some(metadata) = self.messages_metadata.get_mut(&id) { - metadata.role.cycle(); - cx.emit(AssistantEvent::MessagesEdited); - cx.notify(); + fn cycle_message_roles(&mut self, ids: HashSet, cx: &mut ModelContext) { + for id in ids { + if let Some(metadata) = self.messages_metadata.get_mut(&id) { + metadata.role.cycle(); + cx.emit(AssistantEvent::MessagesEdited); + cx.notify(); + } } } @@ -687,6 +759,7 @@ impl Assistant { &mut self, message_id: MessageId, role: Role, + status: MessageStatus, cx: &mut ModelContext, ) -> Option { if let Some(prev_message_ix) = self @@ -713,7 +786,7 @@ impl Assistant { MessageMetadata { role, sent_at: Local::now(), - error: None, + status, }, ); cx.emit(AssistantEvent::MessagesEdited); @@ -772,7 +845,7 @@ impl Assistant { MessageMetadata { role, sent_at: Local::now(), - error: None, + status: MessageStatus::Done, }, ); @@ -814,7 +887,7 @@ impl Assistant { MessageMetadata { role, sent_at: Local::now(), - error: None, + status: MessageStatus::Done, }, ); (Some(selection), Some(suffix)) @@ -833,16 +906,19 @@ impl Assistant { if self.message_anchors.len() >= 2 && self.summary.is_none() { let api_key = self.api_key.borrow().clone(); if let Some(api_key) = api_key { - let mut messages = self.open_ai_request_messages(cx); - messages.truncate(2); - messages.push(RequestMessage { - role: Role::User, - content: "Summarize the conversation into a short title without punctuation" - .into(), - }); + let messages = self + .messages(cx) + .take(2) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .chain(Some(RequestMessage { + role: Role::User, + content: + "Summarize the conversation into a short title without punctuation" + .into(), + })); let request = OpenAIRequest { model: self.model.clone(), - messages, + messages: messages.collect(), stream: true, }; @@ -870,24 +946,39 @@ impl Assistant { } } - fn open_ai_request_messages(&self, cx: &AppContext) -> Vec { - let buffer = self.buffer.read(cx); - self.messages(cx) - .map(|message| RequestMessage { - role: message.role, - content: buffer.text_for_range(message.range).collect(), - }) - .collect() + fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option { + self.messages_for_offsets([offset], cx).pop() } - fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option { + fn messages_for_offsets( + &self, + offsets: impl IntoIterator, + cx: &AppContext, + ) -> Vec { + let mut result = Vec::new(); + + let buffer_len = self.buffer.read(cx).len(); let mut messages = self.messages(cx).peekable(); - while let Some(message) = messages.next() { - if message.range.contains(&offset) || messages.peek().is_none() { - return Some(message); + let mut offsets = offsets.into_iter().peekable(); + while let Some(offset) = offsets.next() { + // Skip messages that start after the offset. + while messages.peek().map_or(false, |message| { + message.range.end < offset || (message.range.end == offset && offset < buffer_len) + }) { + messages.next(); + } + let Some(message) = messages.peek() else { continue }; + + // Skip offsets that are in the same message. + while offsets.peek().map_or(false, |offset| { + message.range.contains(offset) || message.range.end == buffer_len + }) { + offsets.next(); } + + result.push(message.clone()); } - None + result } fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator { @@ -916,7 +1007,7 @@ impl Assistant { anchor: message_anchor.start, role: metadata.role, sent_at: metadata.sent_at, - error: metadata.error.clone(), + status: metadata.status.clone(), }); } None @@ -926,7 +1017,7 @@ impl Assistant { struct PendingCompletion { id: usize, - _task: Task<()>, + _tasks: Vec>, } enum AssistantEditorEvent { @@ -979,20 +1070,31 @@ impl AssistantEditor { } fn assist(&mut self, _: &Assist, cx: &mut ViewContext) { - let user_message = self.assistant.update(cx, |assistant, cx| { - let (_, user_message) = assistant.assist(cx)?; - Some(user_message) + let cursors = self.cursors(cx); + + let user_messages = self.assistant.update(cx, |assistant, cx| { + let selected_messages = assistant + .messages_for_offsets(cursors, cx) + .into_iter() + .map(|message| message.id) + .collect(); + assistant.assist(selected_messages, cx) }); - - if let Some(user_message) = user_message { - let cursor = user_message - .start - .to_offset(&self.assistant.read(cx).buffer.read(cx)); + let new_selections = user_messages + .iter() + .map(|message| { + let cursor = message + .start + .to_offset(self.assistant.read(cx).buffer.read(cx)); + cursor..cursor + }) + .collect::>(); + if !new_selections.is_empty() { self.editor.update(cx, |editor, cx| { editor.change_selections( Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)), cx, - |selections| selections.select_ranges([cursor..cursor]), + |selections| selections.select_ranges(new_selections), ); }); } @@ -1008,14 +1110,25 @@ impl AssistantEditor { } fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext) { - let cursor_offset = self.editor.read(cx).selections.newest(cx).head(); + let cursors = self.cursors(cx); self.assistant.update(cx, |assistant, cx| { - if let Some(message) = assistant.message_for_offset(cursor_offset, cx) { - assistant.cycle_message_role(message.id, cx); - } + let messages = assistant + .messages_for_offsets(cursors, cx) + .into_iter() + .map(|message| message.id) + .collect(); + assistant.cycle_message_roles(messages, cx) }); } + fn cursors(&self, cx: &AppContext) -> Vec { + let selections = self.editor.read(cx).selections.all::(cx); + selections + .into_iter() + .map(|selection| selection.head()) + .collect() + } + fn handle_assistant_event( &mut self, _: ModelHandle, @@ -1144,7 +1257,10 @@ impl AssistantEditor { let assistant = assistant.clone(); move |_, _, cx| { assistant.update(cx, |assistant, cx| { - assistant.cycle_message_role(message_id, cx) + assistant.cycle_message_roles( + HashSet::from_iter(Some(message_id)), + cx, + ) }) } }); @@ -1160,22 +1276,28 @@ impl AssistantEditor { .with_style(style.sent_at.container) .aligned(), ) - .with_children(message.error.as_ref().map(|error| { - Svg::new("icons/circle_x_mark_12.svg") - .with_color(style.error_icon.color) - .constrained() - .with_width(style.error_icon.width) - .contained() - .with_style(style.error_icon.container) - .with_tooltip::( - message_id.0, - error.to_string(), - None, - theme.tooltip.clone(), - cx, + .with_children( + if let MessageStatus::Error(error) = &message.status { + Some( + Svg::new("icons/circle_x_mark_12.svg") + .with_color(style.error_icon.color) + .constrained() + .with_width(style.error_icon.width) + .contained() + .with_style(style.error_icon.container) + .with_tooltip::( + message_id.0, + error.to_string(), + None, + theme.tooltip.clone(), + cx, + ) + .aligned(), ) - .aligned() - })) + } else { + None + }, + ) .aligned() .left() .contained() @@ -1308,8 +1430,8 @@ impl AssistantEditor { fn cycle_model(&mut self, cx: &mut ViewContext) { self.assistant.update(cx, |assistant, cx| { let new_model = match assistant.model.as_str() { - "gpt-4" => "gpt-3.5-turbo", - _ => "gpt-4", + "gpt-4-0613" => "gpt-3.5-turbo-0613", + _ => "gpt-4-0613", }; assistant.set_model(new_model.into(), cx); }); @@ -1423,7 +1545,14 @@ struct MessageAnchor { struct MessageMetadata { role: Role, sent_at: DateTime, - error: Option>, + status: MessageStatus, +} + +#[derive(Clone, Debug)] +enum MessageStatus { + Pending, + Done, + Error(Arc), } #[derive(Clone, Debug)] @@ -1434,7 +1563,18 @@ pub struct Message { anchor: language::Anchor, role: Role, sent_at: DateTime, - error: Option>, + status: MessageStatus, +} + +impl Message { + fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage { + let mut content = format!("[Message {}]\n", self.id.0).to_string(); + content.extend(buffer.text_for_range(self.range.clone())); + RequestMessage { + role: self.role, + content, + } + } } async fn stream_completion( @@ -1542,7 +1682,7 @@ mod tests { let message_2 = assistant.update(cx, |assistant, cx| { assistant - .insert_message_after(message_1.id, Role::Assistant, cx) + .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) .unwrap() }); assert_eq!( @@ -1566,7 +1706,7 @@ mod tests { let message_3 = assistant.update(cx, |assistant, cx| { assistant - .insert_message_after(message_2.id, Role::User, cx) + .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) .unwrap() }); assert_eq!( @@ -1580,7 +1720,7 @@ mod tests { let message_4 = assistant.update(cx, |assistant, cx| { assistant - .insert_message_after(message_2.id, Role::User, cx) + .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) .unwrap() }); assert_eq!( @@ -1641,7 +1781,7 @@ mod tests { // Ensure we can still insert after a merged message. let message_5 = assistant.update(cx, |assistant, cx| { assistant - .insert_message_after(message_1.id, Role::System, cx) + .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) .unwrap() }); assert_eq!( @@ -1747,6 +1887,66 @@ mod tests { ); } + #[gpui::test] + fn test_messages_for_offsets(cx: &mut AppContext) { + let registry = Arc::new(LanguageRegistry::test()); + let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx)); + let buffer = assistant.read(cx).buffer.clone(); + + let message_1 = assistant.read(cx).message_anchors[0].clone(); + assert_eq!( + messages(&assistant, cx), + vec![(message_1.id, Role::User, 0..0)] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx)); + let message_2 = assistant + .update(cx, |assistant, cx| { + assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx) + }) + .unwrap(); + buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx)); + + let message_3 = assistant + .update(cx, |assistant, cx| { + assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) + }) + .unwrap(); + buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx)); + + assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_2.id, Role::User, 4..8), + (message_3.id, Role::User, 8..11) + ] + ); + + assert_eq!( + message_ids_for_offsets(&assistant, &[0, 4, 9], cx), + [message_1.id, message_2.id, message_3.id] + ); + assert_eq!( + message_ids_for_offsets(&assistant, &[0, 1, 11], cx), + [message_1.id, message_3.id] + ); + + fn message_ids_for_offsets( + assistant: &ModelHandle, + offsets: &[usize], + cx: &AppContext, + ) -> Vec { + assistant + .read(cx) + .messages_for_offsets(offsets.iter().copied(), cx) + .into_iter() + .map(|message| message.id) + .collect() + } + } + fn messages( assistant: &ModelHandle, cx: &AppContext,