assistant2: Improve tracking of pending completions (#21186)

Marshall Bowers created

This PR improves the tracking of pending completions in `assistant2`
such that we actually remove ones that have been completed.

Release Notes:

- N/A

Change summary

crates/assistant2/src/thread.rs | 30 ++++++++++++++++++++++++------
1 file changed, 24 insertions(+), 6 deletions(-)

Detailed changes

crates/assistant2/src/thread.rs 🔗

@@ -6,7 +6,7 @@ use language_model::{
     LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
     MessageContent, Role, StopReason,
 };
-use util::ResultExt as _;
+use util::{post_inc, ResultExt as _};
 
 #[derive(Debug, Clone, Copy)]
 pub enum RequestKind {
@@ -19,17 +19,24 @@ pub struct Message {
     pub text: String,
 }
 
+struct PendingCompletion {
+    id: usize,
+    _task: Task<()>,
+}
+
 /// A thread of conversation with the LLM.
 pub struct Thread {
     messages: Vec<Message>,
-    pending_completion_tasks: Vec<Task<()>>,
+    completion_count: usize,
+    pending_completions: Vec<PendingCompletion>,
 }
 
 impl Thread {
     pub fn new(_cx: &mut ModelContext<Self>) -> Self {
         Self {
             messages: Vec::new(),
-            pending_completion_tasks: Vec::new(),
+            completion_count: 0,
+            pending_completions: Vec::new(),
         }
     }
 
@@ -79,7 +86,9 @@ impl Thread {
         model: Arc<dyn LanguageModel>,
         cx: &mut ModelContext<Self>,
     ) {
-        let task = cx.spawn(|this, mut cx| async move {
+        let pending_completion_id = post_inc(&mut self.completion_count);
+
+        let task = cx.spawn(|thread, mut cx| async move {
             let stream = model.stream_completion(request, &cx);
             let stream_completion = async {
                 let mut events = stream.await?;
@@ -88,7 +97,7 @@ impl Thread {
                 while let Some(event) = events.next().await {
                     let event = event?;
 
-                    this.update(&mut cx, |thread, cx| {
+                    thread.update(&mut cx, |thread, cx| {
                         match event {
                             LanguageModelCompletionEvent::StartMessage { .. } => {
                                 thread.messages.push(Message {
@@ -116,6 +125,12 @@ impl Thread {
                     smol::future::yield_now().await;
                 }
 
+                thread.update(&mut cx, |thread, _cx| {
+                    thread
+                        .pending_completions
+                        .retain(|completion| completion.id != pending_completion_id);
+                })?;
+
                 anyhow::Ok(stop_reason)
             };
 
@@ -123,7 +138,10 @@ impl Thread {
             let _ = result.log_err();
         });
 
-        self.pending_completion_tasks.push(task);
+        self.pending_completions.push(PendingCompletion {
+            id: pending_completion_id,
+            _task: task,
+        });
     }
 }