From be3a8584ff4141bd84170206d69714b78d28f724 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Thu, 1 Aug 2024 15:54:47 +0200 Subject: [PATCH] assistant: Add a Configuration page (#15490) - [x] bug: setting a key doesn't update anything - [x] show high-level text on configuration page to explain what it is - [x] show "everything okay!" status when credentials are set - [x] maybe: add "verify" button to check credentials - [x] open configuration page when opening panel for first time and nothing is configured - [x] BUG: need to fix empty assistant panel if provider is `zed.dev` but not logged in Co-Authored-By: Thorsten Release Notes: - N/A --------- Co-authored-by: Thorsten Co-authored-by: Nate Butler Co-authored-by: Thorsten Ball --- Cargo.lock | 1 + crates/assistant/Cargo.toml | 1 + crates/assistant/src/assistant.rs | 2 +- crates/assistant/src/assistant_panel.rs | 493 +++++++++++++----- crates/assistant/src/using-the-assistant.md | 25 + crates/language_model/src/language_model.rs | 4 +- .../language_model/src/provider/anthropic.rs | 187 ++++--- crates/language_model/src/provider/cloud.rs | 122 +++-- .../src/provider/copilot_chat.rs | 207 ++++---- crates/language_model/src/provider/fake.rs | 4 +- crates/language_model/src/provider/google.rs | 136 +++-- crates/language_model/src/provider/ollama.rs | 70 +-- crates/language_model/src/provider/open_ai.rs | 146 ++++-- 13 files changed, 928 insertions(+), 470 deletions(-) create mode 100644 crates/assistant/src/using-the-assistant.md diff --git a/Cargo.lock b/Cargo.lock index ded3128acd58e161229868d50f9b62d040d5013a..56e8f1911e2d8665606e03005e963a43b31c02ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -423,6 +423,7 @@ dependencies = [ "language", "language_model", "log", + "markdown", "menu", "multi_buffer", "ollama", diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index bcff377a4563c335f12476f6e8b9ee994e3e1028..e5f769ede0544e75af2f6da0654861e8d3f1910c 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -47,6 +47,7 @@ indoc.workspace = true language.workspace = true language_model.workspace = true log.workspace = true +markdown.workspace = true menu.workspace = true multi_buffer.workspace = true ollama = { workspace = true, features = ["schemars"] } diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 8e8f00fc9aadbeef4c5be8420fb8d3a23535f76c..4c6941d756dc5724df6edb583984e0b9ba4f067f 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -45,8 +45,8 @@ actions!( QuoteSelection, InsertIntoEditor, ToggleFocus, - ResetKey, InsertActivePrompt, + ShowConfiguration, DeployHistory, DeployPromptLibrary, ConfirmCommand, diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index a335e80b39ec701b54a1d1c84a05c59873b03c07..c3bc4b9a06cb2edade21f6d86effa7570aca2b7b 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,4 +1,3 @@ -use crate::ContextStoreEvent; use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings}, humanize_token_count, @@ -13,8 +12,9 @@ use crate::{ DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, EditStepOperations, EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection, - RemoteContextMetadata, ResetKey, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector, + RemoteContextMetadata, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector, }; +use crate::{ContextStoreEvent, ShowConfiguration}; use anyhow::{anyhow, Result}; use assistant_slash_command::{SlashCommand, SlashCommandOutputSection}; use client::proto; @@ -31,18 +31,20 @@ use editor::{ use editor::{display_map::CreaseId, FoldPlaceholder}; use fs::Fs; use gpui::{ - div, percentage, point, Action, Animation, AnimationExt, AnyElement, AnyView, AppContext, + div, percentage, point, svg, Action, Animation, AnimationExt, AnyElement, AnyView, AppContext, AsyncWindowContext, ClipboardItem, Context as _, DismissEvent, Empty, Entity, EventEmitter, FocusHandle, FocusableView, InteractiveElement, IntoElement, Model, ParentElement, Pixels, - Render, SharedString, StatefulInteractiveElement, Styled, Subscription, Task, Transformation, - UpdateGlobal, View, ViewContext, VisualContext, WeakView, WindowContext, + Render, SharedString, StatefulInteractiveElement, Styled, Subscription, Task, + TextStyleRefinement, Transformation, UpdateGlobal, View, ViewContext, VisualContext, WeakView, + WindowContext, }; use indexed_docs::IndexedDocsStore; use language::{ language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point, ToOffset, }; -use language_model::{LanguageModelProviderId, LanguageModelRegistry, Role}; +use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, Role}; +use markdown::{Markdown, MarkdownStyle}; use multi_buffer::MultiBufferRow; use picker::{Picker, PickerDelegate}; use project::{Project, ProjectLspAdapterDelegate}; @@ -58,6 +60,7 @@ use std::{ time::Duration, }; use terminal_view::{terminal_panel::TerminalPanel, TerminalView}; +use theme::ThemeSettings; use ui::TintColor; use ui::{ prelude::*, @@ -91,7 +94,8 @@ pub fn init(cx: &mut AppContext) { }) .register_action(AssistantPanel::inline_assist) .register_action(ContextEditor::quote_selection) - .register_action(ContextEditor::insert_selection); + .register_action(ContextEditor::insert_selection) + .register_action(AssistantPanel::show_configuration); }, ) .detach(); @@ -136,7 +140,6 @@ pub struct AssistantPanel { languages: Arc, fs: Arc, subscriptions: Vec, - authentication_prompt: Option, model_selector_menu_handle: PopoverMenuHandle, model_summary_editor: View, authenticate_provider_task: Option<(LanguageModelProviderId, Task<()>)>, @@ -365,6 +368,7 @@ impl AssistantPanel { .action("New Context", Box::new(NewFile)) .action("History", Box::new(DeployHistory)) .action("Prompt Library", Box::new(DeployPromptLibrary)) + .action("Configure", Box::new(ShowConfiguration)) .action(zoom_label, Box::new(ToggleZoom)) }); cx.subscribe(&menu, |pane, _, _: &DismissEvent, _| { @@ -399,8 +403,10 @@ impl AssistantPanel { language_model::Event::ActiveModelChanged => { this.completion_provider_changed(cx); } - language_model::Event::ProviderStateChanged - | language_model::Event::AddedProvider(_) + language_model::Event::ProviderStateChanged => { + this.ensure_authenticated(cx); + } + language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { this.ensure_authenticated(cx); } @@ -408,7 +414,7 @@ impl AssistantPanel { ), ]; - Self { + let mut this = Self { pane, workspace: workspace.weak_handle(), width: None, @@ -418,11 +424,21 @@ impl AssistantPanel { languages: workspace.app_state().languages.clone(), fs: workspace.app_state().fs.clone(), subscriptions, - authentication_prompt: None, model_selector_menu_handle, model_summary_editor, authenticate_provider_task: None, - } + }; + + if LanguageModelRegistry::read_global(cx) + .active_provider() + .is_none() + { + this.show_configuration_for_provider(None, cx); + } else { + this.new_context(cx); + }; + + this } fn handle_pane_event( @@ -582,63 +598,39 @@ impl AssistantPanel { *old_provider_id != new_provider_id }) { + self.authenticate_provider_task = None; self.ensure_authenticated(cx); } } - fn authentication_prompt(cx: &mut WindowContext) -> Option { - if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() { - if !provider.is_authenticated(cx) { - return Some(provider.authentication_prompt(cx)); - } - } - None - } - fn ensure_authenticated(&mut self, cx: &mut ViewContext) { if self.is_authenticated(cx) { - self.set_authentication_prompt(None, cx); return; } - let Some(provider_id) = LanguageModelRegistry::read_global(cx) - .active_provider() - .map(|p| p.id()) - else { + let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else { return; }; let load_credentials = self.authenticate(cx); - self.authenticate_provider_task = Some(( - provider_id, - cx.spawn(|this, mut cx| async move { - let _ = load_credentials.await; - this.update(&mut cx, |this, cx| { - this.show_authentication_prompt(cx); - this.authenticate_provider_task = None; - }) - .log_err(); - }), - )); - } - - fn show_authentication_prompt(&mut self, cx: &mut ViewContext) { - let prompt = Self::authentication_prompt(cx); - self.set_authentication_prompt(prompt, cx); - } - - fn set_authentication_prompt(&mut self, prompt: Option, cx: &mut ViewContext) { - if self.active_context_editor(cx).is_none() { - self.new_context(cx); - } - - for context_editor in self.context_editors(cx) { - context_editor.update(cx, |editor, cx| { - editor.set_authentication_prompt(prompt.clone(), cx); - }); + if self.authenticate_provider_task.is_none() { + self.authenticate_provider_task = Some(( + provider.id(), + cx.spawn(|this, mut cx| async move { + let _ = load_credentials.await; + this.update(&mut cx, |this, cx| { + if !provider.is_authenticated(cx) { + this.show_configuration_for_provider(Some(provider), cx) + } else if !this.has_any_context_editors(cx) { + this.new_context(cx); + } + this.authenticate_provider_task = None; + }) + .log_err(); + }), + )); } - cx.notify(); } pub fn inline_assist( @@ -900,6 +892,58 @@ impl AssistantPanel { } } + fn show_configuration( + workspace: &mut Workspace, + _: &ShowConfiguration, + cx: &mut ViewContext, + ) { + let Some(panel) = workspace.panel::(cx) else { + return; + }; + + if !panel.focus_handle(cx).contains_focused(cx) { + workspace.toggle_panel_focus::(cx); + } + + panel.update(cx, |this, cx| { + this.show_configuration_for_active_provider(cx); + }) + } + + fn show_configuration_for_active_provider(&mut self, cx: &mut ViewContext) { + let provider = LanguageModelRegistry::read_global(cx).active_provider(); + self.show_configuration_for_provider(provider, cx); + } + + fn show_configuration_for_provider( + &mut self, + provider: Option>, + cx: &mut ViewContext, + ) { + let configuration_item_ix = self + .pane + .read(cx) + .items() + .position(|item| item.downcast::().is_some()); + + if let Some(configuration_item_ix) = configuration_item_ix { + self.pane.update(cx, |pane, cx| { + pane.activate_item(configuration_item_ix, true, true, cx); + }); + } else { + let configuration = cx.new_view(|cx| { + let mut view = ConfigurationView::new(self.focus_handle(cx), cx); + if let Some(provider) = provider { + view.set_active_tab(provider, cx); + } + view + }); + self.pane.update(cx, |pane, cx| { + pane.add_item(Box::new(configuration), true, true, None, cx); + }); + } + } + fn deploy_history(&mut self, _: &DeployHistory, cx: &mut ViewContext) { let history_item_ix = self .pane @@ -931,35 +975,22 @@ impl AssistantPanel { open_prompt_library(self.languages.clone(), cx).detach_and_log_err(cx); } - fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext) { - if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() { - let reset_credentials = provider.reset_credentials(cx); - cx.spawn(|this, mut cx| async move { - reset_credentials.await?; - this.update(&mut cx, |this, cx| { - this.show_authentication_prompt(cx); - }) - }) - .detach_and_log_err(cx); - } - } - fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext) { self.model_selector_menu_handle.toggle(cx); } - fn context_editors(&self, cx: &AppContext) -> Vec> { + fn active_context_editor(&self, cx: &AppContext) -> Option> { self.pane .read(cx) - .items_of_type::() - .collect() + .active_item()? + .downcast::() } - fn active_context_editor(&self, cx: &AppContext) -> Option> { + fn has_any_context_editors(&self, cx: &AppContext) -> bool { self.pane .read(cx) - .active_item()? - .downcast::() + .items() + .any(|item| item.downcast::().is_some()) } pub fn active_context(&self, cx: &AppContext) -> Option> { @@ -1083,8 +1114,10 @@ impl AssistantPanel { |provider| provider.authenticate(cx), ) } +} - fn render_signed_in(&mut self, cx: &mut ViewContext) -> impl IntoElement { +impl Render for AssistantPanel { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { let mut registrar = DivRegistrar::new( |panel, cx| { panel @@ -1105,21 +1138,14 @@ impl AssistantPanel { .on_action(cx.listener(|this, _: &workspace::NewFile, cx| { this.new_context(cx); })) + .on_action(cx.listener(|this, _: &ShowConfiguration, cx| { + this.show_configuration_for_active_provider(cx) + })) .on_action(cx.listener(AssistantPanel::deploy_history)) .on_action(cx.listener(AssistantPanel::deploy_prompt_library)) - .on_action(cx.listener(AssistantPanel::reset_credentials)) .on_action(cx.listener(AssistantPanel::toggle_model_selector)) .child(registrar.size_full().child(self.pane.clone())) - } -} - -impl Render for AssistantPanel { - fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - if let Some(authentication_prompt) = self.authentication_prompt.as_ref() { - authentication_prompt.clone().into_any() - } else { - self.render_signed_in(cx).into_any_element() - } + .into_any_element() } } @@ -1242,7 +1268,6 @@ struct ActiveEditStep { pub struct ContextEditor { context: Model, - authentication_prompt: Option, fs: Arc, workspace: WeakView, project: Model, @@ -1300,7 +1325,6 @@ impl ContextEditor { let sections = context.read(cx).slash_command_output_sections().to_vec(); let mut this = Self { context, - authentication_prompt: None, editor, lsp_adapter_delegate, blocks: Default::default(), @@ -1320,15 +1344,6 @@ impl ContextEditor { this } - fn set_authentication_prompt( - &mut self, - authentication_prompt: Option, - cx: &mut ViewContext, - ) { - self.authentication_prompt = authentication_prompt; - cx.notify(); - } - fn insert_default_prompt(&mut self, cx: &mut ViewContext) { let command_name = DefaultSlashCommand.name(); self.editor.update(cx, |editor, cx| { @@ -1355,10 +1370,6 @@ impl ContextEditor { } fn assist(&mut self, _: &Assist, cx: &mut ViewContext) { - if self.authentication_prompt.is_some() { - return; - } - if !self.apply_edit_step(cx) { self.send_to_model(cx); } @@ -2419,26 +2430,19 @@ impl Render for ContextEditor { .size_full() .v_flex() .child( - if let Some(authentication_prompt) = self.authentication_prompt.as_ref() { - div() - .flex_grow() - .bg(cx.theme().colors().editor_background) - .child(authentication_prompt.clone().into_any()) - } else { - div() - .flex_grow() - .bg(cx.theme().colors().editor_background) - .child(self.editor.clone()) - .child( - h_flex() - .w_full() - .absolute() - .bottom_0() - .p_4() - .justify_end() - .child(self.render_send_button(cx)), - ) - }, + div() + .flex_grow() + .bg(cx.theme().colors().editor_background) + .child(self.editor.clone()) + .child( + h_flex() + .w_full() + .absolute() + .bottom_0() + .p_4() + .justify_end() + .child(self.render_send_button(cx)), + ), ) } } @@ -2992,6 +2996,253 @@ impl Item for ContextHistory { } } +pub struct ConfigurationView { + fallback_handle: FocusHandle, + using_assistant_description: View, + active_tab: Option, +} + +struct ActiveTab { + provider: Arc, + configuration_prompt: AnyView, + focus_handle: Option, + load_credentials_task: Option>, +} + +impl ActiveTab { + fn is_loading_credentials(&self) -> bool { + if let Some(task) = &self.load_credentials_task { + if let Task::Spawned(_) = task { + return true; + } + } + false + } +} + +// TODO: We need to remove this once we have proper text and styling +const SHOW_CONFIGURATION_TEXT: bool = false; + +impl ConfigurationView { + fn new(fallback_handle: FocusHandle, cx: &mut ViewContext) -> Self { + let usage_description = cx.new_view(|cx| { + let text = include_str!("./using-the-assistant.md"); + + let settings = ThemeSettings::get_global(cx); + let mut base_text_style = cx.text_style(); + base_text_style.refine(&TextStyleRefinement { + font_family: Some(settings.ui_font.family.clone()), + font_size: Some(TextSize::XSmall.rems(cx).into()), + color: Some(cx.theme().colors().editor_foreground), + background_color: Some(gpui::transparent_black()), + ..Default::default() + }); + let markdown_style = MarkdownStyle { + base_text_style, + selection_background_color: { cx.theme().players().local().selection }, + inline_code: TextStyleRefinement { + background_color: Some(cx.theme().colors().background), + ..Default::default() + }, + link: TextStyleRefinement { + underline: Some(gpui::UnderlineStyle { + thickness: px(1.), + color: Some(cx.theme().colors().editor_foreground), + wavy: false, + }), + ..Default::default() + }, + ..Default::default() + }; + Markdown::new(text.to_string(), markdown_style.clone(), None, cx, None) + }); + + Self { + fallback_handle, + using_assistant_description: usage_description, + active_tab: None, + } + } + + fn set_active_tab( + &mut self, + provider: Arc, + cx: &mut ViewContext, + ) { + let (view, focus_handle) = provider.configuration_view(cx); + + if let Some(focus_handle) = &focus_handle { + focus_handle.focus(cx); + } else { + self.fallback_handle.focus(cx); + } + + let load_credentials = provider.authenticate(cx); + let load_credentials_task = cx.spawn(|this, mut cx| async move { + let _ = load_credentials.await; + this.update(&mut cx, |this, cx| { + if let Some(active_tab) = &mut this.active_tab { + active_tab.load_credentials_task = None; + cx.notify(); + } + }) + .log_err(); + }); + + self.active_tab = Some(ActiveTab { + provider, + configuration_prompt: view, + focus_handle, + load_credentials_task: Some(load_credentials_task), + }); + cx.notify(); + } + + fn render_active_tab_view(&mut self, cx: &mut ViewContext) -> Option
{ + let Some(active_tab) = &self.active_tab else { + return None; + }; + + let show_spinner = active_tab.is_loading_credentials(); + + let content = if show_spinner { + let loading_icon = svg() + .size_4() + .path(IconName::ArrowCircle.path()) + .text_color(cx.text_style().color) + .with_animation( + "icon_circle_arrow", + Animation::new(Duration::from_secs(2)).repeat(), + |svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))), + ); + + h_flex() + .gap_2() + .child(loading_icon) + .child(Label::new("Loading provider configuration...").size(LabelSize::Small)) + .into_any_element() + } else { + active_tab.configuration_prompt.clone().into_any_element() + }; + + Some( + div() + .p(Spacing::Large.rems(cx)) + .bg(cx.theme().colors().title_bar_background) + .border_1() + .border_color(cx.theme().colors().border_variant) + .rounded_md() + .child(content), + ) + } + + fn render_tab( + &self, + provider: &Arc, + cx: &mut ViewContext, + ) -> impl IntoElement { + let button_id = SharedString::from(format!("tab-{}", provider.id().0)); + let is_active = self.active_tab.as_ref().map(|t| t.provider.id()) == Some(provider.id()); + ButtonLike::new(button_id) + .size(ButtonSize::Compact) + .style(ButtonStyle::Transparent) + .selected(is_active) + .on_click(cx.listener({ + let provider = provider.clone(); + move |this, _, cx| { + this.set_active_tab(provider.clone(), cx); + } + })) + .child( + div() + .my_3() + .pb_px() + .border_b_1() + .border_color(if is_active { + cx.theme().colors().text_accent + } else { + cx.theme().colors().border_transparent + }) + .when(!is_active, |this| { + this.group_hover("", |this| { + this.border_color(cx.theme().colors().border_variant) + }) + }) + .child(Label::new(provider.name().0).size(LabelSize::Small).color( + if is_active { + Color::Accent + } else { + Color::Default + }, + )), + ) + } +} + +impl Render for ConfigurationView { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let providers = LanguageModelRegistry::read_global(cx).providers(); + + if self.active_tab.is_none() && !providers.is_empty() { + self.set_active_tab(providers[0].clone(), cx); + } + + let tabs = h_flex().mx_neg_1().gap_3().children( + providers + .iter() + .map(|provider| self.render_tab(provider, cx)), + ); + + v_flex() + .id("assistant-configuration-view") + .w_full() + .min_h_full() + .p(Spacing::XXLarge.rems(cx)) + .overflow_y_scroll() + .gap_6() + .child( + v_flex() + .gap_2() + .child( + Headline::new("Get Started with the Assistant").size(HeadlineSize::Medium), + ) + .child( + Label::new("Choose a provider to get started with the assistant.") + .color(Color::Muted), + ), + ) + .child( + v_flex() + .gap_2() + .child(Headline::new("Choosing a Provider").size(HeadlineSize::Small)) + .child(tabs) + .children(self.render_active_tab_view(cx)), + ) + .when(SHOW_CONFIGURATION_TEXT, |this| { + this.child(self.using_assistant_description.clone()) + }) + } +} + +impl EventEmitter<()> for ConfigurationView {} + +impl FocusableView for ConfigurationView { + fn focus_handle(&self, _: &AppContext) -> FocusHandle { + self.active_tab + .as_ref() + .and_then(|tab| tab.focus_handle.clone()) + .unwrap_or(self.fallback_handle.clone()) + } +} + +impl Item for ConfigurationView { + type Event = (); + + fn tab_content_text(&self, _cx: &WindowContext) -> Option { + Some("Configuration".into()) + } +} + type ToggleFold = Arc; fn render_slash_command_output_toggle( diff --git a/crates/assistant/src/using-the-assistant.md b/crates/assistant/src/using-the-assistant.md new file mode 100644 index 0000000000000000000000000000000000000000..b064c7d10abfbfc6f73a59c55a5d406168e544ea --- /dev/null +++ b/crates/assistant/src/using-the-assistant.md @@ -0,0 +1,25 @@ +### Using the Assistant + +Once you have configured a provider, you can interact with the provider's language models in a context editor. + +To create a new context editor, use the menu in the top right of the assistant panel and the `New Context` option. + +In the context editor, select a model from one of the configured providers, type a message in the `You` block, and submit with `cmd-enter` (or `ctrl-enter` on Linux). + +### Inline assistant + +When you're in a normal editor, you can use `ctrl-enter` to open the inline assistant. + +The inline assistant allows you to send the current selection (or the current line) to a language model and modify the selection with the language model's response. + +### Adding Prompts + +You can customize the default prompts that are used in new context editor, by opening the `Prompt Library`. + +Open the `Prompt Library` using either the menu in the top right of the assistant panel and choosing the `Prompt Library` option, or by using the `assistant: deploy prompt library` command when the assistant panel is focused. + +### Viewing past contexts + +You view all previous contexts by opening up the `History` tab in the assistant panel. + +Open the `History` using the menu in the top right of the assistant panel and choosing the `History`. diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 611e4208c2dd396527f2e17eaba141dff5c4a81d..98abbcb56c6e5fc698bc066143772eb0ddc916cd 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -9,7 +9,7 @@ pub mod settings; use anyhow::Result; use client::Client; use futures::{future::BoxFuture, stream::BoxStream}; -use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext}; +use gpui::{AnyView, AppContext, AsyncAppContext, FocusHandle, SharedString, Task, WindowContext}; pub use model::*; use project::Fs; pub(crate) use rate_limiter::*; @@ -84,7 +84,7 @@ pub trait LanguageModelProvider: 'static { fn load_model(&self, _model: Arc, _cx: &AppContext) {} fn is_authenticated(&self, cx: &AppContext) -> bool; fn authenticate(&self, cx: &mut AppContext) -> Task>; - fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView; + fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option); fn reset_credentials(&self, cx: &mut AppContext) -> Task>; } diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 3999483da07c2de026216d7315bb8a602c34d065..d92fddd4937c4f3d4996cf4f823c2d1373e816c7 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -8,8 +8,8 @@ use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{ - AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View, - WhiteSpace, + AnyView, AppContext, AsyncAppContext, FocusHandle, FocusableView, FontStyle, ModelContext, + Subscription, Task, TextStyle, View, WhiteSpace, }; use http_client::HttpClient; use schemars::JsonSchema; @@ -18,8 +18,7 @@ use settings::{Settings, SettingsStore}; use std::{sync::Arc, time::Duration}; use strum::IntoEnumIterator; use theme::ThemeSettings; -use ui::prelude::*; -use util::ResultExt; +use ui::{prelude::*, Indicator}; const PROVIDER_ID: &str = "anthropic"; const PROVIDER_NAME: &str = "Anthropic"; @@ -49,6 +48,43 @@ pub struct State { _subscription: Subscription, } +impl State { + fn reset_api_key(&self, cx: &mut ModelContext) -> Task> { + let delete_credentials = + cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url); + cx.spawn(|this, mut cx| async move { + delete_credentials.await.ok(); + this.update(&mut cx, |this, cx| { + this.api_key = None; + cx.notify(); + }) + }) + } + + fn set_api_key(&mut self, api_key: String, cx: &mut ModelContext) -> Task> { + let write_credentials = cx.write_credentials( + AllLanguageModelSettings::get_global(cx) + .anthropic + .api_url + .as_str(), + "Bearer", + api_key.as_bytes(), + ); + cx.spawn(|this, mut cx| async move { + write_credentials.await?; + + this.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + + fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } +} + impl AnthropicLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { let state = cx.new_model(|cx| State { @@ -120,7 +156,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { } fn is_authenticated(&self, cx: &AppContext) -> bool { - self.state.read(cx).api_key.is_some() + self.state.read(cx).is_authenticated() } fn authenticate(&self, cx: &mut AppContext) -> Task> { @@ -151,22 +187,14 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { } } - fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { - cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx)) - .into() + fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option) { + let view = cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx)); + let focus_handle = view.focus_handle(cx); + (view.into(), Some(focus_handle)) } fn reset_credentials(&self, cx: &mut AppContext) -> Task> { - let state = self.state.clone(); - let delete_credentials = - cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url); - cx.spawn(|mut cx| async move { - delete_credentials.await.log_err(); - state.update(&mut cx, |this, cx| { - this.api_key = None; - cx.notify(); - }) - }) + self.state.update(cx, |state, cx| state.reset_api_key(cx)) } } @@ -350,18 +378,24 @@ impl LanguageModel for AnthropicModel { } } -struct AuthenticationPrompt { - api_key: View, +struct ConfigurationView { + api_key_editor: View, state: gpui::Model, } -impl AuthenticationPrompt { +impl FocusableView for ConfigurationView { + fn focus_handle(&self, cx: &AppContext) -> FocusHandle { + self.api_key_editor.read(cx).focus_handle(cx) + } +} + +impl ConfigurationView { fn new(state: gpui::Model, cx: &mut WindowContext) -> Self { Self { - api_key: cx.new_view(|cx| { + api_key_editor: cx.new_view(|cx| { let mut editor = Editor::single_line(cx); editor.set_placeholder_text( - "sk-000000000000000000000000000000000000000000000000", + "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", cx, ); editor @@ -371,29 +405,22 @@ impl AuthenticationPrompt { } fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { - let api_key = self.api_key.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx); if api_key.is_empty() { return; } - let write_credentials = cx.write_credentials( - AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .as_str(), - "Bearer", - api_key.as_bytes(), - ); - let state = self.state.clone(); - cx.spawn(|_, mut cx| async move { - write_credentials.await?; + self.state + .update(cx, |state, cx| state.set_api_key(api_key, cx)) + .detach_and_log_err(cx); + } - state.update(&mut cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - .detach_and_log_err(cx); + fn reset_api_key(&mut self, cx: &mut ViewContext) { + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", cx)); + self.state + .update(cx, |state, cx| state.reset_api_key(cx)) + .detach_and_log_err(cx); } fn render_api_key_editor(&self, cx: &mut ViewContext) -> impl IntoElement { @@ -413,7 +440,7 @@ impl AuthenticationPrompt { white_space: WhiteSpace::Normal, }; EditorElement::new( - &self.api_key, + &self.api_key_editor, EditorStyle { background: cx.theme().colors().editor_background, local_player: cx.theme().players().local(), @@ -424,7 +451,7 @@ impl AuthenticationPrompt { } } -impl Render for AuthenticationPrompt { +impl Render for ConfigurationView { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { const INSTRUCTIONS: [&str; 4] = [ "To use the assistant panel or inline assistant, you need to add your Anthropic API key.", @@ -433,38 +460,48 @@ impl Render for AuthenticationPrompt { "Paste your Anthropic API key below and hit enter to use the assistant:", ]; - v_flex() - .p_4() - .size_full() - .on_action(cx.listener(Self::save_api_key)) - .children( - INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), - ) - .child( - h_flex() - .w_full() - .my_2() - .px_2() - .py_1() - .bg(cx.theme().colors().editor_background) - .rounded_md() - .child(self.render_api_key_editor(cx)), - ) - .child( - Label::new( - "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.", + if self.state.read(cx).is_authenticated() { + h_flex() + .size_full() + .justify_between() + .child( + h_flex() + .gap_2() + .child(Indicator::dot().color(Color::Success)) + .child(Label::new("API Key configured").size(LabelSize::Small)), ) - .size(LabelSize::Small), - ) - .child( - h_flex() - .gap_2() - .child(Label::new("Click on").size(LabelSize::Small)) - .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall)) - .child( - Label::new("in the status bar to close this panel.").size(LabelSize::Small), - ), - ) - .into_any() + .child( + Button::new("reset-key", "Reset key") + .icon(Some(IconName::Trash)) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))), + ) + .into_any() + } else { + v_flex() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .children( + INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), + ) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .rounded_md() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new( + "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.", + ) + .size(LabelSize::Small), + ) + .into_any() + } } } diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index c9afc34ba4a0d708234e503a2076ae1c3ecd67a4..4b8018ba4693d145b7d975fd829912d9de01d8fa 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -8,7 +8,7 @@ use anyhow::{anyhow, Context as _, Result}; use client::Client; use collections::BTreeMap; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; -use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task}; +use gpui::{AnyView, AppContext, AsyncAppContext, FocusHandle, ModelContext, Subscription, Task}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; @@ -21,7 +21,7 @@ use crate::LanguageModelProvider; use super::anthropic::count_anthropic_tokens; pub const PROVIDER_ID: &str = "zed.dev"; -pub const PROVIDER_NAME: &str = "zed.dev"; +pub const PROVIDER_NAME: &str = "Zed AI"; #[derive(Default, Clone, Debug, PartialEq)] pub struct ZedDotDevSettings { @@ -57,6 +57,10 @@ pub struct State { } impl State { + fn is_connected(&self) -> bool { + self.status.is_connected() + } + fn authenticate(&self, cx: &mut ModelContext) -> Task> { let client = self.client.clone(); cx.spawn(move |this, mut cx| async move { @@ -179,15 +183,17 @@ impl LanguageModelProvider for CloudLanguageModelProvider { self.state.read(cx).status.is_connected() } - fn authenticate(&self, cx: &mut AppContext) -> Task> { - self.state.update(cx, |state, cx| state.authenticate(cx)) + fn authenticate(&self, _cx: &mut AppContext) -> Task> { + Task::ready(Ok(())) } - fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { - cx.new_view(|_cx| AuthenticationPrompt { - state: self.state.clone(), - }) - .into() + fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option) { + let view = cx + .new_view(|_cx| ConfigurationView { + state: self.state.clone(), + }) + .into(); + (view, None) } fn reset_credentials(&self, _cx: &mut AppContext) -> Task> { @@ -376,38 +382,88 @@ impl LanguageModel for CloudLanguageModel { } } -struct AuthenticationPrompt { +struct ConfigurationView { state: gpui::Model, } -impl Render for AuthenticationPrompt { +impl ConfigurationView { + fn authenticate(&mut self, cx: &mut ViewContext) { + self.state.update(cx, |state, cx| { + state.authenticate(cx).detach_and_log_err(cx); + }); + cx.notify(); + } +} + +impl Render for ConfigurationView { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline."; + const ZED_AI_URL: &str = "https://zed.dev/ai"; + const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/settings"; + + let is_connected = self.state.read(cx).is_connected(); + + let is_pro = false; - v_flex().gap_6().p_4().child(Label::new(LABEL)).child( + if is_connected { v_flex() - .gap_2() + .gap_3() + .max_w_4_5() + .child(Label::new( + if is_pro { + "You have full access to Zed's hosted models from Anthropic, OpenAI, Google through Zed Pro." + } else { + "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan." + })) .child( - Button::new("sign_in", "Sign in") - .icon_color(Color::Muted) - .icon(IconName::Github) - .icon_position(IconPosition::Start) - .style(ButtonStyle::Filled) - .full_width() - .on_click(cx.listener(move |this, _, cx| { - this.state.update(cx, |provider, cx| { - provider.authenticate(cx).detach_and_log_err(cx); - cx.notify(); - }); - })), + if is_pro { + h_flex().child( + Button::new("manage_settings", "Manage Subscription") + .style(ButtonStyle::Filled) + .on_click(cx.listener(|_, _, cx| { + cx.open_url(ACCOUNT_SETTINGS_URL) + }))) + } else { + h_flex() + .gap_2() + .child( + Button::new("learn_more", "Learn more") + .style(ButtonStyle::Subtle) + .on_click(cx.listener(|_, _, cx| { + cx.open_url(ZED_AI_URL) + }))) + .child( + Button::new("upgrade", "Upgrade") + .style(ButtonStyle::Subtle) + .color(Color::Accent) + .on_click(cx.listener(|_, _, cx| { + cx.open_url(ACCOUNT_SETTINGS_URL) + }))) + }, ) + } else { + v_flex() + .gap_6() + .child(Label::new("Use the zed.dev to access language models.")) .child( - div().flex().w_full().items_center().child( - Label::new("Sign in to enable collaboration.") - .color(Color::Muted) - .size(LabelSize::Small), - ), - ), - ) + v_flex() + .gap_2() + .child( + Button::new("sign_in", "Sign in") + .icon_color(Color::Muted) + .icon(IconName::Github) + .icon_position(IconPosition::Start) + .style(ButtonStyle::Filled) + .full_width() + .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))), + ) + .child( + div().flex().w_full().items_center().child( + Label::new("Sign in to enable collaboration.") + .color(Color::Muted) + .size(LabelSize::Small), + ), + ), + ) + } } } diff --git a/crates/language_model/src/provider/copilot_chat.rs b/crates/language_model/src/provider/copilot_chat.rs index f73ddb74bfb6549d4359817e3fc6f8ee1b159f47..7de851038908c7c651177783e95e36351c809120 100644 --- a/crates/language_model/src/provider/copilot_chat.rs +++ b/crates/language_model/src/provider/copilot_chat.rs @@ -11,16 +11,16 @@ use futures::future::BoxFuture; use futures::stream::BoxStream; use futures::{FutureExt, StreamExt}; use gpui::{ - percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, Render, - Subscription, Task, Transformation, + percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, FocusHandle, + Model, Render, Subscription, Task, Transformation, }; use settings::{Settings, SettingsStore}; use std::time::Duration; use strum::IntoEnumIterator; use ui::{ - div, v_flex, Button, ButtonCommon, Clickable, Color, Context, FixedWidth, IconName, - IconPosition, IconSize, IntoElement, Label, LabelCommon, ParentElement, Styled, ViewContext, - VisualContext, WindowContext, + div, h_flex, v_flex, Button, ButtonCommon, Clickable, Color, Context, FixedWidth, IconName, + IconPosition, IconSize, Indicator, IntoElement, Label, LabelCommon, ParentElement, Styled, + ViewContext, VisualContext, WindowContext, }; use crate::settings::AllLanguageModelSettings; @@ -49,6 +49,14 @@ pub struct State { _settings_subscription: Subscription, } +impl State { + fn is_authenticated(&self, cx: &AppContext) -> bool { + CopilotChat::global(cx) + .map(|m| m.read(cx).is_authenticated()) + .unwrap_or(false) + } +} + impl CopilotChatLanguageModelProvider { pub fn new(cx: &mut AppContext) -> Self { let state = cx.new_model(|cx| { @@ -95,9 +103,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { } fn is_authenticated(&self, cx: &AppContext) -> bool { - CopilotChat::global(cx) - .map(|m| m.read(cx).is_authenticated()) - .unwrap_or(false) + self.state.read(cx).is_authenticated(cx) } fn authenticate(&self, cx: &mut AppContext) -> Task> { @@ -122,29 +128,16 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { Task::ready(result) } - fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { - cx.new_view(|cx| AuthenticationPrompt::new(cx)).into() - } - - fn reset_credentials(&self, cx: &mut AppContext) -> Task> { - let Some(copilot) = Copilot::global(cx) else { - return Task::ready(Err(anyhow::anyhow!( - "Copilot is not available. Please ensure Copilot is enabled and running and try again." - ))); - }; - + fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option) { let state = self.state.clone(); + let view = cx.new_view(|cx| ConfigurationView::new(state, cx)).into(); + (view, None) + } - cx.spawn(|mut cx| async move { - cx.update_model(&copilot, |model, cx| model.sign_out(cx))? - .await?; - - cx.update_model(&state, |_, cx| { - cx.notify(); - })?; - - Ok(()) - }) + fn reset_credentials(&self, _cx: &mut AppContext) -> Task> { + Task::ready(Err(anyhow!( + "Signing out of GitHub Copilot Chat is currently not supported." + ))) } } @@ -281,17 +274,19 @@ impl CopilotChatLanguageModel { } } -struct AuthenticationPrompt { +struct ConfigurationView { copilot_status: Option, + state: Model, _subscription: Option, } -impl AuthenticationPrompt { - pub fn new(cx: &mut ViewContext) -> Self { +impl ConfigurationView { + pub fn new(state: Model, cx: &mut ViewContext) -> Self { let copilot = Copilot::global(cx); Self { copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()), + state, _subscription: copilot.as_ref().map(|copilot| { cx.observe(copilot, |this, model, cx| { this.copilot_status = Some(model.read(cx).status()); @@ -302,81 +297,85 @@ impl AuthenticationPrompt { } } -impl Render for AuthenticationPrompt { +impl Render for ConfigurationView { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - let loading_icon = svg() - .size_8() - .path(IconName::ArrowCircle.path()) - .text_color(cx.text_style().color) - .with_animation( - "icon_circle_arrow", - Animation::new(Duration::from_secs(2)).repeat(), - |svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))), - ); - - const ERROR_LABEL: &str = "Copilot Chat requires the Copilot plugin to be available and running. Please ensure Copilot is running and try again, or use a different Assistant provider."; - match &self.copilot_status { - Some(status) => match status { - Status::Disabled => { - return v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL)); - } - Status::Starting { task: _ } => { - const LABEL: &str = "Starting Copilot..."; - return v_flex() - .gap_6() - .p_4() - .justify_center() - .items_center() - .child(Label::new(LABEL)) - .child(loading_icon); - } - Status::SigningIn { prompt: _ } => { - const LABEL: &str = "Signing in to Copilot..."; - return v_flex() - .gap_6() - .p_4() - .justify_center() - .items_center() - .child(Label::new(LABEL)) - .child(loading_icon); - } - Status::Error(_) => { - const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot."; - return v_flex() - .gap_6() - .p_4() - .child(Label::new(LABEL)) - .child(svg().size_8().path(IconName::CopilotError.path())); - } - _ => { - const LABEL: &str = - "To use the assistant panel or inline assistant, you must login to GitHub Copilot. Your GitHub account must have an active Copilot Chat subscription."; - v_flex().gap_6().p_4().child(Label::new(LABEL)).child( + if self.state.read(cx).is_authenticated(cx) { + const LABEL: &str = "Authorized."; + h_flex() + .gap_2() + .child(Indicator::dot().color(Color::Success)) + .child(Label::new(LABEL)) + } else { + let loading_icon = svg() + .size_8() + .path(IconName::ArrowCircle.path()) + .text_color(cx.text_style().color) + .with_animation( + "icon_circle_arrow", + Animation::new(Duration::from_secs(2)).repeat(), + |svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))), + ); + + const ERROR_LABEL: &str = "Copilot Chat requires the Copilot plugin to be available and running. Please ensure Copilot is running and try again, or use a different Assistant provider."; + + match &self.copilot_status { + Some(status) => match status { + Status::Disabled => v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL)), + Status::Starting { task: _ } => { + const LABEL: &str = "Starting Copilot..."; + v_flex() + .gap_6() + .justify_center() + .items_center() + .child(Label::new(LABEL)) + .child(loading_icon) + } + Status::SigningIn { prompt: _ } => { + const LABEL: &str = "Signing in to Copilot..."; v_flex() - .gap_2() - .child( - Button::new("sign_in", "Sign In") - .icon_color(Color::Muted) - .icon(IconName::Github) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Medium) - .style(ui::ButtonStyle::Filled) - .full_width() - .on_click(|_, cx| { - inline_completion_button::initiate_sign_in(cx) - }), - ) - .child( - div().flex().w_full().items_center().child( - Label::new("Sign in to start using Github Copilot Chat.") - .color(Color::Muted) - .size(ui::LabelSize::Small), + .gap_6() + .justify_center() + .items_center() + .child(Label::new(LABEL)) + .child(loading_icon) + } + Status::Error(_) => { + const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot."; + v_flex() + .gap_6() + .child(Label::new(LABEL)) + .child(svg().size_8().path(IconName::CopilotError.path())) + } + _ => { + const LABEL: &str = + "To use the assistant panel or inline assistant, you must login to GitHub Copilot. Your GitHub account must have an active Copilot Chat subscription."; + v_flex().gap_6().child(Label::new(LABEL)).child( + v_flex() + .gap_2() + .child( + Button::new("sign_in", "Sign In") + .icon_color(Color::Muted) + .icon(IconName::Github) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Medium) + .style(ui::ButtonStyle::Filled) + .full_width() + .on_click(|_, cx| { + inline_completion_button::initiate_sign_in(cx) + }), + ) + .child( + div().flex().w_full().items_center().child( + Label::new("Sign in to start using Github Copilot Chat.") + .color(Color::Muted) + .size(ui::LabelSize::Small), + ), ), - ), - ) - } - }, - None => v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL)), + ) + } + }, + None => v_flex().gap_6().child(Label::new(ERROR_LABEL)), + } } } } diff --git a/crates/language_model/src/provider/fake.rs b/crates/language_model/src/provider/fake.rs index 70f8402bccf827755b60b501f8a7c211eb24f048..511e11788a08ffe50558642f131d52752e2cb0fe 100644 --- a/crates/language_model/src/provider/fake.rs +++ b/crates/language_model/src/provider/fake.rs @@ -6,7 +6,7 @@ use crate::{ use anyhow::anyhow; use collections::HashMap; use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; -use gpui::{AnyView, AppContext, AsyncAppContext, Task}; +use gpui::{AnyView, AppContext, AsyncAppContext, FocusHandle, Task}; use http_client::Result; use std::{ future, @@ -66,7 +66,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider { Task::ready(Ok(())) } - fn authentication_prompt(&self, _: &mut WindowContext) -> AnyView { + fn configuration_view(&self, _: &mut WindowContext) -> (AnyView, Option) { unimplemented!() } diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs index a1a6cbcceb5d3c11520d498f2fd6ac9a4dd694c3..368778ce29198a25baa3b39a4b07e0100775eab9 100644 --- a/crates/language_model/src/provider/google.rs +++ b/crates/language_model/src/provider/google.rs @@ -4,8 +4,8 @@ use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, FutureExt, StreamExt}; use google_ai::stream_generate_content; use gpui::{ - AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View, - WhiteSpace, + AnyView, AppContext, AsyncAppContext, FocusHandle, FocusableView, FontStyle, ModelContext, + Subscription, Task, TextStyle, View, WhiteSpace, }; use http_client::HttpClient; use schemars::JsonSchema; @@ -14,7 +14,7 @@ use settings::{Settings, SettingsStore}; use std::{future, sync::Arc, time::Duration}; use strum::IntoEnumIterator; use theme::ThemeSettings; -use ui::prelude::*; +use ui::{prelude::*, Indicator}; use util::ResultExt; use crate::{ @@ -49,6 +49,24 @@ pub struct State { _subscription: Subscription, } +impl State { + fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + fn reset_api_key(&self, cx: &mut ModelContext) -> Task> { + let delete_credentials = + cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url); + cx.spawn(|this, mut cx| async move { + delete_credentials.await.ok(); + this.update(&mut cx, |this, cx| { + this.api_key = None; + cx.notify(); + }) + }) + } +} + impl GoogleLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { let state = cx.new_model(|cx| State { @@ -118,7 +136,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { } fn is_authenticated(&self, cx: &AppContext) -> bool { - self.state.read(cx).api_key.is_some() + self.state.read(cx).is_authenticated() } fn authenticate(&self, cx: &mut AppContext) -> Task> { @@ -149,9 +167,11 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { } } - fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { - cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx)) - .into() + fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option) { + let view = cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx)); + + let focus_handle = view.focus_handle(cx); + (view.into(), Some(focus_handle)) } fn reset_credentials(&self, cx: &mut AppContext) -> Task> { @@ -267,15 +287,15 @@ impl LanguageModel for GoogleLanguageModel { } } -struct AuthenticationPrompt { - api_key: View, +struct ConfigurationView { + api_key_editor: View, state: gpui::Model, } -impl AuthenticationPrompt { +impl ConfigurationView { fn new(state: gpui::Model, cx: &mut WindowContext) -> Self { Self { - api_key: cx.new_view(|cx| { + api_key_editor: cx.new_view(|cx| { let mut editor = Editor::single_line(cx); editor.set_placeholder_text("AIzaSy...", cx); editor @@ -285,7 +305,7 @@ impl AuthenticationPrompt { } fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { - let api_key = self.api_key.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx); if api_key.is_empty() { return; } @@ -304,6 +324,14 @@ impl AuthenticationPrompt { .detach_and_log_err(cx); } + fn reset_api_key(&mut self, cx: &mut ViewContext) { + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", cx)); + self.state + .update(cx, |state, cx| state.reset_api_key(cx)) + .detach_and_log_err(cx); + } + fn render_api_key_editor(&self, cx: &mut ViewContext) -> impl IntoElement { let settings = ThemeSettings::get_global(cx); let text_style = TextStyle { @@ -321,7 +349,7 @@ impl AuthenticationPrompt { white_space: WhiteSpace::Normal, }; EditorElement::new( - &self.api_key, + &self.api_key_editor, EditorStyle { background: cx.theme().colors().editor_background, local_player: cx.theme().players().local(), @@ -332,7 +360,13 @@ impl AuthenticationPrompt { } } -impl Render for AuthenticationPrompt { +impl FocusableView for ConfigurationView { + fn focus_handle(&self, cx: &AppContext) -> FocusHandle { + self.api_key_editor.read(cx).focus_handle(cx) + } +} + +impl Render for ConfigurationView { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { const INSTRUCTIONS: [&str; 4] = [ "To use the Google AI assistant, you need to add your Google AI API key.", @@ -341,38 +375,48 @@ impl Render for AuthenticationPrompt { "Paste your Google AI API key below and hit enter to use the assistant:", ]; - v_flex() - .p_4() - .size_full() - .on_action(cx.listener(Self::save_api_key)) - .children( - INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), - ) - .child( - h_flex() - .w_full() - .my_2() - .px_2() - .py_1() - .bg(cx.theme().colors().editor_background) - .rounded_md() - .child(self.render_api_key_editor(cx)), - ) - .child( - Label::new( - "You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.", + if self.state.read(cx).is_authenticated() { + h_flex() + .size_full() + .justify_between() + .child( + h_flex() + .gap_2() + .child(Indicator::dot().color(Color::Success)) + .child(Label::new("API Key configured").size(LabelSize::Small)), ) - .size(LabelSize::Small), - ) - .child( - h_flex() - .gap_2() - .child(Label::new("Click on").size(LabelSize::Small)) - .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall)) - .child( - Label::new("in the status bar to close this panel.").size(LabelSize::Small), - ), - ) - .into_any() + .child( + Button::new("reset-key", "Reset key") + .icon(Some(IconName::Trash)) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))), + ) + .into_any() + } else { + v_flex() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .children( + INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), + ) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .rounded_md() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new( + "You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.", + ) + .size(LabelSize::Small), + ) + .into_any() + } } } diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs index 9afa3825b0b4af4c9a5edc7efbb82aabdbb979c2..934b87eb89e9317b8c59bb7ff89c30bd16ac2fcb 100644 --- a/crates/language_model/src/provider/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -1,13 +1,13 @@ use anyhow::{anyhow, Result}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; -use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task}; +use gpui::{AnyView, AppContext, AsyncAppContext, FocusHandle, ModelContext, Subscription, Task}; use http_client::HttpClient; use ollama::{ get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, }; use settings::{Settings, SettingsStore}; use std::{future, sync::Arc, time::Duration}; -use ui::{prelude::*, ButtonLike, ElevationIndex}; +use ui::{prelude::*, ButtonLike, ElevationIndex, Indicator}; use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, @@ -39,6 +39,10 @@ pub struct State { } impl State { + fn is_authenticated(&self) -> bool { + !self.available_models.is_empty() + } + fn fetch_models(&mut self, cx: &mut ModelContext) -> Task> { let settings = &AllLanguageModelSettings::get_global(cx).ollama; let http_client = self.http_client.clone(); @@ -129,7 +133,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { } fn is_authenticated(&self, cx: &AppContext) -> bool { - !self.state.read(cx).available_models.is_empty() + self.state.read(cx).is_authenticated() } fn authenticate(&self, cx: &mut AppContext) -> Task> { @@ -140,14 +144,12 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { } } - fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option) { let state = self.state.clone(); - let fetch_models = Box::new(move |cx: &mut WindowContext| { - state.update(cx, |this, cx| this.fetch_models(cx)) - }); - - cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx)) - .into() + ( + cx.new_view(|cx| ConfigurationView::new(state, cx)).into(), + None, + ) } fn reset_credentials(&self, cx: &mut AppContext) -> Task> { @@ -287,16 +289,19 @@ impl LanguageModel for OllamaLanguageModel { } } -struct DownloadOllamaMessage { - retry_connection: Box Task>>, +struct ConfigurationView { + state: gpui::Model, } -impl DownloadOllamaMessage { - pub fn new( - retry_connection: Box Task>>, - _cx: &mut ViewContext, - ) -> Self { - Self { retry_connection } +impl ConfigurationView { + pub fn new(state: gpui::Model, _cx: &mut ViewContext) -> Self { + Self { state } + } + + fn retry_connection(&self, cx: &mut WindowContext) { + self.state + .update(cx, |state, cx| state.fetch_models(cx)) + .detach_and_log_err(cx); } fn render_download_button(&self, _cx: &mut ViewContext) -> impl IntoElement { @@ -314,15 +319,7 @@ impl DownloadOllamaMessage { .size(ButtonSize::Large) .layer(ElevationIndex::ModalSurface) .child(Label::new("Retry")) - .on_click(cx.listener(move |this, _, cx| { - let connected = (this.retry_connection)(cx); - - cx.spawn(|_this, _cx| async move { - connected.await?; - anyhow::Ok(()) - }) - .detach_and_log_err(cx) - })) + .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx))) } fn render_next_steps(&self, _cx: &mut ViewContext) -> impl IntoElement { @@ -347,10 +344,22 @@ impl DownloadOllamaMessage { } } -impl Render for DownloadOllamaMessage { +impl Render for ConfigurationView { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - v_flex() - .p_4() + let is_authenticated = self.state.read(cx).is_authenticated(); + + if is_authenticated { + v_flex() + .size_full() + .child( + h_flex() + .gap_2() + .child(Indicator::dot().color(Color::Success)) + .child(Label::new("Ollama configured").size(LabelSize::Small)), + ) + .into_any() + } else { + v_flex() .size_full() .gap_2() .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large)) @@ -369,5 +378,6 @@ impl Render for DownloadOllamaMessage { ) .child(self.render_next_steps(cx)) .into_any() + } } } diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index e0239d959bea80e9fe2f642679f5459fbd1af716..d8a683c7db30a7c567993effd14be692b3c06852 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -3,8 +3,8 @@ use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, FutureExt, StreamExt}; use gpui::{ - AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View, - WhiteSpace, + AnyView, AppContext, AsyncAppContext, FocusHandle, FocusableView, FontStyle, ModelContext, + Subscription, Task, TextStyle, View, WhiteSpace, }; use http_client::HttpClient; use open_ai::stream_completion; @@ -14,7 +14,7 @@ use settings::{Settings, SettingsStore}; use std::{future, sync::Arc, time::Duration}; use strum::IntoEnumIterator; use theme::ThemeSettings; -use ui::prelude::*; +use ui::{prelude::*, Indicator}; use util::ResultExt; use crate::{ @@ -50,6 +50,24 @@ pub struct State { _subscription: Subscription, } +impl State { + fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + fn reset_api_key(&self, cx: &mut ModelContext) -> Task> { + let settings = &AllLanguageModelSettings::get_global(cx).openai; + let delete_credentials = cx.delete_credentials(&settings.api_url); + cx.spawn(|this, mut cx| async move { + delete_credentials.await.log_err(); + this.update(&mut cx, |this, cx| { + this.api_key = None; + cx.notify(); + }) + }) + } +} + impl OpenAiLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { let state = cx.new_model(|cx| State { @@ -119,7 +137,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { } fn is_authenticated(&self, cx: &AppContext) -> bool { - self.state.read(cx).api_key.is_some() + self.state.read(cx).is_authenticated() } fn authenticate(&self, cx: &mut AppContext) -> Task> { @@ -149,22 +167,14 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { } } - fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { - cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx)) - .into() + fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option) { + let view = cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx)); + let focus_handle = view.focus_handle(cx); + (view.into(), Some(focus_handle)) } fn reset_credentials(&self, cx: &mut AppContext) -> Task> { - let settings = &AllLanguageModelSettings::get_global(cx).openai; - let delete_credentials = cx.delete_credentials(&settings.api_url); - let state = self.state.clone(); - cx.spawn(|mut cx| async move { - delete_credentials.await.log_err(); - state.update(&mut cx, |this, cx| { - this.api_key = None; - cx.notify(); - }) - }) + self.state.update(cx, |state, cx| state.reset_api_key(cx)) } } @@ -287,15 +297,15 @@ pub fn count_open_ai_tokens( .boxed() } -struct AuthenticationPrompt { - api_key: View, +struct ConfigurationView { + api_key_editor: View, state: gpui::Model, } -impl AuthenticationPrompt { +impl ConfigurationView { fn new(state: gpui::Model, cx: &mut WindowContext) -> Self { Self { - api_key: cx.new_view(|cx| { + api_key_editor: cx.new_view(|cx| { let mut editor = Editor::single_line(cx); editor.set_placeholder_text( "sk-000000000000000000000000000000000000000000000000", @@ -308,7 +318,7 @@ impl AuthenticationPrompt { } fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { - let api_key = self.api_key.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx); if api_key.is_empty() { return; } @@ -327,6 +337,14 @@ impl AuthenticationPrompt { .detach_and_log_err(cx); } + fn reset_api_key(&mut self, cx: &mut ViewContext) { + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", cx)); + self.state.update(cx, |state, cx| { + state.reset_api_key(cx).detach_and_log_err(cx); + }) + } + fn render_api_key_editor(&self, cx: &mut ViewContext) -> impl IntoElement { let settings = ThemeSettings::get_global(cx); let text_style = TextStyle { @@ -344,7 +362,7 @@ impl AuthenticationPrompt { white_space: WhiteSpace::Normal, }; EditorElement::new( - &self.api_key, + &self.api_key_editor, EditorStyle { background: cx.theme().colors().editor_background, local_player: cx.theme().players().local(), @@ -355,7 +373,13 @@ impl AuthenticationPrompt { } } -impl Render for AuthenticationPrompt { +impl FocusableView for ConfigurationView { + fn focus_handle(&self, cx: &AppContext) -> FocusHandle { + self.api_key_editor.read(cx).focus_handle(cx) + } +} + +impl Render for ConfigurationView { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { const INSTRUCTIONS: [&str; 6] = [ "To use the assistant panel or inline assistant, you need to add your OpenAI API key.", @@ -366,38 +390,48 @@ impl Render for AuthenticationPrompt { "Paste your OpenAI API key below and hit enter to use the assistant:", ]; - v_flex() - .p_4() - .size_full() - .on_action(cx.listener(Self::save_api_key)) - .children( - INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), - ) - .child( - h_flex() - .w_full() - .my_2() - .px_2() - .py_1() - .bg(cx.theme().colors().editor_background) - .rounded_md() - .child(self.render_api_key_editor(cx)), - ) - .child( - Label::new( - "You can also assign the OPENAI_API_KEY environment variable and restart Zed.", + if self.state.read(cx).is_authenticated() { + h_flex() + .size_full() + .justify_between() + .child( + h_flex() + .gap_2() + .child(Indicator::dot().color(Color::Success)) + .child(Label::new("API Key configured").size(LabelSize::Small)), ) - .size(LabelSize::Small), - ) - .child( - h_flex() - .gap_2() - .child(Label::new("Click on").size(LabelSize::Small)) - .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall)) - .child( - Label::new("in the status bar to close this panel.").size(LabelSize::Small), - ), - ) - .into_any() + .child( + Button::new("reset-key", "Reset key") + .icon(Some(IconName::Trash)) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))), + ) + .into_any() + } else { + v_flex() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .children( + INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), + ) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .rounded_md() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new( + "You can also assign the OPENAI_API_KEY environment variable and restart Zed.", + ) + .size(LabelSize::Small), + ) + .into_any() + } } }