assistant2: Implement refresh of context on message editor send (#22944)

Michael Sloan created

Release Notes:

- N/A

Change summary

crates/assistant2/src/context.rs        |  63 +++++--
crates/assistant2/src/context_store.rs  | 225 ++++++++++++++++++++------
crates/assistant2/src/message_editor.rs |  62 ++++---
3 files changed, 254 insertions(+), 96 deletions(-)

Detailed changes

crates/assistant2/src/context.rs 🔗

@@ -1,6 +1,5 @@
 use std::path::Path;
 use std::rc::Rc;
-use std::sync::Arc;
 
 use file_icons::FileIcons;
 use gpui::{AppContext, Model, SharedString};
@@ -11,7 +10,7 @@ use text::BufferId;
 use ui::IconName;
 use util::post_inc;
 
-use crate::thread::Thread;
+use crate::{context_store::buffer_path_log_err, thread::Thread};
 
 #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 pub struct ContextId(pub(crate) usize);
@@ -76,7 +75,7 @@ impl Context {
 #[derive(Debug)]
 pub struct FileContext {
     pub id: ContextId,
-    pub buffer: ContextBuffer,
+    pub context_buffer: ContextBuffer,
 }
 
 #[derive(Debug)]
@@ -84,7 +83,7 @@ pub struct DirectoryContext {
     #[allow(unused)]
     pub path: Rc<Path>,
     #[allow(unused)]
-    pub buffers: Vec<ContextBuffer>,
+    pub context_buffers: Vec<ContextBuffer>,
     pub snapshot: ContextSnapshot,
 }
 
@@ -108,7 +107,7 @@ pub struct ThreadContext {
 // TODO: Model<Buffer> holds onto the buffer even if the file is deleted and closed. Should remove
 // the context from the message editor in this case.
 
-#[derive(Debug)]
+#[derive(Debug, Clone)]
 pub struct ContextBuffer {
     #[allow(unused)]
     pub id: BufferId,
@@ -130,18 +129,9 @@ impl Context {
 }
 
 impl FileContext {
-    pub fn path(&self, cx: &AppContext) -> Option<Arc<Path>> {
-        let buffer = self.buffer.buffer.read(cx);
-        if let Some(file) = buffer.file() {
-            Some(file.path().clone())
-        } else {
-            log::error!("Buffer that had a path unexpectedly no longer has a path.");
-            None
-        }
-    }
-
     pub fn snapshot(&self, cx: &AppContext) -> Option<ContextSnapshot> {
-        let path = self.path(cx)?;
+        let buffer = self.context_buffer.buffer.read(cx);
+        let path = buffer_path_log_err(buffer)?;
         let full_path: SharedString = path.to_string_lossy().into_owned().into();
         let name = match path.file_name() {
             Some(name) => name.to_string_lossy().into_owned().into(),
@@ -161,12 +151,51 @@ impl FileContext {
             tooltip: Some(full_path),
             icon_path,
             kind: ContextKind::File,
-            text: Box::new([self.buffer.text.clone()]),
+            text: Box::new([self.context_buffer.text.clone()]),
         })
     }
 }
 
 impl DirectoryContext {
+    pub fn new(
+        id: ContextId,
+        path: &Path,
+        context_buffers: Vec<ContextBuffer>,
+    ) -> DirectoryContext {
+        let full_path: SharedString = path.to_string_lossy().into_owned().into();
+
+        let name = match path.file_name() {
+            Some(name) => name.to_string_lossy().into_owned().into(),
+            None => full_path.clone(),
+        };
+
+        let parent = path
+            .parent()
+            .and_then(|p| p.file_name())
+            .map(|p| p.to_string_lossy().into_owned().into());
+
+        // TODO: include directory path in text?
+        let text = context_buffers
+            .iter()
+            .map(|b| b.text.clone())
+            .collect::<Vec<_>>()
+            .into();
+
+        DirectoryContext {
+            path: path.into(),
+            context_buffers,
+            snapshot: ContextSnapshot {
+                id,
+                name,
+                parent,
+                tooltip: Some(full_path),
+                icon_path: None,
+                kind: ContextKind::Directory,
+                text,
+            },
+        }
+    }
+
     pub fn snapshot(&self) -> ContextSnapshot {
         self.snapshot.clone()
     }

crates/assistant2/src/context_store.rs 🔗

@@ -3,6 +3,7 @@ use std::sync::Arc;
 
 use anyhow::{anyhow, bail, Result};
 use collections::{BTreeMap, HashMap};
+use futures::{self, future, Future, FutureExt};
 use gpui::{AppContext, AsyncAppContext, Model, ModelContext, SharedString, Task, WeakView};
 use language::Buffer;
 use project::{ProjectPath, Worktree};
@@ -11,8 +12,8 @@ use text::BufferId;
 use workspace::Workspace;
 
 use crate::context::{
-    Context, ContextBuffer, ContextId, ContextKind, ContextSnapshot, DirectoryContext,
-    FetchedUrlContext, FileContext, ThreadContext,
+    Context, ContextBuffer, ContextId, ContextSnapshot, DirectoryContext, FetchedUrlContext,
+    FileContext, ThreadContext,
 };
 use crate::thread::{Thread, ThreadId};
 
@@ -104,7 +105,7 @@ impl ContextStore {
                     project_path.path.clone(),
                     buffer_model,
                     buffer,
-                    &cx.to_async(),
+                    cx.to_async(),
                 )
             })?;
 
@@ -133,7 +134,7 @@ impl ContextStore {
                     file.path().clone(),
                     buffer_model,
                     buffer,
-                    &cx.to_async(),
+                    cx.to_async(),
                 ))
             })??;
 
@@ -150,10 +151,8 @@ impl ContextStore {
     pub fn insert_file(&mut self, context_buffer: ContextBuffer) {
         let id = self.next_context_id.post_inc();
         self.files.insert(context_buffer.id, id);
-        self.context.push(Context::File(FileContext {
-            id,
-            buffer: context_buffer,
-        }));
+        self.context
+            .push(Context::File(FileContext { id, context_buffer }));
     }
 
     pub fn add_directory(
@@ -207,7 +206,7 @@ impl ContextStore {
                     .collect::<Vec<_>>()
             })?;
 
-            let buffers = futures::future::join_all(open_buffer_tasks).await;
+            let buffers = future::join_all(open_buffer_tasks).await;
 
             let mut buffer_infos = Vec::new();
             let mut text_tasks = Vec::new();
@@ -216,68 +215,41 @@ impl ContextStore {
                     let buffer_model = buffer_model?;
                     let buffer = buffer_model.read(cx);
                     let (buffer_info, text_task) =
-                        collect_buffer_info_and_text(path, buffer_model, buffer, &cx.to_async());
+                        collect_buffer_info_and_text(path, buffer_model, buffer, cx.to_async());
                     buffer_infos.push(buffer_info);
                     text_tasks.push(text_task);
                 }
                 anyhow::Ok(())
             })??;
 
-            let buffer_texts = futures::future::join_all(text_tasks).await;
-            let directory_buffers = buffer_infos
+            let buffer_texts = future::join_all(text_tasks).await;
+            let context_buffers = buffer_infos
                 .into_iter()
-                .zip(buffer_texts.iter())
-                .map(|(info, text)| make_context_buffer(info, text.clone()))
+                .zip(buffer_texts)
+                .map(|(info, text)| make_context_buffer(info, text))
                 .collect::<Vec<_>>();
 
-            if directory_buffers.is_empty() {
+            if context_buffers.is_empty() {
                 bail!("No text files found in {}", &project_path.path.display());
             }
 
-            // TODO: include directory path in text?
-
             this.update(&mut cx, |this, _| {
-                this.insert_directory(&project_path.path, directory_buffers, buffer_texts.into());
+                this.insert_directory(&project_path.path, context_buffers);
             })?;
 
             anyhow::Ok(())
         })
     }
 
-    pub fn insert_directory(
-        &mut self,
-        path: &Path,
-        buffers: Vec<ContextBuffer>,
-        text: Box<[SharedString]>,
-    ) {
+    pub fn insert_directory(&mut self, path: &Path, context_buffers: Vec<ContextBuffer>) {
         let id = self.next_context_id.post_inc();
         self.directories.insert(path.to_path_buf(), id);
 
-        let full_path: SharedString = path.to_string_lossy().into_owned().into();
-
-        let name = match path.file_name() {
-            Some(name) => name.to_string_lossy().into_owned().into(),
-            None => full_path.clone(),
-        };
-
-        let parent = path
-            .parent()
-            .and_then(|p| p.file_name())
-            .map(|p| p.to_string_lossy().into_owned().into());
-
-        self.context.push(Context::Directory(DirectoryContext {
-            path: path.into(),
-            buffers,
-            snapshot: ContextSnapshot {
-                id,
-                name,
-                parent,
-                tooltip: Some(full_path),
-                icon_path: None,
-                kind: ContextKind::Directory,
-                text,
-            },
-        }));
+        self.context.push(Context::Directory(DirectoryContext::new(
+            id,
+            path,
+            context_buffers,
+        )));
     }
 
     pub fn add_thread(&mut self, thread: Model<Thread>, cx: &mut ModelContext<Self>) {
@@ -347,7 +319,8 @@ impl ContextStore {
         if !self.files.is_empty() {
             let found_file_context = self.context.iter().find(|context| match &context {
                 Context::File(file_context) => {
-                    if let Some(file_path) = file_context.path(cx) {
+                    let buffer = file_context.context_buffer.buffer.read(cx);
+                    if let Some(file_path) = buffer_path_log_err(buffer) {
                         *file_path == *path
                     } else {
                         false
@@ -390,6 +363,17 @@ impl ContextStore {
     pub fn includes_url(&self, url: &str) -> Option<ContextId> {
         self.fetched_urls.get(url).copied()
     }
+
+    /// Replaces the context that matches the ID of the new context, if any match.
+    fn replace_context(&mut self, new_context: Context) {
+        let id = new_context.id();
+        for context in self.context.iter_mut() {
+            if context.id() == id {
+                *context = new_context;
+                break;
+            }
+        }
+    }
 }
 
 pub enum FileInclusion {
@@ -417,7 +401,7 @@ fn collect_buffer_info_and_text(
     path: Arc<Path>,
     buffer_model: Model<Buffer>,
     buffer: &Buffer,
-    cx: &AsyncAppContext,
+    cx: AsyncAppContext,
 ) -> (BufferInfo, Task<SharedString>) {
     let buffer_info = BufferInfo {
         id: buffer.remote_id(),
@@ -432,6 +416,15 @@ fn collect_buffer_info_and_text(
     (buffer_info, text_task)
 }
 
+pub fn buffer_path_log_err(buffer: &Buffer) -> Option<Arc<Path>> {
+    if let Some(file) = buffer.file() {
+        Some(file.path().clone())
+    } else {
+        log::error!("Buffer that had a path unexpectedly no longer has a path.");
+        None
+    }
+}
+
 fn to_fenced_codeblock(path: &Path, content: Rope) -> SharedString {
     let path_extension = path.extension().and_then(|ext| ext.to_str());
     let path_string = path.to_string_lossy();
@@ -485,3 +478,133 @@ fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec<Arc<Path>> {
 
     files
 }
+
+pub fn refresh_context_store_text(
+    context_store: Model<ContextStore>,
+    cx: &AppContext,
+) -> impl Future<Output = ()> {
+    let mut tasks = Vec::new();
+    let context_store_ref = context_store.read(cx);
+    for context in &context_store_ref.context {
+        match context {
+            Context::File(file_context) => {
+                let context_store = context_store.clone();
+                if let Some(task) = refresh_file_text(context_store, file_context, cx) {
+                    tasks.push(task);
+                }
+            }
+            Context::Directory(directory_context) => {
+                let context_store = context_store.clone();
+                if let Some(task) = refresh_directory_text(context_store, directory_context, cx) {
+                    tasks.push(task);
+                }
+            }
+            Context::Thread(thread_context) => {
+                let context_store = context_store.clone();
+                tasks.push(refresh_thread_text(context_store, thread_context, cx));
+            }
+            // Intentionally omit refreshing fetched URLs as it doesn't seem all that useful,
+            // and doing the caching properly could be tricky (unless it's already handled by
+            // the HttpClient?).
+            Context::FetchedUrl(_) => {}
+        }
+    }
+
+    future::join_all(tasks).map(|_| ())
+}
+
+fn refresh_file_text(
+    context_store: Model<ContextStore>,
+    file_context: &FileContext,
+    cx: &AppContext,
+) -> Option<Task<()>> {
+    let id = file_context.id;
+    let task = refresh_context_buffer(&file_context.context_buffer, cx);
+    if let Some(task) = task {
+        Some(cx.spawn(|mut cx| async move {
+            let context_buffer = task.await;
+            context_store
+                .update(&mut cx, |context_store, _| {
+                    let new_file_context = FileContext { id, context_buffer };
+                    context_store.replace_context(Context::File(new_file_context));
+                })
+                .ok();
+        }))
+    } else {
+        None
+    }
+}
+
+fn refresh_directory_text(
+    context_store: Model<ContextStore>,
+    directory_context: &DirectoryContext,
+    cx: &AppContext,
+) -> Option<Task<()>> {
+    let mut stale = false;
+    let futures = directory_context
+        .context_buffers
+        .iter()
+        .map(|context_buffer| {
+            if let Some(refresh_task) = refresh_context_buffer(context_buffer, cx) {
+                stale = true;
+                future::Either::Left(refresh_task)
+            } else {
+                future::Either::Right(future::ready((*context_buffer).clone()))
+            }
+        })
+        .collect::<Vec<_>>();
+
+    if !stale {
+        return None;
+    }
+
+    let context_buffers = future::join_all(futures);
+
+    let id = directory_context.snapshot.id;
+    let path = directory_context.path.clone();
+    Some(cx.spawn(|mut cx| async move {
+        let context_buffers = context_buffers.await;
+        context_store
+            .update(&mut cx, |context_store, _| {
+                let new_directory_context = DirectoryContext::new(id, &path, context_buffers);
+                context_store.replace_context(Context::Directory(new_directory_context));
+            })
+            .ok();
+    }))
+}
+
+fn refresh_thread_text(
+    context_store: Model<ContextStore>,
+    thread_context: &ThreadContext,
+    cx: &AppContext,
+) -> Task<()> {
+    let id = thread_context.id;
+    let thread = thread_context.thread.clone();
+    cx.spawn(move |mut cx| async move {
+        context_store
+            .update(&mut cx, |context_store, cx| {
+                let text = thread.read(cx).text().into();
+                context_store.replace_context(Context::Thread(ThreadContext { id, thread, text }));
+            })
+            .ok();
+    })
+}
+
+fn refresh_context_buffer(
+    context_buffer: &ContextBuffer,
+    cx: &AppContext,
+) -> Option<impl Future<Output = ContextBuffer>> {
+    let buffer = context_buffer.buffer.read(cx);
+    let path = buffer_path_log_err(buffer)?;
+    if buffer.version.changed_since(&context_buffer.version) {
+        let (buffer_info, text_task) = collect_buffer_info_and_text(
+            path,
+            context_buffer.buffer.clone(),
+            buffer,
+            cx.to_async(),
+        );
+        Some(text_task.map(move |text| make_context_buffer(buffer_info, text)))
+    } else {
+        None
+    }
+}

crates/assistant2/src/message_editor.rs 🔗

@@ -19,7 +19,7 @@ use workspace::Workspace;
 
 use crate::assistant_model_selector::AssistantModelSelector;
 use crate::context_picker::{ConfirmBehavior, ContextPicker};
-use crate::context_store::ContextStore;
+use crate::context_store::{refresh_context_store_text, ContextStore};
 use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
 use crate::thread::{RequestKind, Thread};
 use crate::thread_store::ThreadStore;
@@ -125,22 +125,20 @@ impl MessageEditor {
         self.send_to_model(RequestKind::Chat, cx);
     }
 
-    fn send_to_model(
-        &mut self,
-        request_kind: RequestKind,
-        cx: &mut ViewContext<Self>,
-    ) -> Option<()> {
+    fn send_to_model(&mut self, request_kind: RequestKind, cx: &mut ViewContext<Self>) {
         let provider = LanguageModelRegistry::read_global(cx).active_provider();
         if provider
             .as_ref()
             .map_or(false, |provider| provider.must_accept_terms(cx))
         {
             cx.notify();
-            return None;
+            return;
         }
 
         let model_registry = LanguageModelRegistry::read_global(cx);
-        let model = model_registry.active_model()?;
+        let Some(model) = model_registry.active_model() else {
+            return;
+        };
 
         let user_message = self.editor.update(cx, |editor, cx| {
             let text = editor.text(cx);
@@ -148,29 +146,37 @@ impl MessageEditor {
             text
         });
 
-        let thread = self.thread.clone();
-        thread.update(cx, |thread, cx| {
-            let context = self.context_store.read(cx).snapshot(cx).collect::<Vec<_>>();
-            thread.insert_user_message(user_message, context, cx);
-            let mut request = thread.to_completion_request(request_kind, cx);
+        let refresh_task = refresh_context_store_text(self.context_store.clone(), cx);
 
-            if self.use_tools {
-                request.tools = thread
-                    .tools()
-                    .tools(cx)
-                    .into_iter()
-                    .map(|tool| LanguageModelRequestTool {
-                        name: tool.name(),
-                        description: tool.description(),
-                        input_schema: tool.input_schema(),
-                    })
-                    .collect();
-            }
+        let thread = self.thread.clone();
+        let context_store = self.context_store.clone();
+        let use_tools = self.use_tools;
+        cx.spawn(move |_, mut cx| async move {
+            refresh_task.await;
+            thread
+                .update(&mut cx, |thread, cx| {
+                    let context = context_store.read(cx).snapshot(cx).collect::<Vec<_>>();
+                    thread.insert_user_message(user_message, context, cx);
+                    let mut request = thread.to_completion_request(request_kind, cx);
 
-            thread.stream_completion(request, model, cx)
-        });
+                    if use_tools {
+                        request.tools = thread
+                            .tools()
+                            .tools(cx)
+                            .into_iter()
+                            .map(|tool| LanguageModelRequestTool {
+                                name: tool.name(),
+                                description: tool.description(),
+                                input_schema: tool.input_schema(),
+                            })
+                            .collect();
+                    }
 
-        None
+                    thread.stream_completion(request, model, cx)
+                })
+                .ok();
+        })
+        .detach();
     }
 
     fn handle_editor_event(