From 80469283eed951fa60b31cf48a021997ee4ee42f Mon Sep 17 00:00:00 2001 From: Kyle Caverly Date: Thu, 2 Nov 2023 10:14:52 -0400 Subject: [PATCH] authenticate with completion provider on new inline assists (#3209) authenticate with completion provider on new inline assists Release Notes: - Fixed bug which lead the inline assist functionality to never authenticate --- crates/ai/src/test.rs | 11 +++++++- crates/assistant/src/assistant_panel.rs | 35 ++++++++++++++----------- crates/assistant/src/codegen.rs | 4 +++ 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index d4165f3cca897c4adbf11c2babf6038a8d86f0a6..3f331da1175e32da22fa9b2acd6ba31800c85757 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -153,10 +153,17 @@ impl FakeCompletionProvider { pub fn send_completion(&self, completion: impl Into) { let mut tx = self.last_completion_tx.lock(); - tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + + println!("COMPLETION TX: {:?}", &tx); + + let a = tx.as_mut().unwrap(); + a.try_send(completion.into()).unwrap(); + + // tx.as_mut().unwrap().try_send(completion.into()).unwrap(); } pub fn finish_completion(&self) { + println!("FINISHING COMPLETION"); self.last_completion_tx.lock().take().unwrap(); } } @@ -181,8 +188,10 @@ impl CompletionProvider for FakeCompletionProvider { &self, _prompt: Box, ) -> BoxFuture<'static, anyhow::Result>>> { + println!("COMPLETING"); let (tx, rx) = mpsc::channel(1); *self.last_completion_tx.lock() = Some(tx); + println!("TX: {:?}", *self.last_completion_tx.lock()); async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() } fn box_clone(&self) -> Box { diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 03eb3c238f98d55e7bd40793dafa931244bcc073..6ab96093a74e3f30ee44b21c396eb76a41a1e179 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -142,7 +142,7 @@ pub struct AssistantPanel { zoomed: bool, has_focus: bool, toolbar: ViewHandle, - completion_provider: Box, + completion_provider: Arc, api_key_editor: Option>, languages: Arc, fs: Arc, @@ -204,7 +204,7 @@ impl AssistantPanel { let semantic_index = SemanticIndex::global(cx); // Defaulting currently to GPT4, allow for this to be set via config. - let completion_provider = Box::new(OpenAICompletionProvider::new( + let completion_provider = Arc::new(OpenAICompletionProvider::new( "gpt-4", cx.background().clone(), )); @@ -259,7 +259,13 @@ impl AssistantPanel { cx: &mut ViewContext, ) { let this = if let Some(this) = workspace.panel::(cx) { - if this.update(cx, |assistant, _| assistant.has_credentials()) { + if this.update(cx, |assistant, cx| { + if !assistant.has_credentials() { + assistant.load_credentials(cx); + }; + + assistant.has_credentials() + }) { this } else { workspace.focus_panel::(cx); @@ -320,13 +326,10 @@ impl AssistantPanel { }; let inline_assist_id = post_inc(&mut self.next_inline_assist_id); - let provider = Arc::new(OpenAICompletionProvider::new( - "gpt-4", - cx.background().clone(), - )); + let provider = self.completion_provider.clone(); // Retrieve Credentials Authenticates the Provider - // provider.retrieve_credentials(cx); + provider.retrieve_credentials(cx); let codegen = cx.add_model(|cx| { Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) @@ -1439,7 +1442,7 @@ struct Conversation { pending_save: Task>, path: Option, _subscriptions: Vec, - completion_provider: Box, + completion_provider: Arc, } impl Entity for Conversation { @@ -1450,7 +1453,7 @@ impl Conversation { fn new( language_registry: Arc, cx: &mut ModelContext, - completion_provider: Box, + completion_provider: Arc, ) -> Self { let markdown = language_registry.language_for_name("Markdown"); let buffer = cx.add_model(|cx| { @@ -1544,7 +1547,7 @@ impl Conversation { None => Some(Uuid::new_v4().to_string()), }; let model = saved_conversation.model; - let completion_provider: Box = Box::new( + let completion_provider: Arc = Arc::new( OpenAICompletionProvider::new(model.full_name(), cx.background().clone()), ); completion_provider.retrieve_credentials(cx); @@ -2201,7 +2204,7 @@ struct ConversationEditor { impl ConversationEditor { fn new( - completion_provider: Box, + completion_provider: Arc, language_registry: Arc, fs: Arc, workspace: WeakViewHandle, @@ -3406,7 +3409,7 @@ mod tests { init(cx); let registry = Arc::new(LanguageRegistry::test()); - let completion_provider = Box::new(FakeCompletionProvider::new()); + let completion_provider = Arc::new(FakeCompletionProvider::new()); let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); @@ -3535,7 +3538,7 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let completion_provider = Box::new(FakeCompletionProvider::new()); + let completion_provider = Arc::new(FakeCompletionProvider::new()); let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); @@ -3633,7 +3636,7 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let completion_provider = Box::new(FakeCompletionProvider::new()); + let completion_provider = Arc::new(FakeCompletionProvider::new()); let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); @@ -3716,7 +3719,7 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let completion_provider = Box::new(FakeCompletionProvider::new()); + let completion_provider = Arc::new(FakeCompletionProvider::new()); let conversation = cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index f62c91fcb7e8a2f5891e2e93ae9cba7b13b437eb..25c9deef7f6ef972dd67911046f430e63bb48d40 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -367,6 +367,8 @@ fn strip_invalid_spans_from_codeblock( #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use ai::test::FakeCompletionProvider; use futures::stream::{self}; @@ -437,6 +439,7 @@ mod tests { let max_len = cmp::min(new_text.len(), 10); let len = rng.gen_range(1..=max_len); let (chunk, suffix) = new_text.split_at(len); + println!("CHUNK: {:?}", &chunk); provider.send_completion(chunk); new_text = suffix; deterministic.run_until_parked(); @@ -569,6 +572,7 @@ mod tests { let max_len = cmp::min(new_text.len(), 10); let len = rng.gen_range(1..=max_len); let (chunk, suffix) = new_text.split_at(len); + println!("{:?}", &chunk); provider.send_completion(chunk); new_text = suffix; deterministic.run_until_parked();