assistant2: Wire up context picker with inline assist (#22106)

Marshall Bowers , Richard , and Agus created

This PR wire up the context picker with the inline assist.

UI is not finalized.

Release Notes:

- N/A

---------

Co-authored-by: Richard <richard@zed.dev>
Co-authored-by: Agus <agus@zed.dev>

Change summary

crates/assistant2/src/assistant.rs                            |   1 
crates/assistant2/src/assistant_panel.rs                      |   4 
crates/assistant2/src/context.rs                              |  49 
crates/assistant2/src/context_picker.rs                       |  83 
crates/assistant2/src/context_picker/fetch_context_picker.rs  |  20 
crates/assistant2/src/context_picker/file_context_picker.rs   |  18 
crates/assistant2/src/context_picker/thread_context_picker.rs |  18 
crates/assistant2/src/context_store.rs                        |  47 
crates/assistant2/src/context_strip.rs                        |  63 
crates/assistant2/src/inline_assistant.rs                     | 264 +++-
crates/assistant2/src/message_editor.rs                       |  17 
crates/assistant2/src/thread.rs                               |  48 
12 files changed, 391 insertions(+), 241 deletions(-)

Detailed changes

crates/assistant2/src/assistant.rs 🔗

@@ -3,6 +3,7 @@ mod assistant_panel;
 mod assistant_settings;
 mod context;
 mod context_picker;
+mod context_store;
 mod context_strip;
 mod inline_assistant;
 mod message_editor;

crates/assistant2/src/assistant_panel.rs 🔗

@@ -110,6 +110,10 @@ impl AssistantPanel {
         self.local_timezone
     }
 
+    pub(crate) fn thread_store(&self) -> &Model<ThreadStore> {
+        &self.thread_store
+    }
+
     fn new_thread(&mut self, cx: &mut ViewContext<Self>) {
         let thread = self
             .thread_store

crates/assistant2/src/context.rs 🔗

@@ -1,4 +1,5 @@
 use gpui::SharedString;
+use language_model::{LanguageModelRequestMessage, MessageContent};
 use serde::{Deserialize, Serialize};
 use util::post_inc;
 
@@ -26,3 +27,51 @@ pub enum ContextKind {
     FetchedUrl,
     Thread,
 }
+
+pub fn attach_context_to_message(
+    message: &mut LanguageModelRequestMessage,
+    context: impl IntoIterator<Item = Context>,
+) {
+    let mut file_context = String::new();
+    let mut fetch_context = String::new();
+    let mut thread_context = String::new();
+
+    for context in context.into_iter() {
+        match context.kind {
+            ContextKind::File => {
+                file_context.push_str(&context.text);
+                file_context.push('\n');
+            }
+            ContextKind::FetchedUrl => {
+                fetch_context.push_str(&context.name);
+                fetch_context.push('\n');
+                fetch_context.push_str(&context.text);
+                fetch_context.push('\n');
+            }
+            ContextKind::Thread => {
+                thread_context.push_str(&context.name);
+                thread_context.push('\n');
+                thread_context.push_str(&context.text);
+                thread_context.push('\n');
+            }
+        }
+    }
+
+    let mut context_text = String::new();
+    if !file_context.is_empty() {
+        context_text.push_str("The following files are available:\n");
+        context_text.push_str(&file_context);
+    }
+
+    if !fetch_context.is_empty() {
+        context_text.push_str("The following fetched results are available\n");
+        context_text.push_str(&fetch_context);
+    }
+
+    if !thread_context.is_empty() {
+        context_text.push_str("The following previous conversation threads are available\n");
+        context_text.push_str(&thread_context);
+    }
+
+    message.content.push(MessageContent::Text(context_text));
+}

crates/assistant2/src/context_picker.rs 🔗

@@ -16,7 +16,7 @@ use workspace::Workspace;
 use crate::context_picker::fetch_context_picker::FetchContextPicker;
 use crate::context_picker::file_context_picker::FileContextPicker;
 use crate::context_picker::thread_context_picker::ThreadContextPicker;
-use crate::context_strip::ContextStrip;
+use crate::context_store::ContextStore;
 use crate::thread_store::ThreadStore;
 
 #[derive(Debug, Clone)]
@@ -35,37 +35,42 @@ pub(super) struct ContextPicker {
 impl ContextPicker {
     pub fn new(
         workspace: WeakView<Workspace>,
-        thread_store: WeakModel<ThreadStore>,
-        context_strip: WeakView<ContextStrip>,
+        thread_store: Option<WeakModel<ThreadStore>>,
+        context_store: WeakModel<ContextStore>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
+        let mut entries = vec![
+            ContextPickerEntry {
+                name: "directory".into(),
+                description: "Insert any directory".into(),
+                icon: IconName::Folder,
+            },
+            ContextPickerEntry {
+                name: "file".into(),
+                description: "Insert any file".into(),
+                icon: IconName::File,
+            },
+            ContextPickerEntry {
+                name: "fetch".into(),
+                description: "Fetch content from URL".into(),
+                icon: IconName::Globe,
+            },
+        ];
+
+        if thread_store.is_some() {
+            entries.push(ContextPickerEntry {
+                name: "thread".into(),
+                description: "Insert any thread".into(),
+                icon: IconName::MessageBubbles,
+            });
+        }
+
         let delegate = ContextPickerDelegate {
             context_picker: cx.view().downgrade(),
             workspace,
             thread_store,
-            context_strip,
-            entries: vec![
-                ContextPickerEntry {
-                    name: "directory".into(),
-                    description: "Insert any directory".into(),
-                    icon: IconName::Folder,
-                },
-                ContextPickerEntry {
-                    name: "file".into(),
-                    description: "Insert any file".into(),
-                    icon: IconName::File,
-                },
-                ContextPickerEntry {
-                    name: "fetch".into(),
-                    description: "Fetch content from URL".into(),
-                    icon: IconName::Globe,
-                },
-                ContextPickerEntry {
-                    name: "thread".into(),
-                    description: "Insert any thread".into(),
-                    icon: IconName::MessageBubbles,
-                },
-            ],
+            context_store,
+            entries,
             selected_ix: 0,
         };
 
@@ -121,8 +126,8 @@ struct ContextPickerEntry {
 pub(crate) struct ContextPickerDelegate {
     context_picker: WeakView<ContextPicker>,
     workspace: WeakView<Workspace>,
-    thread_store: WeakModel<ThreadStore>,
-    context_strip: WeakView<ContextStrip>,
+    thread_store: Option<WeakModel<ThreadStore>>,
+    context_store: WeakModel<ContextStore>,
     entries: Vec<ContextPickerEntry>,
     selected_ix: usize,
 }
@@ -161,7 +166,7 @@ impl PickerDelegate for ContextPickerDelegate {
                                 FileContextPicker::new(
                                     self.context_picker.clone(),
                                     self.workspace.clone(),
-                                    self.context_strip.clone(),
+                                    self.context_store.clone(),
                                     cx,
                                 )
                             }));
@@ -171,20 +176,22 @@ impl PickerDelegate for ContextPickerDelegate {
                                 FetchContextPicker::new(
                                     self.context_picker.clone(),
                                     self.workspace.clone(),
-                                    self.context_strip.clone(),
+                                    self.context_store.clone(),
                                     cx,
                                 )
                             }));
                         }
                         "thread" => {
-                            this.mode = ContextPickerMode::Thread(cx.new_view(|cx| {
-                                ThreadContextPicker::new(
-                                    self.thread_store.clone(),
-                                    self.context_picker.clone(),
-                                    self.context_strip.clone(),
-                                    cx,
-                                )
-                            }));
+                            if let Some(thread_store) = self.thread_store.as_ref() {
+                                this.mode = ContextPickerMode::Thread(cx.new_view(|cx| {
+                                    ThreadContextPicker::new(
+                                        thread_store.clone(),
+                                        self.context_picker.clone(),
+                                        self.context_store.clone(),
+                                        cx,
+                                    )
+                                }));
+                            }
                         }
                         _ => {}
                     }

crates/assistant2/src/context_picker/fetch_context_picker.rs 🔗

@@ -4,7 +4,7 @@ use std::sync::Arc;
 
 use anyhow::{bail, Context as _, Result};
 use futures::AsyncReadExt as _;
-use gpui::{AppContext, DismissEvent, FocusHandle, FocusableView, Task, View, WeakView};
+use gpui::{AppContext, DismissEvent, FocusHandle, FocusableView, Task, View, WeakModel, WeakView};
 use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler};
 use http_client::{AsyncBody, HttpClientWithUrl};
 use picker::{Picker, PickerDelegate};
@@ -13,7 +13,7 @@ use workspace::Workspace;
 
 use crate::context::ContextKind;
 use crate::context_picker::ContextPicker;
-use crate::context_strip::ContextStrip;
+use crate::context_store::ContextStore;
 
 pub struct FetchContextPicker {
     picker: View<Picker<FetchContextPickerDelegate>>,
@@ -23,10 +23,10 @@ impl FetchContextPicker {
     pub fn new(
         context_picker: WeakView<ContextPicker>,
         workspace: WeakView<Workspace>,
-        context_strip: WeakView<ContextStrip>,
+        context_store: WeakModel<ContextStore>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
-        let delegate = FetchContextPickerDelegate::new(context_picker, workspace, context_strip);
+        let delegate = FetchContextPickerDelegate::new(context_picker, workspace, context_store);
         let picker = cx.new_view(|cx| Picker::uniform_list(delegate, cx));
 
         Self { picker }
@@ -55,7 +55,7 @@ enum ContentType {
 pub struct FetchContextPickerDelegate {
     context_picker: WeakView<ContextPicker>,
     workspace: WeakView<Workspace>,
-    context_strip: WeakView<ContextStrip>,
+    context_store: WeakModel<ContextStore>,
     url: String,
 }
 
@@ -63,12 +63,12 @@ impl FetchContextPickerDelegate {
     pub fn new(
         context_picker: WeakView<ContextPicker>,
         workspace: WeakView<Workspace>,
-        context_strip: WeakView<ContextStrip>,
+        context_store: WeakModel<ContextStore>,
     ) -> Self {
         FetchContextPickerDelegate {
             context_picker,
             workspace,
-            context_strip,
+            context_store,
             url: String::new(),
         }
     }
@@ -189,9 +189,9 @@ impl PickerDelegate for FetchContextPickerDelegate {
 
             this.update(&mut cx, |this, cx| {
                 this.delegate
-                    .context_strip
-                    .update(cx, |context_strip, _cx| {
-                        context_strip.insert_context(ContextKind::FetchedUrl, url, text);
+                    .context_store
+                    .update(cx, |context_store, _cx| {
+                        context_store.insert_context(ContextKind::FetchedUrl, url, text);
                     })
             })??;
 

crates/assistant2/src/context_picker/file_context_picker.rs 🔗

@@ -5,7 +5,7 @@ use std::sync::atomic::AtomicBool;
 use std::sync::Arc;
 
 use fuzzy::PathMatch;
-use gpui::{AppContext, DismissEvent, FocusHandle, FocusableView, Task, View, WeakView};
+use gpui::{AppContext, DismissEvent, FocusHandle, FocusableView, Task, View, WeakModel, WeakView};
 use picker::{Picker, PickerDelegate};
 use project::{PathMatchCandidateSet, WorktreeId};
 use ui::{prelude::*, ListItem};
@@ -14,7 +14,7 @@ use workspace::Workspace;
 
 use crate::context::ContextKind;
 use crate::context_picker::ContextPicker;
-use crate::context_strip::ContextStrip;
+use crate::context_store::ContextStore;
 
 pub struct FileContextPicker {
     picker: View<Picker<FileContextPickerDelegate>>,
@@ -24,10 +24,10 @@ impl FileContextPicker {
     pub fn new(
         context_picker: WeakView<ContextPicker>,
         workspace: WeakView<Workspace>,
-        context_strip: WeakView<ContextStrip>,
+        context_store: WeakModel<ContextStore>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
-        let delegate = FileContextPickerDelegate::new(context_picker, workspace, context_strip);
+        let delegate = FileContextPickerDelegate::new(context_picker, workspace, context_store);
         let picker = cx.new_view(|cx| Picker::uniform_list(delegate, cx));
 
         Self { picker }
@@ -49,7 +49,7 @@ impl Render for FileContextPicker {
 pub struct FileContextPickerDelegate {
     context_picker: WeakView<ContextPicker>,
     workspace: WeakView<Workspace>,
-    context_strip: WeakView<ContextStrip>,
+    context_store: WeakModel<ContextStore>,
     matches: Vec<PathMatch>,
     selected_index: usize,
 }
@@ -58,12 +58,12 @@ impl FileContextPickerDelegate {
     pub fn new(
         context_picker: WeakView<ContextPicker>,
         workspace: WeakView<Workspace>,
-        context_strip: WeakView<ContextStrip>,
+        context_store: WeakModel<ContextStore>,
     ) -> Self {
         Self {
             context_picker,
             workspace,
-            context_strip,
+            context_store,
             matches: Vec::new(),
             selected_index: 0,
         }
@@ -214,7 +214,7 @@ impl PickerDelegate for FileContextPickerDelegate {
             let buffer = open_buffer_task.await?;
 
             this.update(&mut cx, |this, cx| {
-                this.delegate.context_strip.update(cx, |context_strip, cx| {
+                this.delegate.context_store.update(cx, |context_store, cx| {
                     let mut text = String::new();
                     text.push_str(&codeblock_fence_for_path(Some(&path), None));
                     text.push_str(&buffer.read(cx).text());
@@ -224,7 +224,7 @@ impl PickerDelegate for FileContextPickerDelegate {
 
                     text.push_str("```\n");
 
-                    context_strip.insert_context(
+                    context_store.insert_context(
                         ContextKind::File,
                         path.to_string_lossy().to_string(),
                         text,

crates/assistant2/src/context_picker/thread_context_picker.rs 🔗

@@ -7,7 +7,7 @@ use ui::{prelude::*, ListItem};
 
 use crate::context::ContextKind;
 use crate::context_picker::ContextPicker;
-use crate::context_strip::ContextStrip;
+use crate::context_store;
 use crate::thread::ThreadId;
 use crate::thread_store::ThreadStore;
 
@@ -19,11 +19,11 @@ impl ThreadContextPicker {
     pub fn new(
         thread_store: WeakModel<ThreadStore>,
         context_picker: WeakView<ContextPicker>,
-        context_strip: WeakView<ContextStrip>,
+        context_store: WeakModel<context_store::ContextStore>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
         let delegate =
-            ThreadContextPickerDelegate::new(thread_store, context_picker, context_strip);
+            ThreadContextPickerDelegate::new(thread_store, context_picker, context_store);
         let picker = cx.new_view(|cx| Picker::uniform_list(delegate, cx));
 
         ThreadContextPicker { picker }
@@ -51,7 +51,7 @@ struct ThreadContextEntry {
 pub struct ThreadContextPickerDelegate {
     thread_store: WeakModel<ThreadStore>,
     context_picker: WeakView<ContextPicker>,
-    context_strip: WeakView<ContextStrip>,
+    context_store: WeakModel<context_store::ContextStore>,
     matches: Vec<ThreadContextEntry>,
     selected_index: usize,
 }
@@ -60,12 +60,12 @@ impl ThreadContextPickerDelegate {
     pub fn new(
         thread_store: WeakModel<ThreadStore>,
         context_picker: WeakView<ContextPicker>,
-        context_strip: WeakView<ContextStrip>,
+        context_store: WeakModel<context_store::ContextStore>,
     ) -> Self {
         ThreadContextPickerDelegate {
             thread_store,
             context_picker,
-            context_strip,
+            context_store,
             matches: Vec::new(),
             selected_index: 0,
         }
@@ -157,8 +157,8 @@ impl PickerDelegate for ThreadContextPickerDelegate {
             return;
         };
 
-        self.context_strip
-            .update(cx, |context_strip, cx| {
+        self.context_store
+            .update(cx, |context_store, cx| {
                 let text = thread.update(cx, |thread, _cx| {
                     let mut text = String::new();
 
@@ -177,7 +177,7 @@ impl PickerDelegate for ThreadContextPickerDelegate {
                     text
                 });
 
-                context_strip.insert_context(ContextKind::Thread, entry.summary.clone(), text);
+                context_store.insert_context(ContextKind::Thread, entry.summary.clone(), text);
             })
             .ok();
     }

crates/assistant2/src/context_store.rs 🔗

@@ -0,0 +1,47 @@
+use gpui::SharedString;
+
+use crate::context::{Context, ContextId, ContextKind};
+
+pub struct ContextStore {
+    context: Vec<Context>,
+    next_context_id: ContextId,
+}
+
+impl ContextStore {
+    pub fn new() -> Self {
+        Self {
+            context: Vec::new(),
+            next_context_id: ContextId(0),
+        }
+    }
+
+    pub fn context(&self) -> &Vec<Context> {
+        &self.context
+    }
+
+    pub fn drain(&mut self) -> Vec<Context> {
+        self.context.drain(..).collect()
+    }
+
+    pub fn clear(&mut self) {
+        self.context.clear();
+    }
+
+    pub fn insert_context(
+        &mut self,
+        kind: ContextKind,
+        name: impl Into<SharedString>,
+        text: impl Into<SharedString>,
+    ) {
+        self.context.push(Context {
+            id: self.next_context_id.post_inc(),
+            name: name.into(),
+            kind,
+            text: text.into(),
+        });
+    }
+
+    pub fn remove_context(&mut self, id: &ContextId) {
+        self.context.retain(|context| context.id != *id);
+    }
+}

crates/assistant2/src/context_strip.rs 🔗

@@ -1,60 +1,45 @@
 use std::rc::Rc;
 
-use gpui::{View, WeakModel, WeakView};
+use gpui::{Model, View, WeakModel, WeakView};
 use ui::{prelude::*, IconButtonShape, PopoverMenu, PopoverMenuHandle, Tooltip};
 use workspace::Workspace;
 
-use crate::context::{Context, ContextId, ContextKind};
 use crate::context_picker::ContextPicker;
+use crate::context_store::ContextStore;
 use crate::thread_store::ThreadStore;
 use crate::ui::ContextPill;
 
 pub struct ContextStrip {
-    context: Vec<Context>,
-    next_context_id: ContextId,
+    context_store: Model<ContextStore>,
     context_picker: View<ContextPicker>,
     pub(crate) context_picker_handle: PopoverMenuHandle<ContextPicker>,
 }
 
 impl ContextStrip {
     pub fn new(
+        context_store: Model<ContextStore>,
         workspace: WeakView<Workspace>,
-        thread_store: WeakModel<ThreadStore>,
+        thread_store: Option<WeakModel<ThreadStore>>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
-        let weak_self = cx.view().downgrade();
-
         Self {
-            context: Vec::new(),
-            next_context_id: ContextId(0),
+            context_store: context_store.clone(),
             context_picker: cx.new_view(|cx| {
-                ContextPicker::new(workspace.clone(), thread_store.clone(), weak_self, cx)
+                ContextPicker::new(
+                    workspace.clone(),
+                    thread_store.clone(),
+                    context_store.downgrade(),
+                    cx,
+                )
             }),
             context_picker_handle: PopoverMenuHandle::default(),
         }
     }
-
-    pub fn drain(&mut self) -> Vec<Context> {
-        self.context.drain(..).collect()
-    }
-
-    pub fn insert_context(
-        &mut self,
-        kind: ContextKind,
-        name: impl Into<SharedString>,
-        text: impl Into<SharedString>,
-    ) {
-        self.context.push(Context {
-            id: self.next_context_id.post_inc(),
-            name: name.into(),
-            kind,
-            text: text.into(),
-        });
-    }
 }
 
 impl Render for ContextStrip {
     fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+        let context = self.context_store.read(cx).context();
         let context_picker = self.context_picker.clone();
 
         h_flex()
@@ -76,25 +61,31 @@ impl Render for ContextStrip {
                     })
                     .with_handle(self.context_picker_handle.clone()),
             )
-            .children(self.context.iter().map(|context| {
+            .children(context.iter().map(|context| {
                 ContextPill::new(context.clone()).on_remove({
                     let context = context.clone();
-                    Rc::new(cx.listener(move |this, _event, cx| {
-                        this.context.retain(|other| other.id != context.id);
+                    let context_store = self.context_store.clone();
+                    Rc::new(cx.listener(move |_this, _event, cx| {
+                        context_store.update(cx, |this, _cx| {
+                            this.remove_context(&context.id);
+                        });
                         cx.notify();
                     }))
                 })
             }))
-            .when(!self.context.is_empty(), |parent| {
+            .when(!context.is_empty(), |parent| {
                 parent.child(
                     IconButton::new("remove-all-context", IconName::Eraser)
                         .shape(IconButtonShape::Square)
                         .icon_size(IconSize::Small)
                         .tooltip(move |cx| Tooltip::text("Remove All Context", cx))
-                        .on_click(cx.listener(|this, _event, cx| {
-                            this.context.clear();
-                            cx.notify();
-                        })),
+                        .on_click({
+                            let context_store = self.context_store.clone();
+                            cx.listener(move |_this, _event, cx| {
+                                context_store.update(cx, |this, _cx| this.clear());
+                                cx.notify();
+                            })
+                        }),
                 )
             })
     }

crates/assistant2/src/inline_assistant.rs 🔗

@@ -1,3 +1,8 @@
+use crate::context::attach_context_to_message;
+use crate::context_store::ContextStore;
+use crate::context_strip::ContextStrip;
+use crate::thread_store::ThreadStore;
+use crate::AssistantPanel;
 use crate::{
     assistant_settings::AssistantSettings,
     prompts::PromptBuilder,
@@ -24,7 +29,8 @@ use futures::{channel::mpsc, future::LocalBoxFuture, join, SinkExt, Stream, Stre
 use gpui::{
     anchored, deferred, point, AnyElement, AppContext, ClickEvent, CursorStyle, EventEmitter,
     FocusHandle, FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext,
-    Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakView, WindowContext,
+    Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakModel, WeakView,
+    WindowContext,
 };
 use language::{Buffer, IndentKind, Point, Selection, TransactionId};
 use language_model::{
@@ -178,10 +184,16 @@ impl InlineAssistant {
     ) {
         if let Some(editor) = item.act_as::<Editor>(cx) {
             editor.update(cx, |editor, cx| {
+                let thread_store = workspace
+                    .read(cx)
+                    .panel::<AssistantPanel>(cx)
+                    .map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
+
                 editor.push_code_action_provider(
                     Rc::new(AssistantCodeActionProvider {
                         editor: cx.view().downgrade(),
                         workspace: workspace.downgrade(),
+                        thread_store,
                     }),
                     cx,
                 );
@@ -212,7 +224,11 @@ impl InlineAssistant {
         let handle_assist = |cx: &mut ViewContext<Workspace>| match inline_assist_target {
             InlineAssistTarget::Editor(active_editor) => {
                 InlineAssistant::update_global(cx, |assistant, cx| {
-                    assistant.assist(&active_editor, cx.view().downgrade(), cx)
+                    let thread_store = workspace
+                        .panel::<AssistantPanel>(cx)
+                        .map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
+
+                    assistant.assist(&active_editor, cx.view().downgrade(), thread_store, cx)
                 })
             }
             InlineAssistTarget::Terminal(active_terminal) => {
@@ -265,6 +281,7 @@ impl InlineAssistant {
         &mut self,
         editor: &View<Editor>,
         workspace: WeakView<Workspace>,
+        thread_store: Option<WeakModel<ThreadStore>>,
         cx: &mut WindowContext,
     ) {
         let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
@@ -343,11 +360,13 @@ impl InlineAssistant {
         let mut assist_to_focus = None;
         for range in codegen_ranges {
             let assist_id = self.next_assist_id.post_inc();
+            let context_store = cx.new_model(|_cx| ContextStore::new());
             let codegen = cx.new_model(|cx| {
                 Codegen::new(
                     editor.read(cx).buffer().clone(),
                     range.clone(),
                     None,
+                    context_store.clone(),
                     self.telemetry.clone(),
                     self.prompt_builder.clone(),
                     cx,
@@ -363,6 +382,9 @@ impl InlineAssistant {
                     prompt_buffer.clone(),
                     codegen.clone(),
                     self.fs.clone(),
+                    context_store,
+                    workspace.clone(),
+                    thread_store.clone(),
                     cx,
                 )
             });
@@ -430,6 +452,7 @@ impl InlineAssistant {
         initial_transaction_id: Option<TransactionId>,
         focus: bool,
         workspace: WeakView<Workspace>,
+        thread_store: Option<WeakModel<ThreadStore>>,
         cx: &mut WindowContext,
     ) -> InlineAssistId {
         let assist_group_id = self.next_assist_group_id.post_inc();
@@ -445,11 +468,14 @@ impl InlineAssistant {
             range.end = range.end.bias_right(&snapshot);
         }
 
+        let context_store = cx.new_model(|_cx| ContextStore::new());
+
         let codegen = cx.new_model(|cx| {
             Codegen::new(
                 editor.read(cx).buffer().clone(),
                 range.clone(),
                 initial_transaction_id,
+                context_store.clone(),
                 self.telemetry.clone(),
                 self.prompt_builder.clone(),
                 cx,
@@ -465,6 +491,9 @@ impl InlineAssistant {
                 prompt_buffer.clone(),
                 codegen.clone(),
                 self.fs.clone(),
+                context_store,
+                workspace.clone(),
+                thread_store,
                 cx,
             )
         });
@@ -1456,6 +1485,7 @@ enum PromptEditorEvent {
 struct PromptEditor {
     id: InlineAssistId,
     editor: View<Editor>,
+    context_strip: View<ContextStrip>,
     language_model_selector: View<LanguageModelSelector>,
     edited_since_done: bool,
     gutter_dimensions: Arc<Mutex<GutterDimensions>>,
@@ -1473,11 +1503,7 @@ impl EventEmitter<PromptEditorEvent> for PromptEditor {}
 impl Render for PromptEditor {
     fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
         let gutter_dimensions = *self.gutter_dimensions.lock();
-        let mut buttons = vec![Button::new("add-context", "Add Context")
-            .style(ButtonStyle::Filled)
-            .icon(IconName::Plus)
-            .icon_position(IconPosition::Start)
-            .into_any_element()];
+        let mut buttons = Vec::new();
         let codegen = self.codegen.read(cx);
         if codegen.alternative_count(cx) > 1 {
             buttons.push(self.render_cycle_controls(cx));
@@ -1570,91 +1596,114 @@ impl Render for PromptEditor {
             }
         });
 
-        h_flex()
-            .key_context("PromptEditor")
-            .bg(cx.theme().colors().editor_background)
-            .block_mouse_down()
-            .cursor(CursorStyle::Arrow)
+        v_flex()
             .border_y_1()
             .border_color(cx.theme().status().info_border)
             .size_full()
             .py(cx.line_height() / 2.5)
-            .on_action(cx.listener(Self::confirm))
-            .on_action(cx.listener(Self::cancel))
-            .on_action(cx.listener(Self::move_up))
-            .on_action(cx.listener(Self::move_down))
-            .capture_action(cx.listener(Self::cycle_prev))
-            .capture_action(cx.listener(Self::cycle_next))
             .child(
                 h_flex()
-                    .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
-                    .justify_center()
-                    .gap_2()
-                    .child(LanguageModelSelectorPopoverMenu::new(
-                        self.language_model_selector.clone(),
-                        IconButton::new("context", IconName::SettingsAlt)
-                            .shape(IconButtonShape::Square)
-                            .icon_size(IconSize::Small)
-                            .icon_color(Color::Muted)
-                            .tooltip(move |cx| {
-                                Tooltip::with_meta(
-                                    format!(
-                                        "Using {}",
-                                        LanguageModelRegistry::read_global(cx)
-                                            .active_model()
-                                            .map(|model| model.name().0)
-                                            .unwrap_or_else(|| "No model selected".into()),
-                                    ),
-                                    None,
-                                    "Change Model",
-                                    cx,
-                                )
-                            }),
-                    ))
-                    .map(|el| {
-                        let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) else {
-                            return el;
-                        };
-
-                        let error_message = SharedString::from(error.to_string());
-                        if error.error_code() == proto::ErrorCode::RateLimitExceeded
-                            && cx.has_flag::<ZedPro>()
-                        {
-                            el.child(
-                                v_flex()
-                                    .child(
-                                        IconButton::new("rate-limit-error", IconName::XCircle)
-                                            .toggle_state(self.show_rate_limit_notice)
-                                            .shape(IconButtonShape::Square)
-                                            .icon_size(IconSize::Small)
-                                            .on_click(cx.listener(Self::toggle_rate_limit_notice)),
-                                    )
-                                    .children(self.show_rate_limit_notice.then(|| {
-                                        deferred(
-                                            anchored()
-                                                .position_mode(gpui::AnchoredPositionMode::Local)
-                                                .position(point(px(0.), px(24.)))
-                                                .anchor(gpui::AnchorCorner::TopLeft)
-                                                .child(self.render_rate_limit_notice(cx)),
+                    .key_context("PromptEditor")
+                    .bg(cx.theme().colors().editor_background)
+                    .block_mouse_down()
+                    .cursor(CursorStyle::Arrow)
+                    .on_action(cx.listener(Self::confirm))
+                    .on_action(cx.listener(Self::cancel))
+                    .on_action(cx.listener(Self::move_up))
+                    .on_action(cx.listener(Self::move_down))
+                    .capture_action(cx.listener(Self::cycle_prev))
+                    .capture_action(cx.listener(Self::cycle_next))
+                    .child(
+                        h_flex()
+                            .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
+                            .justify_center()
+                            .gap_2()
+                            .child(LanguageModelSelectorPopoverMenu::new(
+                                self.language_model_selector.clone(),
+                                IconButton::new("context", IconName::SettingsAlt)
+                                    .shape(IconButtonShape::Square)
+                                    .icon_size(IconSize::Small)
+                                    .icon_color(Color::Muted)
+                                    .tooltip(move |cx| {
+                                        Tooltip::with_meta(
+                                            format!(
+                                                "Using {}",
+                                                LanguageModelRegistry::read_global(cx)
+                                                    .active_model()
+                                                    .map(|model| model.name().0)
+                                                    .unwrap_or_else(|| "No model selected".into()),
+                                            ),
+                                            None,
+                                            "Change Model",
+                                            cx,
                                         )
-                                    })),
-                            )
-                        } else {
-                            el.child(
-                                div()
-                                    .id("error")
-                                    .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
-                                    .child(
-                                        Icon::new(IconName::XCircle)
-                                            .size(IconSize::Small)
-                                            .color(Color::Error),
-                                    ),
-                            )
-                        }
-                    }),
+                                    }),
+                            ))
+                            .map(|el| {
+                                let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx)
+                                else {
+                                    return el;
+                                };
+
+                                let error_message = SharedString::from(error.to_string());
+                                if error.error_code() == proto::ErrorCode::RateLimitExceeded
+                                    && cx.has_flag::<ZedPro>()
+                                {
+                                    el.child(
+                                        v_flex()
+                                            .child(
+                                                IconButton::new(
+                                                    "rate-limit-error",
+                                                    IconName::XCircle,
+                                                )
+                                                .toggle_state(self.show_rate_limit_notice)
+                                                .shape(IconButtonShape::Square)
+                                                .icon_size(IconSize::Small)
+                                                .on_click(
+                                                    cx.listener(Self::toggle_rate_limit_notice),
+                                                ),
+                                            )
+                                            .children(self.show_rate_limit_notice.then(|| {
+                                                deferred(
+                                                    anchored()
+                                                        .position_mode(
+                                                            gpui::AnchoredPositionMode::Local,
+                                                        )
+                                                        .position(point(px(0.), px(24.)))
+                                                        .anchor(gpui::AnchorCorner::TopLeft)
+                                                        .child(self.render_rate_limit_notice(cx)),
+                                                )
+                                            })),
+                                    )
+                                } else {
+                                    el.child(
+                                        div()
+                                            .id("error")
+                                            .tooltip(move |cx| {
+                                                Tooltip::text(error_message.clone(), cx)
+                                            })
+                                            .child(
+                                                Icon::new(IconName::XCircle)
+                                                    .size(IconSize::Small)
+                                                    .color(Color::Error),
+                                            ),
+                                    )
+                                }
+                            }),
+                    )
+                    .child(div().flex_1().child(self.render_editor(cx)))
+                    .child(h_flex().gap_2().pr_6().children(buttons)),
+            )
+            .child(
+                h_flex()
+                    .child(
+                        h_flex()
+                            .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
+                            .justify_center()
+                            .gap_2(),
+                    )
+                    .child(self.context_strip.clone()),
             )
-            .child(div().flex_1().child(self.render_editor(cx)))
-            .child(h_flex().gap_2().pr_6().children(buttons))
     }
 }
 
@@ -1675,6 +1724,9 @@ impl PromptEditor {
         prompt_buffer: Model<MultiBuffer>,
         codegen: Model<Codegen>,
         fs: Arc<dyn Fs>,
+        context_store: Model<ContextStore>,
+        workspace: WeakView<Workspace>,
+        thread_store: Option<WeakModel<ThreadStore>>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
         let prompt_editor = cx.new_view(|cx| {
@@ -1699,6 +1751,9 @@ impl PromptEditor {
         let mut this = Self {
             id,
             editor: prompt_editor,
+            context_strip: cx.new_view(|cx| {
+                ContextStrip::new(context_store, workspace.clone(), thread_store.clone(), cx)
+            }),
             language_model_selector: cx.new_view(|cx| {
                 let fs = fs.clone();
                 LanguageModelSelector::new(
@@ -2293,6 +2348,7 @@ pub struct Codegen {
     buffer: Model<MultiBuffer>,
     range: Range<Anchor>,
     initial_transaction_id: Option<TransactionId>,
+    context_store: Model<ContextStore>,
     telemetry: Arc<Telemetry>,
     builder: Arc<PromptBuilder>,
     is_insertion: bool,
@@ -2303,6 +2359,7 @@ impl Codegen {
         buffer: Model<MultiBuffer>,
         range: Range<Anchor>,
         initial_transaction_id: Option<TransactionId>,
+        context_store: Model<ContextStore>,
         telemetry: Arc<Telemetry>,
         builder: Arc<PromptBuilder>,
         cx: &mut ModelContext<Self>,
@@ -2312,6 +2369,7 @@ impl Codegen {
                 buffer.clone(),
                 range.clone(),
                 false,
+                Some(context_store.clone()),
                 Some(telemetry.clone()),
                 builder.clone(),
                 cx,
@@ -2326,6 +2384,7 @@ impl Codegen {
             buffer,
             range,
             initial_transaction_id,
+            context_store,
             telemetry,
             builder,
         };
@@ -2398,6 +2457,7 @@ impl Codegen {
                     self.buffer.clone(),
                     self.range.clone(),
                     false,
+                    Some(self.context_store.clone()),
                     Some(self.telemetry.clone()),
                     self.builder.clone(),
                     cx,
@@ -2477,6 +2537,7 @@ pub struct CodegenAlternative {
     status: CodegenStatus,
     generation: Task<()>,
     diff: Diff,
+    context_store: Option<Model<ContextStore>>,
     telemetry: Option<Arc<Telemetry>>,
     _subscription: gpui::Subscription,
     builder: Arc<PromptBuilder>,
@@ -2515,6 +2576,7 @@ impl CodegenAlternative {
         buffer: Model<MultiBuffer>,
         range: Range<Anchor>,
         active: bool,
+        context_store: Option<Model<ContextStore>>,
         telemetry: Option<Arc<Telemetry>>,
         builder: Arc<PromptBuilder>,
         cx: &mut ModelContext<Self>,
@@ -2552,6 +2614,7 @@ impl CodegenAlternative {
             status: CodegenStatus::Idle,
             generation: Task::ready(()),
             diff: Diff::default(),
+            context_store,
             telemetry,
             _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
             builder,
@@ -2637,7 +2700,11 @@ impl CodegenAlternative {
         Ok(())
     }
 
-    fn build_request(&self, user_prompt: String, cx: &AppContext) -> Result<LanguageModelRequest> {
+    fn build_request(
+        &self,
+        user_prompt: String,
+        cx: &mut AppContext,
+    ) -> Result<LanguageModelRequest> {
         let buffer = self.buffer.read(cx).snapshot(cx);
         let language = buffer.language_at(self.range.start);
         let language_name = if let Some(language) = language.as_ref() {
@@ -2670,15 +2737,24 @@ impl CodegenAlternative {
             .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
             .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
 
+        let mut request_message = LanguageModelRequestMessage {
+            role: Role::User,
+            content: Vec::new(),
+            cache: false,
+        };
+
+        if let Some(context_store) = &self.context_store {
+            let context = context_store.update(cx, |this, _cx| this.context().clone());
+            attach_context_to_message(&mut request_message, context);
+        }
+
+        request_message.content.push(prompt.into());
+
         Ok(LanguageModelRequest {
             tools: Vec::new(),
             stop: Vec::new(),
             temperature: None,
-            messages: vec![LanguageModelRequestMessage {
-                role: Role::User,
-                content: vec![prompt.into()],
-                cache: false,
-            }],
+            messages: vec![request_message],
         })
     }
 
@@ -3273,6 +3349,7 @@ where
 struct AssistantCodeActionProvider {
     editor: WeakView<Editor>,
     workspace: WeakView<Workspace>,
+    thread_store: Option<WeakModel<ThreadStore>>,
 }
 
 impl CodeActionProvider for AssistantCodeActionProvider {
@@ -3337,6 +3414,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
     ) -> Task<Result<ProjectTransaction>> {
         let editor = self.editor.clone();
         let workspace = self.workspace.clone();
+        let thread_store = self.thread_store.clone();
         cx.spawn(|mut cx| async move {
             let editor = editor.upgrade().context("editor was released")?;
             let range = editor
@@ -3384,6 +3462,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
                     None,
                     true,
                     workspace,
+                    thread_store,
                     cx,
                 );
                 assistant.start_assist(assist_id, cx);
@@ -3469,6 +3548,7 @@ mod tests {
                 range.clone(),
                 true,
                 None,
+                None,
                 prompt_builder,
                 cx,
             )
@@ -3533,6 +3613,7 @@ mod tests {
                 range.clone(),
                 true,
                 None,
+                None,
                 prompt_builder,
                 cx,
             )
@@ -3600,6 +3681,7 @@ mod tests {
                 range.clone(),
                 true,
                 None,
+                None,
                 prompt_builder,
                 cx,
             )
@@ -3666,6 +3748,7 @@ mod tests {
                 range.clone(),
                 true,
                 None,
+                None,
                 prompt_builder,
                 cx,
             )
@@ -3721,6 +3804,7 @@ mod tests {
                 range.clone(),
                 false,
                 None,
+                None,
                 prompt_builder,
                 cx,
             )

crates/assistant2/src/message_editor.rs 🔗

@@ -7,6 +7,7 @@ use theme::ThemeSettings;
 use ui::{prelude::*, ButtonLike, CheckboxWithLabel, ElevationIndex, KeyBinding, Tooltip};
 use workspace::Workspace;
 
+use crate::context_store::ContextStore;
 use crate::context_strip::ContextStrip;
 use crate::thread::{RequestKind, Thread};
 use crate::thread_store::ThreadStore;
@@ -15,6 +16,7 @@ use crate::{Chat, ToggleModelSelector};
 pub struct MessageEditor {
     thread: Model<Thread>,
     editor: View<Editor>,
+    context_store: Model<ContextStore>,
     context_strip: View<ContextStrip>,
     language_model_selector: View<LanguageModelSelector>,
     use_tools: bool,
@@ -27,6 +29,8 @@ impl MessageEditor {
         thread: Model<Thread>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
+        let context_store = cx.new_model(|_cx| ContextStore::new());
+
         Self {
             thread,
             editor: cx.new_view(|cx| {
@@ -35,8 +39,15 @@ impl MessageEditor {
 
                 editor
             }),
-            context_strip: cx
-                .new_view(|cx| ContextStrip::new(workspace.clone(), thread_store.clone(), cx)),
+            context_store: context_store.clone(),
+            context_strip: cx.new_view(|cx| {
+                ContextStrip::new(
+                    context_store,
+                    workspace.clone(),
+                    Some(thread_store.clone()),
+                    cx,
+                )
+            }),
             language_model_selector: cx.new_view(|cx| {
                 LanguageModelSelector::new(
                     |model, _cx| {
@@ -75,7 +86,7 @@ impl MessageEditor {
             editor.clear(cx);
             text
         });
-        let context = self.context_strip.update(cx, |this, _cx| this.drain());
+        let context = self.context_store.update(cx, |this, _cx| this.drain());
 
         self.thread.update(cx, |thread, cx| {
             thread.insert_user_message(user_message, context, cx);

crates/assistant2/src/thread.rs 🔗

@@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
 use util::{post_inc, TryFutureExt as _};
 use uuid::Uuid;
 
-use crate::context::{Context, ContextKind};
+use crate::context::{attach_context_to_message, Context};
 
 #[derive(Debug, Clone, Copy)]
 pub enum RequestKind {
@@ -192,51 +192,7 @@ impl Thread {
             }
 
             if let Some(context) = self.context_for_message(message.id) {
-                let mut file_context = String::new();
-                let mut fetch_context = String::new();
-                let mut thread_context = String::new();
-
-                for context in context.iter() {
-                    match context.kind {
-                        ContextKind::File => {
-                            file_context.push_str(&context.text);
-                            file_context.push('\n');
-                        }
-                        ContextKind::FetchedUrl => {
-                            fetch_context.push_str(&context.name);
-                            fetch_context.push('\n');
-                            fetch_context.push_str(&context.text);
-                            fetch_context.push('\n');
-                        }
-                        ContextKind::Thread => {
-                            thread_context.push_str(&context.name);
-                            thread_context.push('\n');
-                            thread_context.push_str(&context.text);
-                            thread_context.push('\n');
-                        }
-                    }
-                }
-
-                let mut context_text = String::new();
-                if !file_context.is_empty() {
-                    context_text.push_str("The following files are available:\n");
-                    context_text.push_str(&file_context);
-                }
-
-                if !fetch_context.is_empty() {
-                    context_text.push_str("The following fetched results are available\n");
-                    context_text.push_str(&fetch_context);
-                }
-
-                if !thread_context.is_empty() {
-                    context_text
-                        .push_str("The following previous conversation threads are available\n");
-                    context_text.push_str(&thread_context);
-                }
-
-                request_message
-                    .content
-                    .push(MessageContent::Text(context_text))
+                attach_context_to_message(&mut request_message, context.clone());
             }
 
             if !message.text.is_empty() {