@@ -7,7 +7,8 @@ use crate::{
};
use ai::{
- completion::CompletionRequest,
+ auth::ProviderCredential,
+ completion::{CompletionProvider, CompletionRequest},
providers::open_ai::{
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
},
@@ -100,8 +101,8 @@ pub fn init(cx: &mut AppContext) {
cx.capture_action(ConversationEditor::copy);
cx.add_action(ConversationEditor::split);
cx.capture_action(ConversationEditor::cycle_message_role);
- cx.add_action(AssistantPanel::save_api_key);
- cx.add_action(AssistantPanel::reset_api_key);
+ cx.add_action(AssistantPanel::save_credentials);
+ cx.add_action(AssistantPanel::reset_credentials);
cx.add_action(AssistantPanel::toggle_zoom);
cx.add_action(AssistantPanel::deploy);
cx.add_action(AssistantPanel::select_next_match);
@@ -143,7 +144,8 @@ pub struct AssistantPanel {
zoomed: bool,
has_focus: bool,
toolbar: ViewHandle<Toolbar>,
- api_key: Rc<RefCell<Option<String>>>,
+ credential: Rc<RefCell<ProviderCredential>>,
+ completion_provider: Box<dyn CompletionProvider>,
api_key_editor: Option<ViewHandle<Editor>>,
has_read_credentials: bool,
languages: Arc<LanguageRegistry>,
@@ -205,6 +207,12 @@ impl AssistantPanel {
});
let semantic_index = SemanticIndex::global(cx);
+ // Defaulting currently to GPT4, allow for this to be set via config.
+ let completion_provider = Box::new(OpenAICompletionProvider::new(
+ "gpt-4",
+ ProviderCredential::NoCredentials,
+ cx.background().clone(),
+ ));
let mut this = Self {
workspace: workspace_handle,
@@ -216,7 +224,8 @@ impl AssistantPanel {
zoomed: false,
has_focus: false,
toolbar,
- api_key: Rc::new(RefCell::new(None)),
+ credential: Rc::new(RefCell::new(ProviderCredential::NoCredentials)),
+ completion_provider,
api_key_editor: None,
has_read_credentials: false,
languages: workspace.app_state().languages.clone(),
@@ -257,10 +266,7 @@ impl AssistantPanel {
cx: &mut ViewContext<Workspace>,
) {
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
- if this
- .update(cx, |assistant, cx| assistant.load_api_key(cx))
- .is_some()
- {
+ if this.update(cx, |assistant, cx| assistant.has_credentials(cx)) {
this
} else {
workspace.focus_panel::<AssistantPanel>(cx);
@@ -292,12 +298,7 @@ impl AssistantPanel {
cx: &mut ViewContext<Self>,
project: &ModelHandle<Project>,
) {
- let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
- api_key
- } else {
- return;
- };
-
+ let credential = self.credential.borrow().clone();
let selection = editor.read(cx).selections.newest_anchor().clone();
if selection.start.excerpt_id() != selection.end.excerpt_id() {
return;
@@ -329,7 +330,7 @@ impl AssistantPanel {
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let provider = Arc::new(OpenAICompletionProvider::new(
"gpt-4",
- api_key,
+ credential,
cx.background().clone(),
));
@@ -816,7 +817,7 @@ impl AssistantPanel {
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
let editor = cx.add_view(|cx| {
ConversationEditor::new(
- self.api_key.clone(),
+ self.credential.clone(),
self.languages.clone(),
self.fs.clone(),
self.workspace.clone(),
@@ -875,17 +876,20 @@ impl AssistantPanel {
}
}
- fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
+ fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
if let Some(api_key) = self
.api_key_editor
.as_ref()
.map(|editor| editor.read(cx).text(cx))
{
if !api_key.is_empty() {
- cx.platform()
- .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
- .log_err();
- *self.api_key.borrow_mut() = Some(api_key);
+ let credential = ProviderCredential::Credentials {
+ api_key: api_key.clone(),
+ };
+ self.completion_provider
+ .save_credentials(cx, credential.clone());
+ *self.credential.borrow_mut() = credential;
+
self.api_key_editor.take();
cx.focus_self();
cx.notify();
@@ -895,9 +899,9 @@ impl AssistantPanel {
}
}
- fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
- cx.platform().delete_credentials(OPENAI_API_URL).log_err();
- self.api_key.take();
+ fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
+ self.completion_provider.delete_credentials(cx);
+ *self.credential.borrow_mut() = ProviderCredential::NoCredentials;
self.api_key_editor = Some(build_api_key_editor(cx));
cx.focus_self();
cx.notify();
@@ -1156,13 +1160,19 @@ impl AssistantPanel {
let fs = self.fs.clone();
let workspace = self.workspace.clone();
- let api_key = self.api_key.clone();
+ let credential = self.credential.clone();
let languages = self.languages.clone();
cx.spawn(|this, mut cx| async move {
let saved_conversation = fs.load(&path).await?;
let saved_conversation = serde_json::from_str(&saved_conversation)?;
let conversation = cx.add_model(|cx| {
- Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
+ Conversation::deserialize(
+ saved_conversation,
+ path.clone(),
+ credential,
+ languages,
+ cx,
+ )
});
this.update(&mut cx, |this, cx| {
// If, by the time we've loaded the conversation, the user has already opened
@@ -1186,30 +1196,39 @@ impl AssistantPanel {
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
}
- fn load_api_key(&mut self, cx: &mut ViewContext<Self>) -> Option<String> {
- if self.api_key.borrow().is_none() && !self.has_read_credentials {
- self.has_read_credentials = true;
- let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
- Some(api_key)
- } else if let Some((_, api_key)) = cx
- .platform()
- .read_credentials(OPENAI_API_URL)
- .log_err()
- .flatten()
- {
- String::from_utf8(api_key).log_err()
- } else {
- None
- };
- if let Some(api_key) = api_key {
- *self.api_key.borrow_mut() = Some(api_key);
- } else if self.api_key_editor.is_none() {
- self.api_key_editor = Some(build_api_key_editor(cx));
- cx.notify();
+ fn has_credentials(&mut self, cx: &mut ViewContext<Self>) -> bool {
+ let credential = self.load_credentials(cx);
+ match credential {
+ ProviderCredential::Credentials { .. } => true,
+ ProviderCredential::NotNeeded => true,
+ ProviderCredential::NoCredentials => false,
+ }
+ }
+
+ fn load_credentials(&mut self, cx: &mut ViewContext<Self>) -> ProviderCredential {
+ let existing_credential = self.credential.clone();
+ let existing_credential = existing_credential.borrow().clone();
+ match existing_credential {
+ ProviderCredential::NoCredentials => {
+ if !self.has_read_credentials {
+ self.has_read_credentials = true;
+ let retrieved_credentials = self.completion_provider.retrieve_credentials(cx);
+
+ match retrieved_credentials {
+ ProviderCredential::NoCredentials {} => {
+ self.api_key_editor = Some(build_api_key_editor(cx));
+ cx.notify();
+ }
+ _ => {
+ *self.credential.borrow_mut() = retrieved_credentials;
+ }
+ }
+ }
}
+ _ => {}
}
- self.api_key.borrow().clone()
+ self.credential.borrow().clone()
}
}
@@ -1394,7 +1413,7 @@ impl Panel for AssistantPanel {
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
if active {
- self.load_api_key(cx);
+ self.load_credentials(cx);
if self.editors.is_empty() {
self.new_conversation(cx);
@@ -1459,7 +1478,7 @@ struct Conversation {
token_count: Option<usize>,
max_token_count: usize,
pending_token_count: Task<Option<()>>,
- api_key: Rc<RefCell<Option<String>>>,
+ credential: Rc<RefCell<ProviderCredential>>,
pending_save: Task<Result<()>>,
path: Option<PathBuf>,
_subscriptions: Vec<Subscription>,
@@ -1471,7 +1490,8 @@ impl Entity for Conversation {
impl Conversation {
fn new(
- api_key: Rc<RefCell<Option<String>>>,
+ credential: Rc<RefCell<ProviderCredential>>,
+
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
) -> Self {
@@ -1512,7 +1532,7 @@ impl Conversation {
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: None,
- api_key,
+ credential,
buffer,
};
let message = MessageAnchor {
@@ -1559,7 +1579,7 @@ impl Conversation {
fn deserialize(
saved_conversation: SavedConversation,
path: PathBuf,
- api_key: Rc<RefCell<Option<String>>>,
+ credential: Rc<RefCell<ProviderCredential>>,
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
) -> Self {
@@ -1614,7 +1634,7 @@ impl Conversation {
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: Some(path),
- api_key,
+ credential,
buffer,
};
this.count_remaining_tokens(cx);
@@ -1736,9 +1756,13 @@ impl Conversation {
}
if should_assist {
- let Some(api_key) = self.api_key.borrow().clone() else {
- return Default::default();
- };
+ let credential = self.credential.borrow().clone();
+ match credential {
+ ProviderCredential::NoCredentials => {
+ return Default::default();
+ }
+ _ => {}
+ }
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
model: self.model.full_name().to_string(),
@@ -1752,7 +1776,7 @@ impl Conversation {
temperature: 1.0,
});
- let stream = stream_completion(api_key, cx.background().clone(), request);
+ let stream = stream_completion(credential, cx.background().clone(), request);
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
@@ -2018,57 +2042,62 @@ impl Conversation {
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
if self.message_anchors.len() >= 2 && self.summary.is_none() {
- let api_key = self.api_key.borrow().clone();
- if let Some(api_key) = api_key {
- let messages = self
- .messages(cx)
- .take(2)
- .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
- .chain(Some(RequestMessage {
- role: Role::User,
- content:
- "Summarize the conversation into a short title without punctuation"
- .into(),
- }));
- let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
- model: self.model.full_name().to_string(),
- messages: messages.collect(),
- stream: true,
- stop: vec![],
- temperature: 1.0,
- });
+ let credential = self.credential.borrow().clone();
- let stream = stream_completion(api_key, cx.background().clone(), request);
- self.pending_summary = cx.spawn(|this, mut cx| {
- async move {
- let mut messages = stream.await?;
+ match credential {
+ ProviderCredential::NoCredentials => {
+ return;
+ }
+ _ => {}
+ }
- while let Some(message) = messages.next().await {
- let mut message = message?;
- if let Some(choice) = message.choices.pop() {
- let text = choice.delta.content.unwrap_or_default();
- this.update(&mut cx, |this, cx| {
- this.summary
- .get_or_insert(Default::default())
- .text
- .push_str(&text);
- cx.emit(ConversationEvent::SummaryChanged);
- });
- }
- }
+ let messages = self
+ .messages(cx)
+ .take(2)
+ .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
+ .chain(Some(RequestMessage {
+ role: Role::User,
+ content: "Summarize the conversation into a short title without punctuation"
+ .into(),
+ }));
+ let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
+ model: self.model.full_name().to_string(),
+ messages: messages.collect(),
+ stream: true,
+ stop: vec![],
+ temperature: 1.0,
+ });
- this.update(&mut cx, |this, cx| {
- if let Some(summary) = this.summary.as_mut() {
- summary.done = true;
- cx.emit(ConversationEvent::SummaryChanged);
- }
- });
+ let stream = stream_completion(credential, cx.background().clone(), request);
+ self.pending_summary = cx.spawn(|this, mut cx| {
+ async move {
+ let mut messages = stream.await?;
- anyhow::Ok(())
+ while let Some(message) = messages.next().await {
+ let mut message = message?;
+ if let Some(choice) = message.choices.pop() {
+ let text = choice.delta.content.unwrap_or_default();
+ this.update(&mut cx, |this, cx| {
+ this.summary
+ .get_or_insert(Default::default())
+ .text
+ .push_str(&text);
+ cx.emit(ConversationEvent::SummaryChanged);
+ });
+ }
}
- .log_err()
- });
- }
+
+ this.update(&mut cx, |this, cx| {
+ if let Some(summary) = this.summary.as_mut() {
+ summary.done = true;
+ cx.emit(ConversationEvent::SummaryChanged);
+ }
+ });
+
+ anyhow::Ok(())
+ }
+ .log_err()
+ });
}
}
@@ -2229,13 +2258,13 @@ struct ConversationEditor {
impl ConversationEditor {
fn new(
- api_key: Rc<RefCell<Option<String>>>,
+ credential: Rc<RefCell<ProviderCredential>>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
workspace: WeakViewHandle<Workspace>,
cx: &mut ViewContext<Self>,
) -> Self {
- let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
+ let conversation = cx.add_model(|cx| Conversation::new(credential, language_registry, cx));
Self::for_conversation(conversation, fs, workspace, cx)
}
@@ -3431,7 +3460,13 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
- let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+ let conversation = cx.add_model(|cx| {
+ Conversation::new(
+ Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
+ registry,
+ cx,
+ )
+ });
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3559,7 +3594,13 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
- let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+ let conversation = cx.add_model(|cx| {
+ Conversation::new(
+ Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
+ registry,
+ cx,
+ )
+ });
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3655,7 +3696,13 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
- let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+ let conversation = cx.add_model(|cx| {
+ Conversation::new(
+ Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
+ registry,
+ cx,
+ )
+ });
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3737,8 +3784,13 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
- let conversation =
- cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
+ let conversation = cx.add_model(|cx| {
+ Conversation::new(
+ Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
+ registry.clone(),
+ cx,
+ )
+ });
let buffer = conversation.read(cx).buffer.clone();
let message_0 = conversation.read(cx).message_anchors[0].id;
let message_1 = conversation.update(cx, |conversation, cx| {
@@ -3775,7 +3827,7 @@ mod tests {
Conversation::deserialize(
conversation.read(cx).serialize(cx),
Default::default(),
- Default::default(),
+ Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
registry.clone(),
cx,
)