@@ -459,7 +459,7 @@ impl Assistant {
api_key,
buffer,
};
- this.push_message(Role::User, cx);
+ this.insert_message_after(ExcerptId::max(), Role::User, cx);
this.count_remaining_tokens(cx);
this
}
@@ -498,7 +498,7 @@ impl Assistant {
})
.collect::<Vec<_>>();
let model = self.model.clone();
- self.pending_token_count = cx.spawn(|this, mut cx| {
+ self.pending_token_count = cx.spawn_weak(|this, mut cx| {
async move {
cx.background().timer(Duration::from_millis(200)).await;
let token_count = cx
@@ -506,11 +506,13 @@ impl Assistant {
.spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
.await?;
- this.update(&mut cx, |this, cx| {
- this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
- this.token_count = Some(token_count);
- cx.notify()
- });
+ this.upgrade(&cx)
+ .ok_or_else(|| anyhow!("assistant was dropped"))?
+ .update(&mut cx, |this, cx| {
+ this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
+ this.token_count = Some(token_count);
+ cx.notify()
+ });
anyhow::Ok(())
}
.log_err()
@@ -547,9 +549,10 @@ impl Assistant {
let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key {
let stream = stream_completion(api_key, cx.background().clone(), request);
- let (excerpt_id, content) = self.push_message(Role::Assistant, cx);
- self.push_message(Role::User, cx);
- let task = cx.spawn(|this, mut cx| async move {
+ let (excerpt_id, content) =
+ self.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
+ self.insert_message_after(ExcerptId::max(), Role::User, cx);
+ let task = cx.spawn_weak(|this, mut cx| async move {
let stream_completion = async {
let mut messages = stream.await?;
@@ -564,22 +567,26 @@ impl Assistant {
}
}
- this.update(&mut cx, |this, cx| {
- this.pending_completions
- .retain(|completion| completion.id != this.completion_count);
- this.summarize(cx);
- });
+ this.upgrade(&cx)
+ .ok_or_else(|| anyhow!("assistant was dropped"))?
+ .update(&mut cx, |this, cx| {
+ this.pending_completions
+ .retain(|completion| completion.id != this.completion_count);
+ this.summarize(cx);
+ });
anyhow::Ok(())
};
if let Err(error) = stream_completion.await {
- this.update(&mut cx, |this, cx| {
- if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) {
- metadata.error = Some(error.to_string().trim().into());
- cx.notify();
- }
- })
+ if let Some(this) = this.upgrade(&cx) {
+ this.update(&mut cx, |this, cx| {
+ if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) {
+ metadata.error = Some(error.to_string().trim().into());
+ cx.notify();
+ }
+ });
+ }
}
});
@@ -632,8 +639,9 @@ impl Assistant {
}
}
- fn push_message(
+ fn insert_message_after(
&mut self,
+ excerpt_id: ExcerptId,
role: Role,
cx: &mut ModelContext<Self>,
) -> (ExcerptId, ModelHandle<Buffer>) {
@@ -654,9 +662,10 @@ impl Assistant {
buffer.set_language_registry(self.languages.clone());
buffer
});
- let excerpt_id = self.buffer.update(cx, |buffer, cx| {
+ let new_excerpt_id = self.buffer.update(cx, |buffer, cx| {
buffer
- .push_excerpts(
+ .insert_excerpts_after(
+ excerpt_id,
content.clone(),
vec![ExcerptRange {
context: 0..0,
@@ -668,19 +677,27 @@ impl Assistant {
.unwrap()
});
- self.messages.push(Message {
- excerpt_id,
- content: content.clone(),
- });
+ let ix = self
+ .messages
+ .iter()
+ .position(|message| message.excerpt_id == excerpt_id)
+ .map_or(self.messages.len(), |ix| ix + 1);
+ self.messages.insert(
+ ix,
+ Message {
+ excerpt_id: new_excerpt_id,
+ content: content.clone(),
+ },
+ );
self.messages_metadata.insert(
- excerpt_id,
+ new_excerpt_id,
MessageMetadata {
role,
sent_at: Local::now(),
error: None,
},
);
- (excerpt_id, content)
+ (new_excerpt_id, content)
}
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
@@ -882,7 +899,7 @@ impl AssistantEditor {
if metadata.role == Role::User {
assistant.assist(cx);
} else {
- assistant.push_message(Role::User, cx);
+ assistant.insert_message_after(excerpt_id, Role::User, cx);
}
}
}
@@ -1227,3 +1244,28 @@ async fn stream_completion(
}
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use gpui::AppContext;
+
+ #[gpui::test]
+ fn test_inserting_and_removing_messages(cx: &mut AppContext) {
+ let registry = Arc::new(LanguageRegistry::test());
+
+ cx.add_model(|cx| {
+ let mut assistant = Assistant::new(Default::default(), registry, cx);
+ let (excerpt_1, _) =
+ assistant.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
+ let (excerpt_2, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
+ let (excerpt_3, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
+ assistant.remove_empty_messages(
+ HashSet::from_iter([excerpt_2, excerpt_3]),
+ Default::default(),
+ cx,
+ );
+ assistant
+ });
+ }
+}