assistant: Factor `RecentBuffersContext` logic out of `AssistantPanel` (#11876)

Marshall Bowers and Max created

This PR factors some more code related to the `RecentBuffersContext` out
of the `AssistantPanel` and into the corresponding module.

We're trying to strike a balance between keeping this code easy to
evolve as we work on the Assistant, while also having some semblance of
separation/structure.

This also adds the missing functionality of updating the remaining token
count when the `CurrentProjectContext` is enabled/disabled.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>

Change summary

crates/assistant/src/ambient_context.rs                 |   6 
crates/assistant/src/ambient_context/current_project.rs |  11 
crates/assistant/src/ambient_context/recent_buffers.rs  | 168 +++++++++
crates/assistant/src/assistant_panel.rs                 | 189 +---------
4 files changed, 204 insertions(+), 170 deletions(-)

Detailed changes

crates/assistant/src/ambient_context.rs 🔗

@@ -9,3 +9,9 @@ pub struct AmbientContext {
     pub recent_buffers: RecentBuffersContext,
     pub current_project: CurrentProjectContext,
 }
+
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
+pub enum ContextUpdated {
+    Updating,
+    Disabled,
+}

crates/assistant/src/ambient_context/current_project.rs 🔗

@@ -8,6 +8,7 @@ use gpui::{AsyncAppContext, ModelContext, Task, WeakModel};
 use project::{Project, ProjectPath};
 use util::ResultExt;
 
+use crate::ambient_context::ContextUpdated;
 use crate::assistant_panel::Conversation;
 use crate::{LanguageModelRequestMessage, Role};
 
@@ -44,12 +45,12 @@ impl CurrentProjectContext {
         fs: Arc<dyn Fs>,
         project: WeakModel<Project>,
         cx: &mut ModelContext<Conversation>,
-    ) {
+    ) -> ContextUpdated {
         if !self.enabled {
             self.message.clear();
             self.pending_message = None;
             cx.notify();
-            return;
+            return ContextUpdated::Disabled;
         }
 
         self.pending_message = Some(cx.spawn(|conversation, mut cx| async move {
@@ -74,12 +75,16 @@ impl CurrentProjectContext {
 
             if let Some(message) = message_task.await.log_err() {
                 conversation
-                    .update(&mut cx, |conversation, _cx| {
+                    .update(&mut cx, |conversation, cx| {
                         conversation.ambient_context.current_project.message = message;
+                        conversation.count_remaining_tokens(cx);
+                        cx.notify();
                     })
                     .log_err();
             }
         }));
+
+        ContextUpdated::Updating
     }
 
     async fn build_message(fs: Arc<dyn Fs>, path_to_cargo_toml: &Path) -> Result<String> {

crates/assistant/src/ambient_context/recent_buffers.rs 🔗

@@ -1,6 +1,13 @@
-use gpui::{Subscription, Task, WeakModel};
-use language::Buffer;
+use std::fmt::Write;
+use std::iter;
+use std::path::PathBuf;
+use std::time::Duration;
 
+use gpui::{ModelContext, Subscription, Task, WeakModel};
+use language::{Buffer, BufferSnapshot, DiagnosticEntry, Point};
+
+use crate::ambient_context::ContextUpdated;
+use crate::assistant_panel::Conversation;
 use crate::{LanguageModelRequestMessage, Role};
 
 pub struct RecentBuffersContext {
@@ -34,4 +41,161 @@ impl RecentBuffersContext {
             content: self.message.clone(),
         })
     }
+
+    pub fn update(&mut self, cx: &mut ModelContext<Conversation>) -> ContextUpdated {
+        let buffers = self
+            .buffers
+            .iter()
+            .filter_map(|recent| {
+                recent
+                    .buffer
+                    .read_with(cx, |buffer, cx| {
+                        (
+                            buffer.file().map(|file| file.full_path(cx)),
+                            buffer.snapshot(),
+                        )
+                    })
+                    .ok()
+            })
+            .collect::<Vec<_>>();
+
+        if !self.enabled || buffers.is_empty() {
+            self.message.clear();
+            self.pending_message = None;
+            cx.notify();
+            ContextUpdated::Disabled
+        } else {
+            self.pending_message = Some(cx.spawn(|this, mut cx| async move {
+                const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(100);
+                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
+
+                let message = cx
+                    .background_executor()
+                    .spawn(async move { Self::build_message(&buffers) })
+                    .await;
+                this.update(&mut cx, |conversation, cx| {
+                    conversation.ambient_context.recent_buffers.message = message;
+                    conversation.count_remaining_tokens(cx);
+                    cx.notify();
+                })
+                .ok();
+            }));
+
+            ContextUpdated::Updating
+        }
+    }
+
+    fn build_message(buffers: &[(Option<PathBuf>, BufferSnapshot)]) -> String {
+        let mut message = String::new();
+        writeln!(
+            message,
+            "The following is a list of recent buffers that the user has opened."
+        )
+        .unwrap();
+        writeln!(
+            message,
+            "For every line in the buffer, I will include a row number that line corresponds to."
+        )
+        .unwrap();
+        writeln!(
+            message,
+            "Lines that don't have a number correspond to errors and warnings. For example:"
+        )
+        .unwrap();
+        writeln!(message, "path/to/file.md").unwrap();
+        writeln!(message, "```markdown").unwrap();
+        writeln!(message, "1 The quick brown fox").unwrap();
+        writeln!(message, "2 jumps over one active").unwrap();
+        writeln!(message, "             --- error: should be 'the'").unwrap();
+        writeln!(message, "                 ------ error: should be 'lazy'").unwrap();
+        writeln!(message, "3 dog").unwrap();
+        writeln!(message, "```").unwrap();
+
+        message.push('\n');
+        writeln!(message, "Here's the actual recent buffer list:").unwrap();
+        for (path, buffer) in buffers {
+            if let Some(path) = path {
+                writeln!(message, "{}", path.display()).unwrap();
+            } else {
+                writeln!(message, "untitled").unwrap();
+            }
+
+            if let Some(language) = buffer.language() {
+                writeln!(message, "```{}", language.name().to_lowercase()).unwrap();
+            } else {
+                writeln!(message, "```").unwrap();
+            }
+
+            let mut diagnostics = buffer
+                .diagnostics_in_range::<_, Point>(
+                    language::Anchor::MIN..language::Anchor::MAX,
+                    false,
+                )
+                .peekable();
+
+            let mut active_diagnostics = Vec::<DiagnosticEntry<Point>>::new();
+            const GUTTER_PADDING: usize = 4;
+            let gutter_width =
+                ((buffer.max_point().row + 1) as f32).log10() as usize + 1 + GUTTER_PADDING;
+            for buffer_row in 0..=buffer.max_point().row {
+                let display_row = buffer_row + 1;
+                active_diagnostics.retain(|diagnostic| {
+                    (diagnostic.range.start.row..=diagnostic.range.end.row).contains(&buffer_row)
+                });
+                while diagnostics.peek().map_or(false, |diagnostic| {
+                    (diagnostic.range.start.row..=diagnostic.range.end.row).contains(&buffer_row)
+                }) {
+                    active_diagnostics.push(diagnostics.next().unwrap());
+                }
+
+                let row_width = (display_row as f32).log10() as usize + 1;
+                write!(message, "{}", display_row).unwrap();
+                if row_width < gutter_width {
+                    message.extend(iter::repeat(' ').take(gutter_width - row_width));
+                }
+
+                for chunk in buffer.text_for_range(
+                    Point::new(buffer_row, 0)..Point::new(buffer_row, buffer.line_len(buffer_row)),
+                ) {
+                    message.push_str(chunk);
+                }
+                message.push('\n');
+
+                for diagnostic in &active_diagnostics {
+                    message.extend(iter::repeat(' ').take(gutter_width));
+
+                    let start_column = if diagnostic.range.start.row == buffer_row {
+                        message
+                            .extend(iter::repeat(' ').take(diagnostic.range.start.column as usize));
+                        diagnostic.range.start.column
+                    } else {
+                        0
+                    };
+                    let end_column = if diagnostic.range.end.row == buffer_row {
+                        diagnostic.range.end.column
+                    } else {
+                        buffer.line_len(buffer_row)
+                    };
+
+                    message.extend(iter::repeat('-').take((end_column - start_column) as usize));
+                    writeln!(message, " {}", diagnostic.diagnostic.message).unwrap();
+                }
+            }
+
+            message.push('\n');
+        }
+
+        writeln!(
+            message,
+            "When quoting the above code, mention which rows the code occurs at."
+        )
+        .unwrap();
+        writeln!(
+            message,
+            "Never include rows in the quoted code itself and only report lines that didn't start with a row number."
+        )
+        .unwrap();
+
+        message
+    }
 }

crates/assistant/src/assistant_panel.rs 🔗

@@ -1,4 +1,4 @@
-use crate::ambient_context::{AmbientContext, RecentBuffer};
+use crate::ambient_context::{AmbientContext, ContextUpdated, RecentBuffer};
 use crate::{
     assistant_settings::{AssistantDockPosition, AssistantSettings, ZedDotDevModel},
     codegen::{self, Codegen, CodegenKind},
@@ -31,10 +31,7 @@ use gpui::{
     Subscription, Task, TextStyle, UniformListScrollHandle, View, ViewContext, VisualContext,
     WeakModel, WeakView, WhiteSpace, WindowContext,
 };
-use language::{
-    language_settings::SoftWrap, Buffer, BufferSnapshot, DiagnosticEntry, LanguageRegistry, Point,
-    ToOffset as _,
-};
+use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, Point, ToOffset as _};
 use multi_buffer::MultiBufferRow;
 use parking_lot::Mutex;
 use project::Project;
@@ -1519,7 +1516,12 @@ impl Conversation {
 
     fn toggle_recent_buffers(&mut self, cx: &mut ModelContext<Self>) {
         self.ambient_context.recent_buffers.enabled = !self.ambient_context.recent_buffers.enabled;
-        self.update_recent_buffers_context(cx);
+        match self.ambient_context.recent_buffers.update(cx) {
+            ContextUpdated::Updating => {}
+            ContextUpdated::Disabled => {
+                self.count_remaining_tokens(cx);
+            }
+        }
     }
 
     fn toggle_current_project_context(
@@ -1530,7 +1532,12 @@ impl Conversation {
     ) {
         self.ambient_context.current_project.enabled =
             !self.ambient_context.current_project.enabled;
-        self.ambient_context.current_project.update(fs, project, cx);
+        match self.ambient_context.current_project.update(fs, project, cx) {
+            ContextUpdated::Updating => {}
+            ContextUpdated::Disabled => {
+                self.count_remaining_tokens(cx);
+            }
+        }
     }
 
     fn set_recent_buffers(
@@ -1545,168 +1552,20 @@ impl Conversation {
             .extend(buffers.into_iter().map(|buffer| RecentBuffer {
                 buffer: buffer.downgrade(),
                 _subscription: cx.observe(&buffer, |this, _, cx| {
-                    this.update_recent_buffers_context(cx);
+                    match this.ambient_context.recent_buffers.update(cx) {
+                        ContextUpdated::Updating => {}
+                        ContextUpdated::Disabled => {
+                            this.count_remaining_tokens(cx);
+                        }
+                    }
                 }),
             }));
-        self.update_recent_buffers_context(cx);
-    }
-
-    fn update_recent_buffers_context(&mut self, cx: &mut ModelContext<Self>) {
-        let buffers = self
-            .ambient_context
-            .recent_buffers
-            .buffers
-            .iter()
-            .filter_map(|recent| {
-                recent
-                    .buffer
-                    .read_with(cx, |buffer, cx| {
-                        (
-                            buffer.file().map(|file| file.full_path(cx)),
-                            buffer.snapshot(),
-                        )
-                    })
-                    .ok()
-            })
-            .collect::<Vec<_>>();
-
-        if !self.ambient_context.recent_buffers.enabled || buffers.is_empty() {
-            self.ambient_context.recent_buffers.message.clear();
-            self.ambient_context.recent_buffers.pending_message = None;
-            self.count_remaining_tokens(cx);
-            cx.notify();
-        } else {
-            self.ambient_context.recent_buffers.pending_message =
-                Some(cx.spawn(|this, mut cx| async move {
-                    const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(100);
-                    cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
-
-                    let message = cx
-                        .background_executor()
-                        .spawn(async move { Self::message_for_recent_buffers(&buffers) })
-                        .await;
-                    this.update(&mut cx, |this, cx| {
-                        this.ambient_context.recent_buffers.message = message;
-                        this.count_remaining_tokens(cx);
-                        cx.notify();
-                    })
-                    .ok();
-                }));
-        }
-    }
-
-    fn message_for_recent_buffers(buffers: &[(Option<PathBuf>, BufferSnapshot)]) -> String {
-        let mut message = String::new();
-        writeln!(
-            message,
-            "The following is a list of recent buffers that the user has opened."
-        )
-        .unwrap();
-        writeln!(
-            message,
-            "For every line in the buffer, I will include a row number that line corresponds to."
-        )
-        .unwrap();
-        writeln!(
-            message,
-            "Lines that don't have a number correspond to errors and warnings. For example:"
-        )
-        .unwrap();
-        writeln!(message, "path/to/file.md").unwrap();
-        writeln!(message, "```markdown").unwrap();
-        writeln!(message, "1 The quick brown fox").unwrap();
-        writeln!(message, "2 jumps over one active").unwrap();
-        writeln!(message, "             --- error: should be 'the'").unwrap();
-        writeln!(message, "                 ------ error: should be 'lazy'").unwrap();
-        writeln!(message, "3 dog").unwrap();
-        writeln!(message, "```").unwrap();
-
-        message.push('\n');
-        writeln!(message, "Here's the actual recent buffer list:").unwrap();
-        for (path, buffer) in buffers {
-            if let Some(path) = path {
-                writeln!(message, "{}", path.display()).unwrap();
-            } else {
-                writeln!(message, "untitled").unwrap();
+        match self.ambient_context.recent_buffers.update(cx) {
+            ContextUpdated::Updating => {}
+            ContextUpdated::Disabled => {
+                self.count_remaining_tokens(cx);
             }
-
-            if let Some(language) = buffer.language() {
-                writeln!(message, "```{}", language.name().to_lowercase()).unwrap();
-            } else {
-                writeln!(message, "```").unwrap();
-            }
-
-            let mut diagnostics = buffer
-                .diagnostics_in_range::<_, Point>(
-                    language::Anchor::MIN..language::Anchor::MAX,
-                    false,
-                )
-                .peekable();
-
-            let mut active_diagnostics = Vec::<DiagnosticEntry<Point>>::new();
-            const GUTTER_PADDING: usize = 4;
-            let gutter_width =
-                ((buffer.max_point().row + 1) as f32).log10() as usize + 1 + GUTTER_PADDING;
-            for buffer_row in 0..=buffer.max_point().row {
-                let display_row = buffer_row + 1;
-                active_diagnostics.retain(|diagnostic| {
-                    (diagnostic.range.start.row..=diagnostic.range.end.row).contains(&buffer_row)
-                });
-                while diagnostics.peek().map_or(false, |diagnostic| {
-                    (diagnostic.range.start.row..=diagnostic.range.end.row).contains(&buffer_row)
-                }) {
-                    active_diagnostics.push(diagnostics.next().unwrap());
-                }
-
-                let row_width = (display_row as f32).log10() as usize + 1;
-                write!(message, "{}", display_row).unwrap();
-                if row_width < gutter_width {
-                    message.extend(iter::repeat(' ').take(gutter_width - row_width));
-                }
-
-                for chunk in buffer.text_for_range(
-                    Point::new(buffer_row, 0)..Point::new(buffer_row, buffer.line_len(buffer_row)),
-                ) {
-                    message.push_str(chunk);
-                }
-                message.push('\n');
-
-                for diagnostic in &active_diagnostics {
-                    message.extend(iter::repeat(' ').take(gutter_width));
-
-                    let start_column = if diagnostic.range.start.row == buffer_row {
-                        message
-                            .extend(iter::repeat(' ').take(diagnostic.range.start.column as usize));
-                        diagnostic.range.start.column
-                    } else {
-                        0
-                    };
-                    let end_column = if diagnostic.range.end.row == buffer_row {
-                        diagnostic.range.end.column
-                    } else {
-                        buffer.line_len(buffer_row)
-                    };
-
-                    message.extend(iter::repeat('-').take((end_column - start_column) as usize));
-                    writeln!(message, " {}", diagnostic.diagnostic.message).unwrap();
-                }
-            }
-
-            message.push('\n');
         }
-
-        writeln!(
-            message,
-            "When quoting the above code, mention which rows the code occurs at."
-        )
-        .unwrap();
-        writeln!(
-            message,
-            "Never include rows in the quoted code itself and only report lines that didn't start with a row number."
-        )
-        .unwrap();
-
-        message
     }
 
     fn handle_buffer_event(