assistant: Allow guests to create new contexts on the host (#15439)

Marshall Bowers created

This PR extends collaboration in the Assistant to allow guests to create
new contexts on the host when collaborating.

Release Notes:

- N/A

Change summary

crates/assistant/src/assistant_panel.rs | 130 ++++++++++++++++++++++----
crates/assistant/src/context_store.rs   |  99 ++++++++++++++++++++
crates/collab/src/rpc.rs                |   3 
crates/proto/proto/zed.proto            |  13 ++
crates/proto/src/proto.rs               |   4 
5 files changed, 225 insertions(+), 24 deletions(-)

Detailed changes

crates/assistant/src/assistant_panel.rs 🔗

@@ -1,3 +1,4 @@
+use crate::ContextStoreEvent;
 use crate::{
     assistant_settings::{AssistantDockPosition, AssistantSettings},
     humanize_token_count,
@@ -389,6 +390,7 @@ impl AssistantPanel {
             cx.subscribe(&pane, Self::handle_pane_event),
             cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event),
             cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event),
+            cx.subscribe(&context_store, Self::handle_context_store_event),
             cx.observe(
                 &LanguageModelCompletionProvider::global(cx),
                 |this, _, cx| {
@@ -507,6 +509,46 @@ impl AssistantPanel {
         }
     }
 
+    fn handle_context_store_event(
+        &mut self,
+        _context_store: Model<ContextStore>,
+        event: &ContextStoreEvent,
+        cx: &mut ViewContext<Self>,
+    ) {
+        let ContextStoreEvent::ContextCreated(context_id) = event;
+        let Some(context) = self
+            .context_store
+            .read(cx)
+            .loaded_context_for_id(&context_id, cx)
+        else {
+            log::error!("no context found with ID: {}", context_id.to_proto());
+            return;
+        };
+        let Some(workspace) = self.workspace.upgrade() else {
+            return;
+        };
+        let lsp_adapter_delegate = workspace.update(cx, |workspace, cx| {
+            make_lsp_adapter_delegate(workspace.project(), cx).log_err()
+        });
+
+        let assistant_panel = cx.view().downgrade();
+        let editor = cx.new_view(|cx| {
+            let mut editor = ContextEditor::for_context(
+                context,
+                self.fs.clone(),
+                workspace.clone(),
+                self.project.clone(),
+                lsp_adapter_delegate,
+                assistant_panel,
+                cx,
+            );
+            editor.insert_default_prompt(cx);
+            editor
+        });
+
+        self.show_context(editor.clone(), cx);
+    }
+
     fn completion_provider_changed(&mut self, cx: &mut ViewContext<Self>) {
         if let Some(editor) = self.active_context_editor(cx) {
             editor.update(cx, |active_context, cx| {
@@ -681,29 +723,75 @@ impl AssistantPanel {
     }
 
     fn new_context(&mut self, cx: &mut ViewContext<Self>) -> Option<View<ContextEditor>> {
-        let context = self.context_store.update(cx, |store, cx| store.create(cx));
-        let workspace = self.workspace.upgrade()?;
-        let lsp_adapter_delegate = workspace.update(cx, |workspace, cx| {
-            make_lsp_adapter_delegate(workspace.project(), cx).log_err()
-        });
+        if self.project.read(cx).is_remote() {
+            let task = self
+                .context_store
+                .update(cx, |store, cx| store.create_remote_context(cx));
 
-        let assistant_panel = cx.view().downgrade();
-        let editor = cx.new_view(|cx| {
-            let mut editor = ContextEditor::for_context(
-                context,
-                self.fs.clone(),
-                workspace.clone(),
-                self.project.clone(),
-                lsp_adapter_delegate,
-                assistant_panel,
-                cx,
-            );
-            editor.insert_default_prompt(cx);
-            editor
-        });
+            cx.spawn(|this, mut cx| async move {
+                let context = task.await?;
 
-        self.show_context(editor.clone(), cx);
-        Some(editor)
+                this.update(&mut cx, |this, cx| {
+                    let Some(workspace) = this.workspace.upgrade() else {
+                        return Ok(());
+                    };
+                    let lsp_adapter_delegate = workspace.update(cx, |workspace, cx| {
+                        make_lsp_adapter_delegate(workspace.project(), cx).log_err()
+                    });
+
+                    let fs = this.fs.clone();
+                    let project = this.project.clone();
+                    let weak_assistant_panel = cx.view().downgrade();
+
+                    let editor = cx.new_view(|cx| {
+                        let mut editor = ContextEditor::for_context(
+                            context,
+                            fs,
+                            workspace.clone(),
+                            project,
+                            lsp_adapter_delegate,
+                            weak_assistant_panel,
+                            cx,
+                        );
+                        editor.insert_default_prompt(cx);
+                        editor
+                    });
+
+                    this.show_context(editor, cx);
+
+                    anyhow::Ok(())
+                })??;
+
+                anyhow::Ok(())
+            })
+            .detach_and_log_err(cx);
+
+            None
+        } else {
+            let context = self.context_store.update(cx, |store, cx| store.create(cx));
+            let workspace = self.workspace.upgrade()?;
+            let lsp_adapter_delegate = workspace.update(cx, |workspace, cx| {
+                make_lsp_adapter_delegate(workspace.project(), cx).log_err()
+            });
+
+            let assistant_panel = cx.view().downgrade();
+            let editor = cx.new_view(|cx| {
+                let mut editor = ContextEditor::for_context(
+                    context,
+                    self.fs.clone(),
+                    workspace.clone(),
+                    self.project.clone(),
+                    lsp_adapter_delegate,
+                    assistant_panel,
+                    cx,
+                );
+                editor.insert_default_prompt(cx);
+                editor
+            });
+
+            self.show_context(editor.clone(), cx);
+            Some(editor)
+        }
     }
 
     fn show_context(&mut self, context_editor: View<ContextEditor>, cx: &mut ViewContext<Self>) {

crates/assistant/src/context_store.rs 🔗

@@ -8,7 +8,9 @@ use clock::ReplicaId;
 use fs::Fs;
 use futures::StreamExt;
 use fuzzy::StringMatchCandidate;
-use gpui::{AppContext, AsyncAppContext, Context as _, Model, ModelContext, Task, WeakModel};
+use gpui::{
+    AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
+};
 use language::LanguageRegistry;
 use paths::contexts_dir;
 use project::Project;
@@ -26,6 +28,7 @@ use util::{ResultExt, TryFutureExt};
 pub fn init(client: &Arc<Client>) {
     client.add_model_message_handler(ContextStore::handle_advertise_contexts);
     client.add_model_request_handler(ContextStore::handle_open_context);
+    client.add_model_request_handler(ContextStore::handle_create_context);
     client.add_model_message_handler(ContextStore::handle_update_context);
     client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
 }
@@ -51,6 +54,12 @@ pub struct ContextStore {
     _project_subscriptions: Vec<gpui::Subscription>,
 }
 
+pub enum ContextStoreEvent {
+    ContextCreated(ContextId),
+}
+
+impl EventEmitter<ContextStoreEvent> for ContextStore {}
+
 enum ContextHandle {
     Weak(WeakModel<Context>),
     Strong(Model<Context>),
@@ -169,6 +178,34 @@ impl ContextStore {
         })
     }
 
+    async fn handle_create_context(
+        this: Model<Self>,
+        _: TypedEnvelope<proto::CreateContext>,
+        mut cx: AsyncAppContext,
+    ) -> Result<proto::CreateContextResponse> {
+        let (context_id, operations) = this.update(&mut cx, |this, cx| {
+            if this.project.read(cx).is_remote() {
+                return Err(anyhow!("can only create contexts as the host"));
+            }
+
+            let context = this.create(cx);
+            let context_id = context.read(cx).id().clone();
+            cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
+
+            anyhow::Ok((
+                context_id,
+                context
+                    .read(cx)
+                    .serialize_ops(&ContextVersion::default(), cx),
+            ))
+        })??;
+        let operations = operations.await;
+        Ok(proto::CreateContextResponse {
+            context_id: context_id.to_proto(),
+            context: Some(proto::Context { operations }),
+        })
+    }
+
     async fn handle_update_context(
         this: Model<Self>,
         envelope: TypedEnvelope<proto::UpdateContext>,
@@ -299,6 +336,60 @@ impl ContextStore {
         context
     }
 
+    pub fn create_remote_context(
+        &mut self,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<Model<Context>>> {
+        let project = self.project.read(cx);
+        let Some(project_id) = project.remote_id() else {
+            return Task::ready(Err(anyhow!("project was not remote")));
+        };
+        if project.is_local() {
+            return Task::ready(Err(anyhow!("cannot create remote contexts as the host")));
+        }
+
+        let replica_id = project.replica_id();
+        let capability = project.capability();
+        let language_registry = self.languages.clone();
+        let telemetry = self.telemetry.clone();
+        let request = self.client.request(proto::CreateContext { project_id });
+        cx.spawn(|this, mut cx| async move {
+            let response = request.await?;
+            let context_id = ContextId::from_proto(response.context_id);
+            let context_proto = response.context.context("invalid context")?;
+            let context = cx.new_model(|cx| {
+                Context::new(
+                    context_id.clone(),
+                    replica_id,
+                    capability,
+                    language_registry,
+                    Some(telemetry),
+                    cx,
+                )
+            })?;
+            let operations = cx
+                .background_executor()
+                .spawn(async move {
+                    context_proto
+                        .operations
+                        .into_iter()
+                        .map(|op| ContextOperation::from_proto(op))
+                        .collect::<Result<Vec<_>>>()
+                })
+                .await?;
+            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
+            this.update(&mut cx, |this, cx| {
+                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
+                    existing_context
+                } else {
+                    this.register_context(&context, cx);
+                    this.synchronize_contexts(cx);
+                    context
+                }
+            })
+        })
+    }
+
     pub fn open_local_context(
         &mut self,
         path: PathBuf,
@@ -346,7 +437,11 @@ impl ContextStore {
         })
     }
 
-    fn loaded_context_for_id(&self, id: &ContextId, cx: &AppContext) -> Option<Model<Context>> {
+    pub(super) fn loaded_context_for_id(
+        &self,
+        id: &ContextId,
+        cx: &AppContext,
+    ) -> Option<Model<Context>> {
         self.contexts.iter().find_map(|context| {
             let context = context.upgrade()?;
             if context.read(cx).id() == id {

crates/collab/src/rpc.rs 🔗

@@ -600,6 +600,9 @@ impl Server {
             .add_request_handler(user_handler(
                 forward_mutating_project_request::<proto::OpenContext>,
             ))
+            .add_request_handler(user_handler(
+                forward_mutating_project_request::<proto::CreateContext>,
+            ))
             .add_request_handler(user_handler(
                 forward_mutating_project_request::<proto::SynchronizeContexts>,
             ))

crates/proto/proto/zed.proto 🔗

@@ -199,7 +199,7 @@ message Envelope {
         StreamCompleteWithLanguageModel stream_complete_with_language_model = 228;
         StreamCompleteWithLanguageModelResponse stream_complete_with_language_model_response = 229;
         CountLanguageModelTokens count_language_model_tokens = 230;
-        CountLanguageModelTokensResponse count_language_model_tokens_response = 231; // current max
+        CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
         GetCachedEmbeddings get_cached_embeddings = 189;
         GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
         ComputeEmbeddings compute_embeddings = 191;
@@ -255,6 +255,8 @@ message Envelope {
         AdvertiseContexts advertise_contexts = 211;
         OpenContext open_context = 212;
         OpenContextResponse open_context_response = 213;
+        CreateContext create_context = 232;
+        CreateContextResponse create_context_response = 233; // current max
         UpdateContext update_context = 214;
         SynchronizeContexts synchronize_contexts = 215;
         SynchronizeContextsResponse synchronize_contexts_response = 216;
@@ -2381,6 +2383,15 @@ message OpenContextResponse {
     Context context = 1;
 }
 
+message CreateContext {
+    uint64 project_id = 1;
+}
+
+message CreateContextResponse {
+    string context_id = 1;
+    Context context = 2;
+}
+
 message UpdateContext {
     uint64 project_id = 1;
     string context_id = 2;

crates/proto/src/proto.rs 🔗

@@ -398,6 +398,8 @@ messages!(
     (AdvertiseContexts, Foreground),
     (OpenContext, Foreground),
     (OpenContextResponse, Foreground),
+    (CreateContext, Foreground),
+    (CreateContextResponse, Foreground),
     (UpdateContext, Foreground),
     (SynchronizeContexts, Foreground),
     (SynchronizeContextsResponse, Foreground),
@@ -523,6 +525,7 @@ request_messages!(
     (RenameDevServer, Ack),
     (RestartLanguageServers, Ack),
     (OpenContext, OpenContextResponse),
+    (CreateContext, CreateContextResponse),
     (SynchronizeContexts, SynchronizeContextsResponse),
     (AddWorktree, AddWorktreeResponse),
 );
@@ -589,6 +592,7 @@ entity_messages!(
     LspExtExpandMacro,
     AdvertiseContexts,
     OpenContext,
+    CreateContext,
     UpdateContext,
     SynchronizeContexts,
 );