From b5fe0d72ee4328d9d7a6d9cf9ec0e371a7d0a3f1 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 2 Nov 2023 09:34:18 -0400 Subject: [PATCH 1/2] authenticate with completion provider on new inline assists --- crates/assistant/src/assistant_panel.rs | 15 +++++++++------ crates/assistant/src/codegen.rs | 14 ++++++++------ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 03eb3c238f98d55e7bd40793dafa931244bcc073..022c22879050b275d525c2170b0c4d436e301c9f 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -259,7 +259,13 @@ impl AssistantPanel { cx: &mut ViewContext, ) { let this = if let Some(this) = workspace.panel::(cx) { - if this.update(cx, |assistant, _| assistant.has_credentials()) { + if this.update(cx, |assistant, cx| { + if !assistant.has_credentials() { + assistant.load_credentials(cx); + }; + + assistant.has_credentials() + }) { this } else { workspace.focus_panel::(cx); @@ -320,13 +326,10 @@ impl AssistantPanel { }; let inline_assist_id = post_inc(&mut self.next_inline_assist_id); - let provider = Arc::new(OpenAICompletionProvider::new( - "gpt-4", - cx.background().clone(), - )); + let provider = self.completion_provider.clone(); // Retrieve Credentials Authenticates the Provider - // provider.retrieve_credentials(cx); + provider.retrieve_credentials(cx); let codegen = cx.add_model(|cx| { Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index f62c91fcb7e8a2f5891e2e93ae9cba7b13b437eb..da7beda2dc5a41e646b296acb15460b100075c80 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -6,7 +6,7 @@ use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use gpui::{Entity, ModelContext, ModelHandle, Task}; use language::{Rope, TransactionId}; use multi_buffer; -use std::{cmp, future, ops::Range, sync::Arc}; +use std::{cmp, future, ops::Range}; pub enum Event { Finished, @@ -20,7 +20,7 @@ pub enum CodegenKind { } pub struct Codegen { - provider: Arc, + provider: Box, buffer: ModelHandle, snapshot: MultiBufferSnapshot, kind: CodegenKind, @@ -40,7 +40,7 @@ impl Codegen { pub fn new( buffer: ModelHandle, kind: CodegenKind, - provider: Arc, + provider: Box, cx: &mut ModelContext, ) -> Self { let snapshot = buffer.read(cx).snapshot(cx); @@ -367,6 +367,8 @@ fn strip_invalid_spans_from_codeblock( #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use ai::test::FakeCompletionProvider; use futures::stream::{self}; @@ -412,7 +414,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); - let provider = Arc::new(FakeCompletionProvider::new()); + let provider = Box::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -478,7 +480,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 6)) }); - let provider = Arc::new(FakeCompletionProvider::new()); + let provider = Box::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -544,7 +546,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 2)) }); - let provider = Arc::new(FakeCompletionProvider::new()); + let provider = Box::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), From d5b6300fd7edb6441516be5004887a5090e5a599 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 2 Nov 2023 10:08:47 -0400 Subject: [PATCH 2/2] moved from Boxes to Arcs for shared access of completion providers across the assistant panel and inline assistant --- crates/ai/src/test.rs | 11 ++++++++++- crates/assistant/src/assistant_panel.rs | 20 ++++++++++---------- crates/assistant/src/codegen.rs | 14 ++++++++------ 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index d4165f3cca897c4adbf11c2babf6038a8d86f0a6..3f331da1175e32da22fa9b2acd6ba31800c85757 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -153,10 +153,17 @@ impl FakeCompletionProvider { pub fn send_completion(&self, completion: impl Into) { let mut tx = self.last_completion_tx.lock(); - tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + + println!("COMPLETION TX: {:?}", &tx); + + let a = tx.as_mut().unwrap(); + a.try_send(completion.into()).unwrap(); + + // tx.as_mut().unwrap().try_send(completion.into()).unwrap(); } pub fn finish_completion(&self) { + println!("FINISHING COMPLETION"); self.last_completion_tx.lock().take().unwrap(); } } @@ -181,8 +188,10 @@ impl CompletionProvider for FakeCompletionProvider { &self, _prompt: Box, ) -> BoxFuture<'static, anyhow::Result>>> { + println!("COMPLETING"); let (tx, rx) = mpsc::channel(1); *self.last_completion_tx.lock() = Some(tx); + println!("TX: {:?}", *self.last_completion_tx.lock()); async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() } fn box_clone(&self) -> Box { diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 022c22879050b275d525c2170b0c4d436e301c9f..6ab96093a74e3f30ee44b21c396eb76a41a1e179 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -142,7 +142,7 @@ pub struct AssistantPanel { zoomed: bool, has_focus: bool, toolbar: ViewHandle, - completion_provider: Box, + completion_provider: Arc, api_key_editor: Option>, languages: Arc, fs: Arc, @@ -204,7 +204,7 @@ impl AssistantPanel { let semantic_index = SemanticIndex::global(cx); // Defaulting currently to GPT4, allow for this to be set via config. - let completion_provider = Box::new(OpenAICompletionProvider::new( + let completion_provider = Arc::new(OpenAICompletionProvider::new( "gpt-4", cx.background().clone(), )); @@ -1442,7 +1442,7 @@ struct Conversation { pending_save: Task>, path: Option, _subscriptions: Vec, - completion_provider: Box, + completion_provider: Arc, } impl Entity for Conversation { @@ -1453,7 +1453,7 @@ impl Conversation { fn new( language_registry: Arc, cx: &mut ModelContext, - completion_provider: Box, + completion_provider: Arc, ) -> Self { let markdown = language_registry.language_for_name("Markdown"); let buffer = cx.add_model(|cx| { @@ -1547,7 +1547,7 @@ impl Conversation { None => Some(Uuid::new_v4().to_string()), }; let model = saved_conversation.model; - let completion_provider: Box = Box::new( + let completion_provider: Arc = Arc::new( OpenAICompletionProvider::new(model.full_name(), cx.background().clone()), ); completion_provider.retrieve_credentials(cx); @@ -2204,7 +2204,7 @@ struct ConversationEditor { impl ConversationEditor { fn new( - completion_provider: Box, + completion_provider: Arc, language_registry: Arc, fs: Arc, workspace: WeakViewHandle, @@ -3409,7 +3409,7 @@ mod tests { init(cx); let registry = Arc::new(LanguageRegistry::test()); - let completion_provider = Box::new(FakeCompletionProvider::new()); + let completion_provider = Arc::new(FakeCompletionProvider::new()); let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); @@ -3538,7 +3538,7 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let completion_provider = Box::new(FakeCompletionProvider::new()); + let completion_provider = Arc::new(FakeCompletionProvider::new()); let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); @@ -3636,7 +3636,7 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let completion_provider = Box::new(FakeCompletionProvider::new()); + let completion_provider = Arc::new(FakeCompletionProvider::new()); let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); @@ -3719,7 +3719,7 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let completion_provider = Box::new(FakeCompletionProvider::new()); + let completion_provider = Arc::new(FakeCompletionProvider::new()); let conversation = cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index da7beda2dc5a41e646b296acb15460b100075c80..25c9deef7f6ef972dd67911046f430e63bb48d40 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -6,7 +6,7 @@ use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use gpui::{Entity, ModelContext, ModelHandle, Task}; use language::{Rope, TransactionId}; use multi_buffer; -use std::{cmp, future, ops::Range}; +use std::{cmp, future, ops::Range, sync::Arc}; pub enum Event { Finished, @@ -20,7 +20,7 @@ pub enum CodegenKind { } pub struct Codegen { - provider: Box, + provider: Arc, buffer: ModelHandle, snapshot: MultiBufferSnapshot, kind: CodegenKind, @@ -40,7 +40,7 @@ impl Codegen { pub fn new( buffer: ModelHandle, kind: CodegenKind, - provider: Box, + provider: Arc, cx: &mut ModelContext, ) -> Self { let snapshot = buffer.read(cx).snapshot(cx); @@ -414,7 +414,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); - let provider = Box::new(FakeCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -439,6 +439,7 @@ mod tests { let max_len = cmp::min(new_text.len(), 10); let len = rng.gen_range(1..=max_len); let (chunk, suffix) = new_text.split_at(len); + println!("CHUNK: {:?}", &chunk); provider.send_completion(chunk); new_text = suffix; deterministic.run_until_parked(); @@ -480,7 +481,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 6)) }); - let provider = Box::new(FakeCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -546,7 +547,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 2)) }); - let provider = Box::new(FakeCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -571,6 +572,7 @@ mod tests { let max_len = cmp::min(new_text.len(), 10); let len = rng.gen_range(1..=max_len); let (chunk, suffix) = new_text.split_at(len); + println!("{:?}", &chunk); provider.send_completion(chunk); new_text = suffix; deterministic.run_until_parked();