replace api_key with ProviderCredential throughout the AssistantPanel

KCaverly created

Change summary

crates/ai/src/auth.rs                         |   4 
crates/ai/src/completion.rs                   |   6 
crates/ai/src/providers/open_ai/auth.rs       |  13 
crates/ai/src/providers/open_ai/completion.rs |  24 +
crates/assistant/src/assistant_panel.rs       | 276 ++++++++++++--------
5 files changed, 205 insertions(+), 118 deletions(-)

Detailed changes

crates/ai/src/auth.rs 🔗

@@ -9,6 +9,8 @@ pub enum ProviderCredential {
 
 pub trait CredentialProvider: Send + Sync {
     fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
+    fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
+    fn delete_credentials(&self, cx: &AppContext);
 }
 
 #[derive(Clone)]
@@ -17,4 +19,6 @@ impl CredentialProvider for NullCredentialProvider {
     fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
         ProviderCredential::NotNeeded
     }
+    fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {}
+    fn delete_credentials(&self, cx: &AppContext) {}
 }

crates/ai/src/completion.rs 🔗

@@ -17,6 +17,12 @@ pub trait CompletionProvider {
     fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
         self.credential_provider().retrieve_credentials(cx)
     }
+    fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
+        self.credential_provider().save_credentials(cx, credential);
+    }
+    fn delete_credentials(&self, cx: &AppContext) {
+        self.credential_provider().delete_credentials(cx);
+    }
     fn complete(
         &self,
         prompt: Box<dyn CompletionRequest>,

crates/ai/src/providers/open_ai/auth.rs 🔗

@@ -30,4 +30,17 @@ impl CredentialProvider for OpenAICredentialProvider {
             ProviderCredential::NoCredentials
         }
     }
+    fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
+        match credential {
+            ProviderCredential::Credentials { api_key } => {
+                cx.platform()
+                    .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+                    .log_err();
+            }
+            _ => {}
+        }
+    }
+    fn delete_credentials(&self, cx: &AppContext) {
+        cx.platform().delete_credentials(OPENAI_API_URL).log_err();
+    }
 }

crates/ai/src/providers/open_ai/completion.rs 🔗

@@ -13,7 +13,7 @@ use std::{
 };
 
 use crate::{
-    auth::CredentialProvider,
+    auth::{CredentialProvider, ProviderCredential},
     completion::{CompletionProvider, CompletionRequest},
     models::LanguageModel,
 };
@@ -102,10 +102,17 @@ pub struct OpenAIResponseStreamEvent {
 }
 
 pub async fn stream_completion(
-    api_key: String,
+    credential: ProviderCredential,
     executor: Arc<Background>,
     request: Box<dyn CompletionRequest>,
 ) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+    let api_key = match credential {
+        ProviderCredential::Credentials { api_key } => api_key,
+        _ => {
+            return Err(anyhow!("no credentials provider for completion"));
+        }
+    };
+
     let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
 
     let json_data = request.data()?;
@@ -188,18 +195,22 @@ pub async fn stream_completion(
 pub struct OpenAICompletionProvider {
     model: OpenAILanguageModel,
     credential_provider: OpenAICredentialProvider,
-    api_key: String,
+    credential: ProviderCredential,
     executor: Arc<Background>,
 }
 
 impl OpenAICompletionProvider {
-    pub fn new(model_name: &str, api_key: String, executor: Arc<Background>) -> Self {
+    pub fn new(
+        model_name: &str,
+        credential: ProviderCredential,
+        executor: Arc<Background>,
+    ) -> Self {
         let model = OpenAILanguageModel::load(model_name);
         let credential_provider = OpenAICredentialProvider {};
         Self {
             model,
             credential_provider,
-            api_key,
+            credential,
             executor,
         }
     }
@@ -218,7 +229,8 @@ impl CompletionProvider for OpenAICompletionProvider {
         &self,
         prompt: Box<dyn CompletionRequest>,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
-        let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
+        let credential = self.credential.clone();
+        let request = stream_completion(credential, self.executor.clone(), prompt);
         async move {
             let response = request.await?;
             let stream = response

crates/assistant/src/assistant_panel.rs 🔗

@@ -7,7 +7,8 @@ use crate::{
 };
 
 use ai::{
-    completion::CompletionRequest,
+    auth::ProviderCredential,
+    completion::{CompletionProvider, CompletionRequest},
     providers::open_ai::{
         stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
     },
@@ -100,8 +101,8 @@ pub fn init(cx: &mut AppContext) {
     cx.capture_action(ConversationEditor::copy);
     cx.add_action(ConversationEditor::split);
     cx.capture_action(ConversationEditor::cycle_message_role);
-    cx.add_action(AssistantPanel::save_api_key);
-    cx.add_action(AssistantPanel::reset_api_key);
+    cx.add_action(AssistantPanel::save_credentials);
+    cx.add_action(AssistantPanel::reset_credentials);
     cx.add_action(AssistantPanel::toggle_zoom);
     cx.add_action(AssistantPanel::deploy);
     cx.add_action(AssistantPanel::select_next_match);
@@ -143,7 +144,8 @@ pub struct AssistantPanel {
     zoomed: bool,
     has_focus: bool,
     toolbar: ViewHandle<Toolbar>,
-    api_key: Rc<RefCell<Option<String>>>,
+    credential: Rc<RefCell<ProviderCredential>>,
+    completion_provider: Box<dyn CompletionProvider>,
     api_key_editor: Option<ViewHandle<Editor>>,
     has_read_credentials: bool,
     languages: Arc<LanguageRegistry>,
@@ -205,6 +207,12 @@ impl AssistantPanel {
                     });
 
                     let semantic_index = SemanticIndex::global(cx);
+                    // Defaulting currently to GPT4, allow for this to be set via config.
+                    let completion_provider = Box::new(OpenAICompletionProvider::new(
+                        "gpt-4",
+                        ProviderCredential::NoCredentials,
+                        cx.background().clone(),
+                    ));
 
                     let mut this = Self {
                         workspace: workspace_handle,
@@ -216,7 +224,8 @@ impl AssistantPanel {
                         zoomed: false,
                         has_focus: false,
                         toolbar,
-                        api_key: Rc::new(RefCell::new(None)),
+                        credential: Rc::new(RefCell::new(ProviderCredential::NoCredentials)),
+                        completion_provider,
                         api_key_editor: None,
                         has_read_credentials: false,
                         languages: workspace.app_state().languages.clone(),
@@ -257,10 +266,7 @@ impl AssistantPanel {
         cx: &mut ViewContext<Workspace>,
     ) {
         let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
-            if this
-                .update(cx, |assistant, cx| assistant.load_api_key(cx))
-                .is_some()
-            {
+            if this.update(cx, |assistant, cx| assistant.has_credentials(cx)) {
                 this
             } else {
                 workspace.focus_panel::<AssistantPanel>(cx);
@@ -292,12 +298,7 @@ impl AssistantPanel {
         cx: &mut ViewContext<Self>,
         project: &ModelHandle<Project>,
     ) {
-        let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
-            api_key
-        } else {
-            return;
-        };
-
+        let credential = self.credential.borrow().clone();
         let selection = editor.read(cx).selections.newest_anchor().clone();
         if selection.start.excerpt_id() != selection.end.excerpt_id() {
             return;
@@ -329,7 +330,7 @@ impl AssistantPanel {
         let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
         let provider = Arc::new(OpenAICompletionProvider::new(
             "gpt-4",
-            api_key,
+            credential,
             cx.background().clone(),
         ));
 
@@ -816,7 +817,7 @@ impl AssistantPanel {
     fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
         let editor = cx.add_view(|cx| {
             ConversationEditor::new(
-                self.api_key.clone(),
+                self.credential.clone(),
                 self.languages.clone(),
                 self.fs.clone(),
                 self.workspace.clone(),
@@ -875,17 +876,20 @@ impl AssistantPanel {
         }
     }
 
-    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
+    fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
         if let Some(api_key) = self
             .api_key_editor
             .as_ref()
             .map(|editor| editor.read(cx).text(cx))
         {
             if !api_key.is_empty() {
-                cx.platform()
-                    .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
-                    .log_err();
-                *self.api_key.borrow_mut() = Some(api_key);
+                let credential = ProviderCredential::Credentials {
+                    api_key: api_key.clone(),
+                };
+                self.completion_provider
+                    .save_credentials(cx, credential.clone());
+                *self.credential.borrow_mut() = credential;
+
                 self.api_key_editor.take();
                 cx.focus_self();
                 cx.notify();
@@ -895,9 +899,9 @@ impl AssistantPanel {
         }
     }
 
-    fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
-        cx.platform().delete_credentials(OPENAI_API_URL).log_err();
-        self.api_key.take();
+    fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
+        self.completion_provider.delete_credentials(cx);
+        *self.credential.borrow_mut() = ProviderCredential::NoCredentials;
         self.api_key_editor = Some(build_api_key_editor(cx));
         cx.focus_self();
         cx.notify();
@@ -1156,13 +1160,19 @@ impl AssistantPanel {
 
         let fs = self.fs.clone();
         let workspace = self.workspace.clone();
-        let api_key = self.api_key.clone();
+        let credential = self.credential.clone();
         let languages = self.languages.clone();
         cx.spawn(|this, mut cx| async move {
             let saved_conversation = fs.load(&path).await?;
             let saved_conversation = serde_json::from_str(&saved_conversation)?;
             let conversation = cx.add_model(|cx| {
-                Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
+                Conversation::deserialize(
+                    saved_conversation,
+                    path.clone(),
+                    credential,
+                    languages,
+                    cx,
+                )
             });
             this.update(&mut cx, |this, cx| {
                 // If, by the time we've loaded the conversation, the user has already opened
@@ -1186,30 +1196,39 @@ impl AssistantPanel {
             .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
     }
 
-    fn load_api_key(&mut self, cx: &mut ViewContext<Self>) -> Option<String> {
-        if self.api_key.borrow().is_none() && !self.has_read_credentials {
-            self.has_read_credentials = true;
-            let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
-                Some(api_key)
-            } else if let Some((_, api_key)) = cx
-                .platform()
-                .read_credentials(OPENAI_API_URL)
-                .log_err()
-                .flatten()
-            {
-                String::from_utf8(api_key).log_err()
-            } else {
-                None
-            };
-            if let Some(api_key) = api_key {
-                *self.api_key.borrow_mut() = Some(api_key);
-            } else if self.api_key_editor.is_none() {
-                self.api_key_editor = Some(build_api_key_editor(cx));
-                cx.notify();
+    fn has_credentials(&mut self, cx: &mut ViewContext<Self>) -> bool {
+        let credential = self.load_credentials(cx);
+        match credential {
+            ProviderCredential::Credentials { .. } => true,
+            ProviderCredential::NotNeeded => true,
+            ProviderCredential::NoCredentials => false,
+        }
+    }
+
+    fn load_credentials(&mut self, cx: &mut ViewContext<Self>) -> ProviderCredential {
+        let existing_credential = self.credential.clone();
+        let existing_credential = existing_credential.borrow().clone();
+        match existing_credential {
+            ProviderCredential::NoCredentials => {
+                if !self.has_read_credentials {
+                    self.has_read_credentials = true;
+                    let retrieved_credentials = self.completion_provider.retrieve_credentials(cx);
+
+                    match retrieved_credentials {
+                        ProviderCredential::NoCredentials {} => {
+                            self.api_key_editor = Some(build_api_key_editor(cx));
+                            cx.notify();
+                        }
+                        _ => {
+                            *self.credential.borrow_mut() = retrieved_credentials;
+                        }
+                    }
+                }
             }
+            _ => {}
         }
 
-        self.api_key.borrow().clone()
+        self.credential.borrow().clone()
     }
 }
 
@@ -1394,7 +1413,7 @@ impl Panel for AssistantPanel {
 
     fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
         if active {
-            self.load_api_key(cx);
+            self.load_credentials(cx);
 
             if self.editors.is_empty() {
                 self.new_conversation(cx);
@@ -1459,7 +1478,7 @@ struct Conversation {
     token_count: Option<usize>,
     max_token_count: usize,
     pending_token_count: Task<Option<()>>,
-    api_key: Rc<RefCell<Option<String>>>,
+    credential: Rc<RefCell<ProviderCredential>>,
     pending_save: Task<Result<()>>,
     path: Option<PathBuf>,
     _subscriptions: Vec<Subscription>,
@@ -1471,7 +1490,8 @@ impl Entity for Conversation {
 
 impl Conversation {
     fn new(
-        api_key: Rc<RefCell<Option<String>>>,
+        credential: Rc<RefCell<ProviderCredential>>,
+
         language_registry: Arc<LanguageRegistry>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
@@ -1512,7 +1532,7 @@ impl Conversation {
             _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
             pending_save: Task::ready(Ok(())),
             path: None,
-            api_key,
+            credential,
             buffer,
         };
         let message = MessageAnchor {
@@ -1559,7 +1579,7 @@ impl Conversation {
     fn deserialize(
         saved_conversation: SavedConversation,
         path: PathBuf,
-        api_key: Rc<RefCell<Option<String>>>,
+        credential: Rc<RefCell<ProviderCredential>>,
         language_registry: Arc<LanguageRegistry>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
@@ -1614,7 +1634,7 @@ impl Conversation {
             _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
             pending_save: Task::ready(Ok(())),
             path: Some(path),
-            api_key,
+            credential,
             buffer,
         };
         this.count_remaining_tokens(cx);
@@ -1736,9 +1756,13 @@ impl Conversation {
         }
 
         if should_assist {
-            let Some(api_key) = self.api_key.borrow().clone() else {
-                return Default::default();
-            };
+            let credential = self.credential.borrow().clone();
+            match credential {
+                ProviderCredential::NoCredentials => {
+                    return Default::default();
+                }
+                _ => {}
+            }
 
             let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
                 model: self.model.full_name().to_string(),
@@ -1752,7 +1776,7 @@ impl Conversation {
                 temperature: 1.0,
             });
 
-            let stream = stream_completion(api_key, cx.background().clone(), request);
+            let stream = stream_completion(credential, cx.background().clone(), request);
             let assistant_message = self
                 .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
                 .unwrap();
@@ -2018,57 +2042,62 @@ impl Conversation {
 
     fn summarize(&mut self, cx: &mut ModelContext<Self>) {
         if self.message_anchors.len() >= 2 && self.summary.is_none() {
-            let api_key = self.api_key.borrow().clone();
-            if let Some(api_key) = api_key {
-                let messages = self
-                    .messages(cx)
-                    .take(2)
-                    .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
-                    .chain(Some(RequestMessage {
-                        role: Role::User,
-                        content:
-                            "Summarize the conversation into a short title without punctuation"
-                                .into(),
-                    }));
-                let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
-                    model: self.model.full_name().to_string(),
-                    messages: messages.collect(),
-                    stream: true,
-                    stop: vec![],
-                    temperature: 1.0,
-                });
+            let credential = self.credential.borrow().clone();
 
-                let stream = stream_completion(api_key, cx.background().clone(), request);
-                self.pending_summary = cx.spawn(|this, mut cx| {
-                    async move {
-                        let mut messages = stream.await?;
+            match credential {
+                ProviderCredential::NoCredentials => {
+                    return;
+                }
+                _ => {}
+            }
 
-                        while let Some(message) = messages.next().await {
-                            let mut message = message?;
-                            if let Some(choice) = message.choices.pop() {
-                                let text = choice.delta.content.unwrap_or_default();
-                                this.update(&mut cx, |this, cx| {
-                                    this.summary
-                                        .get_or_insert(Default::default())
-                                        .text
-                                        .push_str(&text);
-                                    cx.emit(ConversationEvent::SummaryChanged);
-                                });
-                            }
-                        }
+            let messages = self
+                .messages(cx)
+                .take(2)
+                .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
+                .chain(Some(RequestMessage {
+                    role: Role::User,
+                    content: "Summarize the conversation into a short title without punctuation"
+                        .into(),
+                }));
+            let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
+                model: self.model.full_name().to_string(),
+                messages: messages.collect(),
+                stream: true,
+                stop: vec![],
+                temperature: 1.0,
+            });
 
-                        this.update(&mut cx, |this, cx| {
-                            if let Some(summary) = this.summary.as_mut() {
-                                summary.done = true;
-                                cx.emit(ConversationEvent::SummaryChanged);
-                            }
-                        });
+            let stream = stream_completion(credential, cx.background().clone(), request);
+            self.pending_summary = cx.spawn(|this, mut cx| {
+                async move {
+                    let mut messages = stream.await?;
 
-                        anyhow::Ok(())
+                    while let Some(message) = messages.next().await {
+                        let mut message = message?;
+                        if let Some(choice) = message.choices.pop() {
+                            let text = choice.delta.content.unwrap_or_default();
+                            this.update(&mut cx, |this, cx| {
+                                this.summary
+                                    .get_or_insert(Default::default())
+                                    .text
+                                    .push_str(&text);
+                                cx.emit(ConversationEvent::SummaryChanged);
+                            });
+                        }
                     }
-                    .log_err()
-                });
-            }
+
+                    this.update(&mut cx, |this, cx| {
+                        if let Some(summary) = this.summary.as_mut() {
+                            summary.done = true;
+                            cx.emit(ConversationEvent::SummaryChanged);
+                        }
+                    });
+
+                    anyhow::Ok(())
+                }
+                .log_err()
+            });
         }
     }
 
@@ -2229,13 +2258,13 @@ struct ConversationEditor {
 
 impl ConversationEditor {
     fn new(
-        api_key: Rc<RefCell<Option<String>>>,
+        credential: Rc<RefCell<ProviderCredential>>,
         language_registry: Arc<LanguageRegistry>,
         fs: Arc<dyn Fs>,
         workspace: WeakViewHandle<Workspace>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
-        let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
+        let conversation = cx.add_model(|cx| Conversation::new(credential, language_registry, cx));
         Self::for_conversation(conversation, fs, workspace, cx)
     }
 
@@ -3431,7 +3460,13 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+        let conversation = cx.add_model(|cx| {
+            Conversation::new(
+                Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
+                registry,
+                cx,
+            )
+        });
         let buffer = conversation.read(cx).buffer.clone();
 
         let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3559,7 +3594,13 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+        let conversation = cx.add_model(|cx| {
+            Conversation::new(
+                Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
+                registry,
+                cx,
+            )
+        });
         let buffer = conversation.read(cx).buffer.clone();
 
         let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3655,7 +3696,13 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+        let conversation = cx.add_model(|cx| {
+            Conversation::new(
+                Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
+                registry,
+                cx,
+            )
+        });
         let buffer = conversation.read(cx).buffer.clone();
 
         let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3737,8 +3784,13 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let conversation =
-            cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
+        let conversation = cx.add_model(|cx| {
+            Conversation::new(
+                Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
+                registry.clone(),
+                cx,
+            )
+        });
         let buffer = conversation.read(cx).buffer.clone();
         let message_0 = conversation.read(cx).message_anchors[0].id;
         let message_1 = conversation.update(cx, |conversation, cx| {
@@ -3775,7 +3827,7 @@ mod tests {
             Conversation::deserialize(
                 conversation.read(cx).serialize(cx),
                 Default::default(),
-                Default::default(),
+                Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
                 registry.clone(),
                 cx,
             )