moved from Boxes to Arcs for shared access of completion providers across the assistant panel and inline assistant

KCaverly created

Change summary

crates/ai/src/test.rs                   | 11 ++++++++++-
crates/assistant/src/assistant_panel.rs | 20 ++++++++++----------
crates/assistant/src/codegen.rs         | 14 ++++++++------
3 files changed, 28 insertions(+), 17 deletions(-)

Detailed changes

crates/ai/src/test.rs 🔗

@@ -153,10 +153,17 @@ impl FakeCompletionProvider {
 
     pub fn send_completion(&self, completion: impl Into<String>) {
         let mut tx = self.last_completion_tx.lock();
-        tx.as_mut().unwrap().try_send(completion.into()).unwrap();
+
+        println!("COMPLETION TX: {:?}", &tx);
+
+        let a = tx.as_mut().unwrap();
+        a.try_send(completion.into()).unwrap();
+
+        // tx.as_mut().unwrap().try_send(completion.into()).unwrap();
     }
 
     pub fn finish_completion(&self) {
+        println!("FINISHING COMPLETION");
         self.last_completion_tx.lock().take().unwrap();
     }
 }
@@ -181,8 +188,10 @@ impl CompletionProvider for FakeCompletionProvider {
         &self,
         _prompt: Box<dyn CompletionRequest>,
     ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
+        println!("COMPLETING");
         let (tx, rx) = mpsc::channel(1);
         *self.last_completion_tx.lock() = Some(tx);
+        println!("TX: {:?}", *self.last_completion_tx.lock());
         async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
     }
     fn box_clone(&self) -> Box<dyn CompletionProvider> {

crates/assistant/src/assistant_panel.rs 🔗

@@ -142,7 +142,7 @@ pub struct AssistantPanel {
     zoomed: bool,
     has_focus: bool,
     toolbar: ViewHandle<Toolbar>,
-    completion_provider: Box<dyn CompletionProvider>,
+    completion_provider: Arc<dyn CompletionProvider>,
     api_key_editor: Option<ViewHandle<Editor>>,
     languages: Arc<LanguageRegistry>,
     fs: Arc<dyn Fs>,
@@ -204,7 +204,7 @@ impl AssistantPanel {
 
                     let semantic_index = SemanticIndex::global(cx);
                     // Defaulting currently to GPT4, allow for this to be set via config.
-                    let completion_provider = Box::new(OpenAICompletionProvider::new(
+                    let completion_provider = Arc::new(OpenAICompletionProvider::new(
                         "gpt-4",
                         cx.background().clone(),
                     ));
@@ -1442,7 +1442,7 @@ struct Conversation {
     pending_save: Task<Result<()>>,
     path: Option<PathBuf>,
     _subscriptions: Vec<Subscription>,
-    completion_provider: Box<dyn CompletionProvider>,
+    completion_provider: Arc<dyn CompletionProvider>,
 }
 
 impl Entity for Conversation {
@@ -1453,7 +1453,7 @@ impl Conversation {
     fn new(
         language_registry: Arc<LanguageRegistry>,
         cx: &mut ModelContext<Self>,
-        completion_provider: Box<dyn CompletionProvider>,
+        completion_provider: Arc<dyn CompletionProvider>,
     ) -> Self {
         let markdown = language_registry.language_for_name("Markdown");
         let buffer = cx.add_model(|cx| {
@@ -1547,7 +1547,7 @@ impl Conversation {
             None => Some(Uuid::new_v4().to_string()),
         };
         let model = saved_conversation.model;
-        let completion_provider: Box<dyn CompletionProvider> = Box::new(
+        let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
             OpenAICompletionProvider::new(model.full_name(), cx.background().clone()),
         );
         completion_provider.retrieve_credentials(cx);
@@ -2204,7 +2204,7 @@ struct ConversationEditor {
 
 impl ConversationEditor {
     fn new(
-        completion_provider: Box<dyn CompletionProvider>,
+        completion_provider: Arc<dyn CompletionProvider>,
         language_registry: Arc<LanguageRegistry>,
         fs: Arc<dyn Fs>,
         workspace: WeakViewHandle<Workspace>,
@@ -3409,7 +3409,7 @@ mod tests {
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
 
-        let completion_provider = Box::new(FakeCompletionProvider::new());
+        let completion_provider = Arc::new(FakeCompletionProvider::new());
         let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
         let buffer = conversation.read(cx).buffer.clone();
 
@@ -3538,7 +3538,7 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let completion_provider = Box::new(FakeCompletionProvider::new());
+        let completion_provider = Arc::new(FakeCompletionProvider::new());
 
         let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
         let buffer = conversation.read(cx).buffer.clone();
@@ -3636,7 +3636,7 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let completion_provider = Box::new(FakeCompletionProvider::new());
+        let completion_provider = Arc::new(FakeCompletionProvider::new());
         let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
         let buffer = conversation.read(cx).buffer.clone();
 
@@ -3719,7 +3719,7 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let completion_provider = Box::new(FakeCompletionProvider::new());
+        let completion_provider = Arc::new(FakeCompletionProvider::new());
         let conversation =
             cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
         let buffer = conversation.read(cx).buffer.clone();

crates/assistant/src/codegen.rs 🔗

@@ -6,7 +6,7 @@ use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
 use gpui::{Entity, ModelContext, ModelHandle, Task};
 use language::{Rope, TransactionId};
 use multi_buffer;
-use std::{cmp, future, ops::Range};
+use std::{cmp, future, ops::Range, sync::Arc};
 
 pub enum Event {
     Finished,
@@ -20,7 +20,7 @@ pub enum CodegenKind {
 }
 
 pub struct Codegen {
-    provider: Box<dyn CompletionProvider>,
+    provider: Arc<dyn CompletionProvider>,
     buffer: ModelHandle<MultiBuffer>,
     snapshot: MultiBufferSnapshot,
     kind: CodegenKind,
@@ -40,7 +40,7 @@ impl Codegen {
     pub fn new(
         buffer: ModelHandle<MultiBuffer>,
         kind: CodegenKind,
-        provider: Box<dyn CompletionProvider>,
+        provider: Arc<dyn CompletionProvider>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
         let snapshot = buffer.read(cx).snapshot(cx);
@@ -414,7 +414,7 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
         });
-        let provider = Box::new(FakeCompletionProvider::new());
+        let provider = Arc::new(FakeCompletionProvider::new());
         let codegen = cx.add_model(|cx| {
             Codegen::new(
                 buffer.clone(),
@@ -439,6 +439,7 @@ mod tests {
             let max_len = cmp::min(new_text.len(), 10);
             let len = rng.gen_range(1..=max_len);
             let (chunk, suffix) = new_text.split_at(len);
+            println!("CHUNK: {:?}", &chunk);
             provider.send_completion(chunk);
             new_text = suffix;
             deterministic.run_until_parked();
@@ -480,7 +481,7 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 6))
         });
-        let provider = Box::new(FakeCompletionProvider::new());
+        let provider = Arc::new(FakeCompletionProvider::new());
         let codegen = cx.add_model(|cx| {
             Codegen::new(
                 buffer.clone(),
@@ -546,7 +547,7 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 2))
         });
-        let provider = Box::new(FakeCompletionProvider::new());
+        let provider = Arc::new(FakeCompletionProvider::new());
         let codegen = cx.add_model(|cx| {
             Codegen::new(
                 buffer.clone(),
@@ -571,6 +572,7 @@ mod tests {
             let max_len = cmp::min(new_text.len(), 10);
             let len = rng.gen_range(1..=max_len);
             let (chunk, suffix) = new_text.split_at(len);
+            println!("{:?}", &chunk);
             provider.send_completion(chunk);
             new_text = suffix;
             deterministic.run_until_parked();