Cargo.lock 🔗
@@ -114,6 +114,7 @@ dependencies = [
"serde",
"serde_json",
"settings",
+ "smol",
"theme",
"tiktoken-rs",
"util",
Antonio Scandurra created
Cargo.lock | 1
crates/ai/Cargo.toml | 1
crates/ai/src/assistant.rs | 148 ++++++++++++++++++++++++---------------
3 files changed, 94 insertions(+), 56 deletions(-)
@@ -114,6 +114,7 @@ dependencies = [
"serde",
"serde_json",
"settings",
+ "smol",
"theme",
"tiktoken-rs",
"util",
@@ -28,6 +28,7 @@ isahc.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
+smol.workspace = true
tiktoken-rs = "0.4"
[dev-dependencies]
@@ -518,7 +518,7 @@ impl Assistant {
MessageMetadata {
role: Role::User,
sent_at: Local::now(),
- error: None,
+ status: MessageStatus::Done,
},
);
@@ -595,6 +595,7 @@ impl Assistant {
cx: &mut ModelContext<Self>,
) -> Vec<MessageAnchor> {
let mut user_messages = Vec::new();
+ let mut tasks = Vec::new();
for selected_message_id in selected_messages {
let selected_message_role =
if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
@@ -604,15 +605,22 @@ impl Assistant {
};
if selected_message_role == Role::Assistant {
- let Some(user_message) = self.insert_message_after(selected_message_id, Role::User, cx) else {
+ if let Some(user_message) = self.insert_message_after(
+ selected_message_id,
+ Role::User,
+ MessageStatus::Done,
+ cx,
+ ) {
+ user_messages.push(user_message);
+ } else {
continue;
- };
- user_messages.push(user_message);
+ }
} else {
let request = OpenAIRequest {
model: self.model.clone(),
messages: self
.messages(cx)
+ .filter(|message| matches!(message.status, MessageStatus::Done))
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
.chain(Some(RequestMessage {
role: Role::System,
@@ -628,10 +636,15 @@ impl Assistant {
let Some(api_key) = self.api_key.borrow().clone() else { continue };
let stream = stream_completion(api_key, cx.background().clone(), request);
let assistant_message = self
- .insert_message_after(selected_message_id, Role::Assistant, cx)
+ .insert_message_after(
+ selected_message_id,
+ Role::Assistant,
+ MessageStatus::Pending,
+ cx,
+ )
.unwrap();
- let task = cx.spawn_weak({
+ tasks.push(cx.spawn_weak({
|this, mut cx| async move {
let assistant_message_id = assistant_message.id;
let stream_completion = async {
@@ -648,16 +661,15 @@ impl Assistant {
|message| message.id == assistant_message_id,
)?;
this.buffer.update(cx, |buffer, cx| {
- let offset = if message_ix + 1
- == this.message_anchors.len()
- {
- buffer.len()
- } else {
- this.message_anchors[message_ix + 1]
- .start
- .to_offset(buffer)
- .saturating_sub(1)
- };
+ let offset = this.message_anchors[message_ix + 1..]
+ .iter()
+ .find(|message| message.start.is_valid(buffer))
+ .map_or(buffer.len(), |message| {
+ message
+ .start
+ .to_offset(buffer)
+ .saturating_sub(1)
+ });
buffer.edit([(offset..offset, text)], None, cx);
});
cx.emit(AssistantEvent::StreamedCompletion);
@@ -665,6 +677,7 @@ impl Assistant {
Some(())
});
}
+ smol::future::yield_now().await;
}
this.upgrade(&cx)
@@ -682,26 +695,35 @@ impl Assistant {
let result = stream_completion.await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
- if let Err(error) = result {
- if let Some(metadata) =
- this.messages_metadata.get_mut(&assistant_message.id)
- {
- metadata.error = Some(error.to_string().trim().into());
- cx.notify();
+ if let Some(metadata) =
+ this.messages_metadata.get_mut(&assistant_message.id)
+ {
+ match result {
+ Ok(_) => {
+ metadata.status = MessageStatus::Done;
+ }
+ Err(error) => {
+ metadata.status = MessageStatus::Error(
+ error.to_string().trim().into(),
+ );
+ }
}
+ cx.notify();
}
});
}
}
- });
-
- self.pending_completions.push(PendingCompletion {
- id: post_inc(&mut self.completion_count),
- _task: task,
- });
+ }));
}
}
+ if !tasks.is_empty() {
+ self.pending_completions.push(PendingCompletion {
+ id: post_inc(&mut self.completion_count),
+ _tasks: tasks,
+ });
+ }
+
user_messages
}
@@ -723,6 +745,7 @@ impl Assistant {
&mut self,
message_id: MessageId,
role: Role,
+ status: MessageStatus,
cx: &mut ModelContext<Self>,
) -> Option<MessageAnchor> {
if let Some(prev_message_ix) = self
@@ -749,7 +772,7 @@ impl Assistant {
MessageMetadata {
role,
sent_at: Local::now(),
- error: None,
+ status,
},
);
cx.emit(AssistantEvent::MessagesEdited);
@@ -808,7 +831,7 @@ impl Assistant {
MessageMetadata {
role,
sent_at: Local::now(),
- error: None,
+ status: MessageStatus::Done,
},
);
@@ -850,7 +873,7 @@ impl Assistant {
MessageMetadata {
role,
sent_at: Local::now(),
- error: None,
+ status: MessageStatus::Done,
},
);
(Some(selection), Some(suffix))
@@ -970,7 +993,7 @@ impl Assistant {
anchor: message_anchor.start,
role: metadata.role,
sent_at: metadata.sent_at,
- error: metadata.error.clone(),
+ status: metadata.status.clone(),
});
}
None
@@ -980,7 +1003,7 @@ impl Assistant {
struct PendingCompletion {
id: usize,
- _task: Task<()>,
+ _tasks: Vec<Task<()>>,
}
enum AssistantEditorEvent {
@@ -1239,22 +1262,28 @@ impl AssistantEditor {
.with_style(style.sent_at.container)
.aligned(),
)
- .with_children(message.error.as_ref().map(|error| {
- Svg::new("icons/circle_x_mark_12.svg")
- .with_color(style.error_icon.color)
- .constrained()
- .with_width(style.error_icon.width)
- .contained()
- .with_style(style.error_icon.container)
- .with_tooltip::<ErrorTooltip>(
- message_id.0,
- error.to_string(),
- None,
- theme.tooltip.clone(),
- cx,
+ .with_children(
+ if let MessageStatus::Error(error) = &message.status {
+ Some(
+ Svg::new("icons/circle_x_mark_12.svg")
+ .with_color(style.error_icon.color)
+ .constrained()
+ .with_width(style.error_icon.width)
+ .contained()
+ .with_style(style.error_icon.container)
+ .with_tooltip::<ErrorTooltip>(
+ message_id.0,
+ error.to_string(),
+ None,
+ theme.tooltip.clone(),
+ cx,
+ )
+ .aligned(),
)
- .aligned()
- }))
+ } else {
+ None
+ },
+ )
.aligned()
.left()
.contained()
@@ -1502,7 +1531,14 @@ struct MessageAnchor {
struct MessageMetadata {
role: Role,
sent_at: DateTime<Local>,
- error: Option<Arc<str>>,
+ status: MessageStatus,
+}
+
+#[derive(Clone, Debug)]
+enum MessageStatus {
+ Pending,
+ Done,
+ Error(Arc<str>),
}
#[derive(Clone, Debug)]
@@ -1513,7 +1549,7 @@ pub struct Message {
anchor: language::Anchor,
role: Role,
sent_at: DateTime<Local>,
- error: Option<Arc<str>>,
+ status: MessageStatus,
}
impl Message {
@@ -1632,7 +1668,7 @@ mod tests {
let message_2 = assistant.update(cx, |assistant, cx| {
assistant
- .insert_message_after(message_1.id, Role::Assistant, cx)
+ .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@@ -1656,7 +1692,7 @@ mod tests {
let message_3 = assistant.update(cx, |assistant, cx| {
assistant
- .insert_message_after(message_2.id, Role::User, cx)
+ .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@@ -1670,7 +1706,7 @@ mod tests {
let message_4 = assistant.update(cx, |assistant, cx| {
assistant
- .insert_message_after(message_2.id, Role::User, cx)
+ .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@@ -1731,7 +1767,7 @@ mod tests {
// Ensure we can still insert after a merged message.
let message_5 = assistant.update(cx, |assistant, cx| {
assistant
- .insert_message_after(message_1.id, Role::System, cx)
+ .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
@@ -1852,14 +1888,14 @@ mod tests {
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
let message_2 = assistant
.update(cx, |assistant, cx| {
- assistant.insert_message_after(message_1.id, Role::User, cx)
+ assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
})
.unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
let message_3 = assistant
.update(cx, |assistant, cx| {
- assistant.insert_message_after(message_2.id, Role::User, cx)
+ assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
})
.unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));