From 8673b0b75bdf9a654a05751e44ada4eb0bf07052 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 20 Jun 2023 11:59:51 +0200 Subject: [PATCH] Avoid including pending or errored messages on `assist` --- Cargo.lock | 1 + crates/ai/Cargo.toml | 1 + crates/ai/src/assistant.rs | 148 +++++++++++++++++++++++-------------- 3 files changed, 94 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 24fd67f90ec17ecc7d0440fa9d6dc52996a939b2..957f8c04bb7a57fdae9fab29e1ede0510c70c077 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,7 @@ dependencies = [ "serde", "serde_json", "settings", + "smol", "theme", "tiktoken-rs", "util", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 9d67cbd108e79145db2bae2c709ee4d7c0b61660..7f8954bb21ea88a0b14f7fd5bacf26743de3c6be 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -28,6 +28,7 @@ isahc.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true +smol.workspace = true tiktoken-rs = "0.4" [dev-dependencies] diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 121e4600cc92190df5514710feca3adb495699c0..e6b62e4ea1532d3e9f9fde6f3867b1cda7e143c2 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -518,7 +518,7 @@ impl Assistant { MessageMetadata { role: Role::User, sent_at: Local::now(), - error: None, + status: MessageStatus::Done, }, ); @@ -595,6 +595,7 @@ impl Assistant { cx: &mut ModelContext, ) -> Vec { let mut user_messages = Vec::new(); + let mut tasks = Vec::new(); for selected_message_id in selected_messages { let selected_message_role = if let Some(metadata) = self.messages_metadata.get(&selected_message_id) { @@ -604,15 +605,22 @@ impl Assistant { }; if selected_message_role == Role::Assistant { - let Some(user_message) = self.insert_message_after(selected_message_id, Role::User, cx) else { + if let Some(user_message) = self.insert_message_after( + selected_message_id, + Role::User, + MessageStatus::Done, + cx, + ) { + user_messages.push(user_message); + } else { continue; - }; - user_messages.push(user_message); + } } else { let request = OpenAIRequest { model: self.model.clone(), messages: self .messages(cx) + .filter(|message| matches!(message.status, MessageStatus::Done)) .map(|message| message.to_open_ai_message(self.buffer.read(cx))) .chain(Some(RequestMessage { role: Role::System, @@ -628,10 +636,15 @@ impl Assistant { let Some(api_key) = self.api_key.borrow().clone() else { continue }; let stream = stream_completion(api_key, cx.background().clone(), request); let assistant_message = self - .insert_message_after(selected_message_id, Role::Assistant, cx) + .insert_message_after( + selected_message_id, + Role::Assistant, + MessageStatus::Pending, + cx, + ) .unwrap(); - let task = cx.spawn_weak({ + tasks.push(cx.spawn_weak({ |this, mut cx| async move { let assistant_message_id = assistant_message.id; let stream_completion = async { @@ -648,16 +661,15 @@ impl Assistant { |message| message.id == assistant_message_id, )?; this.buffer.update(cx, |buffer, cx| { - let offset = if message_ix + 1 - == this.message_anchors.len() - { - buffer.len() - } else { - this.message_anchors[message_ix + 1] - .start - .to_offset(buffer) - .saturating_sub(1) - }; + let offset = this.message_anchors[message_ix + 1..] + .iter() + .find(|message| message.start.is_valid(buffer)) + .map_or(buffer.len(), |message| { + message + .start + .to_offset(buffer) + .saturating_sub(1) + }); buffer.edit([(offset..offset, text)], None, cx); }); cx.emit(AssistantEvent::StreamedCompletion); @@ -665,6 +677,7 @@ impl Assistant { Some(()) }); } + smol::future::yield_now().await; } this.upgrade(&cx) @@ -682,26 +695,35 @@ impl Assistant { let result = stream_completion.await; if let Some(this) = this.upgrade(&cx) { this.update(&mut cx, |this, cx| { - if let Err(error) = result { - if let Some(metadata) = - this.messages_metadata.get_mut(&assistant_message.id) - { - metadata.error = Some(error.to_string().trim().into()); - cx.notify(); + if let Some(metadata) = + this.messages_metadata.get_mut(&assistant_message.id) + { + match result { + Ok(_) => { + metadata.status = MessageStatus::Done; + } + Err(error) => { + metadata.status = MessageStatus::Error( + error.to_string().trim().into(), + ); + } } + cx.notify(); } }); } } - }); - - self.pending_completions.push(PendingCompletion { - id: post_inc(&mut self.completion_count), - _task: task, - }); + })); } } + if !tasks.is_empty() { + self.pending_completions.push(PendingCompletion { + id: post_inc(&mut self.completion_count), + _tasks: tasks, + }); + } + user_messages } @@ -723,6 +745,7 @@ impl Assistant { &mut self, message_id: MessageId, role: Role, + status: MessageStatus, cx: &mut ModelContext, ) -> Option { if let Some(prev_message_ix) = self @@ -749,7 +772,7 @@ impl Assistant { MessageMetadata { role, sent_at: Local::now(), - error: None, + status, }, ); cx.emit(AssistantEvent::MessagesEdited); @@ -808,7 +831,7 @@ impl Assistant { MessageMetadata { role, sent_at: Local::now(), - error: None, + status: MessageStatus::Done, }, ); @@ -850,7 +873,7 @@ impl Assistant { MessageMetadata { role, sent_at: Local::now(), - error: None, + status: MessageStatus::Done, }, ); (Some(selection), Some(suffix)) @@ -970,7 +993,7 @@ impl Assistant { anchor: message_anchor.start, role: metadata.role, sent_at: metadata.sent_at, - error: metadata.error.clone(), + status: metadata.status.clone(), }); } None @@ -980,7 +1003,7 @@ impl Assistant { struct PendingCompletion { id: usize, - _task: Task<()>, + _tasks: Vec>, } enum AssistantEditorEvent { @@ -1239,22 +1262,28 @@ impl AssistantEditor { .with_style(style.sent_at.container) .aligned(), ) - .with_children(message.error.as_ref().map(|error| { - Svg::new("icons/circle_x_mark_12.svg") - .with_color(style.error_icon.color) - .constrained() - .with_width(style.error_icon.width) - .contained() - .with_style(style.error_icon.container) - .with_tooltip::( - message_id.0, - error.to_string(), - None, - theme.tooltip.clone(), - cx, + .with_children( + if let MessageStatus::Error(error) = &message.status { + Some( + Svg::new("icons/circle_x_mark_12.svg") + .with_color(style.error_icon.color) + .constrained() + .with_width(style.error_icon.width) + .contained() + .with_style(style.error_icon.container) + .with_tooltip::( + message_id.0, + error.to_string(), + None, + theme.tooltip.clone(), + cx, + ) + .aligned(), ) - .aligned() - })) + } else { + None + }, + ) .aligned() .left() .contained() @@ -1502,7 +1531,14 @@ struct MessageAnchor { struct MessageMetadata { role: Role, sent_at: DateTime, - error: Option>, + status: MessageStatus, +} + +#[derive(Clone, Debug)] +enum MessageStatus { + Pending, + Done, + Error(Arc), } #[derive(Clone, Debug)] @@ -1513,7 +1549,7 @@ pub struct Message { anchor: language::Anchor, role: Role, sent_at: DateTime, - error: Option>, + status: MessageStatus, } impl Message { @@ -1632,7 +1668,7 @@ mod tests { let message_2 = assistant.update(cx, |assistant, cx| { assistant - .insert_message_after(message_1.id, Role::Assistant, cx) + .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) .unwrap() }); assert_eq!( @@ -1656,7 +1692,7 @@ mod tests { let message_3 = assistant.update(cx, |assistant, cx| { assistant - .insert_message_after(message_2.id, Role::User, cx) + .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) .unwrap() }); assert_eq!( @@ -1670,7 +1706,7 @@ mod tests { let message_4 = assistant.update(cx, |assistant, cx| { assistant - .insert_message_after(message_2.id, Role::User, cx) + .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) .unwrap() }); assert_eq!( @@ -1731,7 +1767,7 @@ mod tests { // Ensure we can still insert after a merged message. let message_5 = assistant.update(cx, |assistant, cx| { assistant - .insert_message_after(message_1.id, Role::System, cx) + .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) .unwrap() }); assert_eq!( @@ -1852,14 +1888,14 @@ mod tests { buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx)); let message_2 = assistant .update(cx, |assistant, cx| { - assistant.insert_message_after(message_1.id, Role::User, cx) + assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx) }) .unwrap(); buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx)); let message_3 = assistant .update(cx, |assistant, cx| { - assistant.insert_message_after(message_2.id, Role::User, cx) + assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) }) .unwrap(); buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));