clean up warnings and fix tests in the ai crate

KCaverly created

Change summary

crates/ai/src/completion.rs                   |   7 
crates/ai/src/prompts/base.rs                 |   4 
crates/ai/src/providers/open_ai/completion.rs |   8 
crates/ai/src/test.rs                         |  14 +
crates/assistant/src/assistant_panel.rs       | 214 +++++++-------------
5 files changed, 103 insertions(+), 144 deletions(-)

Detailed changes

crates/ai/src/completion.rs 🔗

@@ -13,4 +13,11 @@ pub trait CompletionProvider: CredentialProvider {
         &self,
         prompt: Box<dyn CompletionRequest>,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+    fn box_clone(&self) -> Box<dyn CompletionProvider>;
+}
+
+impl Clone for Box<dyn CompletionProvider> {
+    fn clone(&self) -> Box<dyn CompletionProvider> {
+        self.box_clone()
+    }
 }

crates/ai/src/prompts/base.rs 🔗

@@ -147,7 +147,7 @@ pub(crate) mod tests {
                         content = args.model.truncate(
                             &content,
                             max_token_length,
-                            TruncationDirection::Start,
+                            TruncationDirection::End,
                         )?;
                         token_count = max_token_length;
                     }
@@ -172,7 +172,7 @@ pub(crate) mod tests {
                         content = args.model.truncate(
                             &content,
                             max_token_length,
-                            TruncationDirection::Start,
+                            TruncationDirection::End,
                         )?;
                         token_count = max_token_length;
                     }

crates/ai/src/providers/open_ai/completion.rs 🔗

@@ -193,6 +193,7 @@ pub async fn stream_completion(
     }
 }
 
+#[derive(Clone)]
 pub struct OpenAICompletionProvider {
     model: OpenAILanguageModel,
     credential: Arc<RwLock<ProviderCredential>>,
@@ -271,6 +272,10 @@ impl CompletionProvider for OpenAICompletionProvider {
         &self,
         prompt: Box<dyn CompletionRequest>,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
+        // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
+        // which is currently model based, due to the langauge model.
+        // At some point in the future we should rectify this.
         let credential = self.credential.read().clone();
         let request = stream_completion(credential, self.executor.clone(), prompt);
         async move {
@@ -287,4 +292,7 @@ impl CompletionProvider for OpenAICompletionProvider {
         }
         .boxed()
     }
+    fn box_clone(&self) -> Box<dyn CompletionProvider> {
+        Box::new((*self).clone())
+    }
 }

crates/ai/src/test.rs 🔗

@@ -33,7 +33,10 @@ impl LanguageModel for FakeLanguageModel {
         length: usize,
         direction: TruncationDirection,
     ) -> anyhow::Result<String> {
+        println!("TRYING TO TRUNCATE: {:?}", length.clone());
+
         if length > self.count_tokens(content)? {
+            println!("NOT TRUNCATING");
             return anyhow::Ok(content.to_string());
         }
 
@@ -133,6 +136,14 @@ pub struct FakeCompletionProvider {
     last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
 }
 
+impl Clone for FakeCompletionProvider {
+    fn clone(&self) -> Self {
+        Self {
+            last_completion_tx: Mutex::new(None),
+        }
+    }
+}
+
 impl FakeCompletionProvider {
     pub fn new() -> Self {
         Self {
@@ -174,4 +185,7 @@ impl CompletionProvider for FakeCompletionProvider {
         *self.last_completion_tx.lock() = Some(tx);
         async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
     }
+    fn box_clone(&self) -> Box<dyn CompletionProvider> {
+        Box::new((*self).clone())
+    }
 }

crates/assistant/src/assistant_panel.rs 🔗

@@ -9,9 +9,7 @@ use crate::{
 use ai::{
     auth::ProviderCredential,
     completion::{CompletionProvider, CompletionRequest},
-    providers::open_ai::{
-        stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage,
-    },
+    providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage},
 };
 
 use ai::prompts::repository_context::PromptCodeSnippet;
@@ -47,7 +45,7 @@ use search::BufferSearchBar;
 use semantic_index::{SemanticIndex, SemanticIndexStatus};
 use settings::SettingsStore;
 use std::{
-    cell::{Cell, RefCell},
+    cell::Cell,
     cmp,
     fmt::Write,
     iter,
@@ -144,10 +142,8 @@ pub struct AssistantPanel {
     zoomed: bool,
     has_focus: bool,
     toolbar: ViewHandle<Toolbar>,
-    credential: Rc<RefCell<ProviderCredential>>,
     completion_provider: Box<dyn CompletionProvider>,
     api_key_editor: Option<ViewHandle<Editor>>,
-    has_read_credentials: bool,
     languages: Arc<LanguageRegistry>,
     fs: Arc<dyn Fs>,
     subscriptions: Vec<Subscription>,
@@ -223,10 +219,8 @@ impl AssistantPanel {
                         zoomed: false,
                         has_focus: false,
                         toolbar,
-                        credential: Rc::new(RefCell::new(ProviderCredential::NoCredentials)),
                         completion_provider,
                         api_key_editor: None,
-                        has_read_credentials: false,
                         languages: workspace.app_state().languages.clone(),
                         fs: workspace.app_state().fs.clone(),
                         width: None,
@@ -265,7 +259,7 @@ impl AssistantPanel {
         cx: &mut ViewContext<Workspace>,
     ) {
         let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
-            if this.update(cx, |assistant, cx| assistant.has_credentials(cx)) {
+            if this.update(cx, |assistant, _| assistant.has_credentials()) {
                 this
             } else {
                 workspace.focus_panel::<AssistantPanel>(cx);
@@ -331,6 +325,9 @@ impl AssistantPanel {
             cx.background().clone(),
         ));
 
+        // Retrieve Credentials Authenticates the Provider
+        // provider.retrieve_credentials(cx);
+
         let codegen = cx.add_model(|cx| {
             Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
         });
@@ -814,7 +811,7 @@ impl AssistantPanel {
     fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
         let editor = cx.add_view(|cx| {
             ConversationEditor::new(
-                self.credential.clone(),
+                self.completion_provider.clone(),
                 self.languages.clone(),
                 self.fs.clone(),
                 self.workspace.clone(),
@@ -883,9 +880,8 @@ impl AssistantPanel {
                 let credential = ProviderCredential::Credentials {
                     api_key: api_key.clone(),
                 };
-                self.completion_provider
-                    .save_credentials(cx, credential.clone());
-                *self.credential.borrow_mut() = credential;
+
+                self.completion_provider.save_credentials(cx, credential);
 
                 self.api_key_editor.take();
                 cx.focus_self();
@@ -898,7 +894,6 @@ impl AssistantPanel {
 
     fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
         self.completion_provider.delete_credentials(cx);
-        *self.credential.borrow_mut() = ProviderCredential::NoCredentials;
         self.api_key_editor = Some(build_api_key_editor(cx));
         cx.focus_self();
         cx.notify();
@@ -1157,19 +1152,12 @@ impl AssistantPanel {
 
         let fs = self.fs.clone();
         let workspace = self.workspace.clone();
-        let credential = self.credential.clone();
         let languages = self.languages.clone();
         cx.spawn(|this, mut cx| async move {
             let saved_conversation = fs.load(&path).await?;
             let saved_conversation = serde_json::from_str(&saved_conversation)?;
             let conversation = cx.add_model(|cx| {
-                Conversation::deserialize(
-                    saved_conversation,
-                    path.clone(),
-                    credential,
-                    languages,
-                    cx,
-                )
+                Conversation::deserialize(saved_conversation, path.clone(), languages, cx)
             });
             this.update(&mut cx, |this, cx| {
                 // If, by the time we've loaded the conversation, the user has already opened
@@ -1193,39 +1181,12 @@ impl AssistantPanel {
             .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
     }
 
-    fn has_credentials(&mut self, cx: &mut ViewContext<Self>) -> bool {
-        let credential = self.load_credentials(cx);
-        match credential {
-            ProviderCredential::Credentials { .. } => true,
-            ProviderCredential::NotNeeded => true,
-            ProviderCredential::NoCredentials => false,
-        }
+    fn has_credentials(&mut self) -> bool {
+        self.completion_provider.has_credentials()
     }
 
-    fn load_credentials(&mut self, cx: &mut ViewContext<Self>) -> ProviderCredential {
-        let existing_credential = self.credential.clone();
-        let existing_credential = existing_credential.borrow().clone();
-        match existing_credential {
-            ProviderCredential::NoCredentials => {
-                if !self.has_read_credentials {
-                    self.has_read_credentials = true;
-                    let retrieved_credentials = self.completion_provider.retrieve_credentials(cx);
-
-                    match retrieved_credentials {
-                        ProviderCredential::NoCredentials {} => {
-                            self.api_key_editor = Some(build_api_key_editor(cx));
-                            cx.notify();
-                        }
-                        _ => {
-                            *self.credential.borrow_mut() = retrieved_credentials;
-                        }
-                    }
-                }
-            }
-            _ => {}
-        }
-
-        self.credential.borrow().clone()
+    fn load_credentials(&mut self, cx: &mut ViewContext<Self>) {
+        self.completion_provider.retrieve_credentials(cx);
     }
 }
 
@@ -1475,10 +1436,10 @@ struct Conversation {
     token_count: Option<usize>,
     max_token_count: usize,
     pending_token_count: Task<Option<()>>,
-    credential: Rc<RefCell<ProviderCredential>>,
     pending_save: Task<Result<()>>,
     path: Option<PathBuf>,
     _subscriptions: Vec<Subscription>,
+    completion_provider: Box<dyn CompletionProvider>,
 }
 
 impl Entity for Conversation {
@@ -1487,10 +1448,9 @@ impl Entity for Conversation {
 
 impl Conversation {
     fn new(
-        credential: Rc<RefCell<ProviderCredential>>,
-
         language_registry: Arc<LanguageRegistry>,
         cx: &mut ModelContext<Self>,
+        completion_provider: Box<dyn CompletionProvider>,
     ) -> Self {
         let markdown = language_registry.language_for_name("Markdown");
         let buffer = cx.add_model(|cx| {
@@ -1529,8 +1489,8 @@ impl Conversation {
             _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
             pending_save: Task::ready(Ok(())),
             path: None,
-            credential,
             buffer,
+            completion_provider,
         };
         let message = MessageAnchor {
             id: MessageId(post_inc(&mut this.next_message_id.0)),
@@ -1576,7 +1536,6 @@ impl Conversation {
     fn deserialize(
         saved_conversation: SavedConversation,
         path: PathBuf,
-        credential: Rc<RefCell<ProviderCredential>>,
         language_registry: Arc<LanguageRegistry>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
@@ -1585,6 +1544,10 @@ impl Conversation {
             None => Some(Uuid::new_v4().to_string()),
         };
         let model = saved_conversation.model;
+        let completion_provider: Box<dyn CompletionProvider> = Box::new(
+            OpenAICompletionProvider::new(model.full_name(), cx.background().clone()),
+        );
+        completion_provider.retrieve_credentials(cx);
         let markdown = language_registry.language_for_name("Markdown");
         let mut message_anchors = Vec::new();
         let mut next_message_id = MessageId(0);
@@ -1631,8 +1594,8 @@ impl Conversation {
             _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
             pending_save: Task::ready(Ok(())),
             path: Some(path),
-            credential,
             buffer,
+            completion_provider,
         };
         this.count_remaining_tokens(cx);
         this
@@ -1753,12 +1716,8 @@ impl Conversation {
         }
 
         if should_assist {
-            let credential = self.credential.borrow().clone();
-            match credential {
-                ProviderCredential::NoCredentials => {
-                    return Default::default();
-                }
-                _ => {}
+            if !self.completion_provider.has_credentials() {
+                return Default::default();
             }
 
             let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
@@ -1773,7 +1732,7 @@ impl Conversation {
                 temperature: 1.0,
             });
 
-            let stream = stream_completion(credential, cx.background().clone(), request);
+            let stream = self.completion_provider.complete(request);
             let assistant_message = self
                 .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
                 .unwrap();
@@ -1791,33 +1750,28 @@ impl Conversation {
                         let mut messages = stream.await?;
 
                         while let Some(message) = messages.next().await {
-                            let mut message = message?;
-                            if let Some(choice) = message.choices.pop() {
-                                this.upgrade(&cx)
-                                    .ok_or_else(|| anyhow!("conversation was dropped"))?
-                                    .update(&mut cx, |this, cx| {
-                                        let text: Arc<str> = choice.delta.content?.into();
-                                        let message_ix =
-                                            this.message_anchors.iter().position(|message| {
-                                                message.id == assistant_message_id
-                                            })?;
-                                        this.buffer.update(cx, |buffer, cx| {
-                                            let offset = this.message_anchors[message_ix + 1..]
-                                                .iter()
-                                                .find(|message| message.start.is_valid(buffer))
-                                                .map_or(buffer.len(), |message| {
-                                                    message
-                                                        .start
-                                                        .to_offset(buffer)
-                                                        .saturating_sub(1)
-                                                });
-                                            buffer.edit([(offset..offset, text)], None, cx);
-                                        });
-                                        cx.emit(ConversationEvent::StreamedCompletion);
-
-                                        Some(())
+                            let text = message?;
+
+                            this.upgrade(&cx)
+                                .ok_or_else(|| anyhow!("conversation was dropped"))?
+                                .update(&mut cx, |this, cx| {
+                                    let message_ix = this
+                                        .message_anchors
+                                        .iter()
+                                        .position(|message| message.id == assistant_message_id)?;
+                                    this.buffer.update(cx, |buffer, cx| {
+                                        let offset = this.message_anchors[message_ix + 1..]
+                                            .iter()
+                                            .find(|message| message.start.is_valid(buffer))
+                                            .map_or(buffer.len(), |message| {
+                                                message.start.to_offset(buffer).saturating_sub(1)
+                                            });
+                                        buffer.edit([(offset..offset, text)], None, cx);
                                     });
-                            }
+                                    cx.emit(ConversationEvent::StreamedCompletion);
+
+                                    Some(())
+                                });
                             smol::future::yield_now().await;
                         }
 
@@ -2039,13 +1993,8 @@ impl Conversation {
 
     fn summarize(&mut self, cx: &mut ModelContext<Self>) {
         if self.message_anchors.len() >= 2 && self.summary.is_none() {
-            let credential = self.credential.borrow().clone();
-
-            match credential {
-                ProviderCredential::NoCredentials => {
-                    return;
-                }
-                _ => {}
+            if !self.completion_provider.has_credentials() {
+                return;
             }
 
             let messages = self
@@ -2065,23 +2014,20 @@ impl Conversation {
                 temperature: 1.0,
             });
 
-            let stream = stream_completion(credential, cx.background().clone(), request);
+            let stream = self.completion_provider.complete(request);
             self.pending_summary = cx.spawn(|this, mut cx| {
                 async move {
                     let mut messages = stream.await?;
 
                     while let Some(message) = messages.next().await {
-                        let mut message = message?;
-                        if let Some(choice) = message.choices.pop() {
-                            let text = choice.delta.content.unwrap_or_default();
-                            this.update(&mut cx, |this, cx| {
-                                this.summary
-                                    .get_or_insert(Default::default())
-                                    .text
-                                    .push_str(&text);
-                                cx.emit(ConversationEvent::SummaryChanged);
-                            });
-                        }
+                        let text = message?;
+                        this.update(&mut cx, |this, cx| {
+                            this.summary
+                                .get_or_insert(Default::default())
+                                .text
+                                .push_str(&text);
+                            cx.emit(ConversationEvent::SummaryChanged);
+                        });
                     }
 
                     this.update(&mut cx, |this, cx| {
@@ -2255,13 +2201,14 @@ struct ConversationEditor {
 
 impl ConversationEditor {
     fn new(
-        credential: Rc<RefCell<ProviderCredential>>,
+        completion_provider: Box<dyn CompletionProvider>,
         language_registry: Arc<LanguageRegistry>,
         fs: Arc<dyn Fs>,
         workspace: WeakViewHandle<Workspace>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
-        let conversation = cx.add_model(|cx| Conversation::new(credential, language_registry, cx));
+        let conversation =
+            cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider));
         Self::for_conversation(conversation, fs, workspace, cx)
     }
 
@@ -3450,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
 mod tests {
     use super::*;
     use crate::MessageId;
+    use ai::test::FakeCompletionProvider;
     use gpui::AppContext;
 
     #[gpui::test]
@@ -3457,13 +3405,9 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let conversation = cx.add_model(|cx| {
-            Conversation::new(
-                Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
-                registry,
-                cx,
-            )
-        });
+
+        let completion_provider = Box::new(FakeCompletionProvider::new());
+        let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
         let buffer = conversation.read(cx).buffer.clone();
 
         let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3591,13 +3535,9 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let conversation = cx.add_model(|cx| {
-            Conversation::new(
-                Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
-                registry,
-                cx,
-            )
-        });
+        let completion_provider = Box::new(FakeCompletionProvider::new());
+
+        let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
         let buffer = conversation.read(cx).buffer.clone();
 
         let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3693,13 +3633,8 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let conversation = cx.add_model(|cx| {
-            Conversation::new(
-                Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
-                registry,
-                cx,
-            )
-        });
+        let completion_provider = Box::new(FakeCompletionProvider::new());
+        let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
         let buffer = conversation.read(cx).buffer.clone();
 
         let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3781,13 +3716,9 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let conversation = cx.add_model(|cx| {
-            Conversation::new(
-                Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
-                registry.clone(),
-                cx,
-            )
-        });
+        let completion_provider = Box::new(FakeCompletionProvider::new());
+        let conversation =
+            cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
         let buffer = conversation.read(cx).buffer.clone();
         let message_0 = conversation.read(cx).message_anchors[0].id;
         let message_1 = conversation.update(cx, |conversation, cx| {
@@ -3824,7 +3755,6 @@ mod tests {
             Conversation::deserialize(
                 conversation.read(cx).serialize(cx),
                 Default::default(),
-                Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
                 registry.clone(),
                 cx,
             )