@@ -6,7 +6,7 @@ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role, StopReason,
};
-use util::ResultExt as _;
+use util::{post_inc, ResultExt as _};
#[derive(Debug, Clone, Copy)]
pub enum RequestKind {
@@ -19,17 +19,24 @@ pub struct Message {
pub text: String,
}
+struct PendingCompletion {
+ id: usize,
+ _task: Task<()>,
+}
+
/// A thread of conversation with the LLM.
pub struct Thread {
messages: Vec<Message>,
- pending_completion_tasks: Vec<Task<()>>,
+ completion_count: usize,
+ pending_completions: Vec<PendingCompletion>,
}
impl Thread {
pub fn new(_cx: &mut ModelContext<Self>) -> Self {
Self {
messages: Vec::new(),
- pending_completion_tasks: Vec::new(),
+ completion_count: 0,
+ pending_completions: Vec::new(),
}
}
@@ -79,7 +86,9 @@ impl Thread {
model: Arc<dyn LanguageModel>,
cx: &mut ModelContext<Self>,
) {
- let task = cx.spawn(|this, mut cx| async move {
+ let pending_completion_id = post_inc(&mut self.completion_count);
+
+ let task = cx.spawn(|thread, mut cx| async move {
let stream = model.stream_completion(request, &cx);
let stream_completion = async {
let mut events = stream.await?;
@@ -88,7 +97,7 @@ impl Thread {
while let Some(event) = events.next().await {
let event = event?;
- this.update(&mut cx, |thread, cx| {
+ thread.update(&mut cx, |thread, cx| {
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
thread.messages.push(Message {
@@ -116,6 +125,12 @@ impl Thread {
smol::future::yield_now().await;
}
+ thread.update(&mut cx, |thread, _cx| {
+ thread
+ .pending_completions
+ .retain(|completion| completion.id != pending_completion_id);
+ })?;
+
anyhow::Ok(stop_reason)
};
@@ -123,7 +138,10 @@ impl Thread {
let _ = result.log_err();
});
- self.pending_completion_tasks.push(task);
+ self.pending_completions.push(PendingCompletion {
+ id: pending_completion_id,
+ _task: task,
+ });
}
}