diff --git a/Cargo.lock b/Cargo.lock index 0d83b2b9b912ab112d9b38fd1ef1d5ff21f9049c..1bb39f2bdf8c5745b3e5c0e5ad1200be34ec6ab0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8932,6 +8932,8 @@ dependencies = [ "credentials_provider", "deepseek", "editor", + "extension", + "extension_host", "fs", "futures 0.3.31", "google_ai", diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index a670ba601159ec323ad2c88695c30bf4aeae4118..598d0428174eb2fc124739a18ddeff1098521cb7 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -210,12 +210,21 @@ pub trait AgentModelSelector: 'static { } } +/// Icon for a model in the model selector. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AgentModelIcon { + /// A built-in icon from Zed's icon set. + Named(IconName), + /// Path to a custom SVG icon file. + Path(SharedString), +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct AgentModelInfo { pub id: acp::ModelId, pub name: SharedString, pub description: Option, - pub icon: Option, + pub icon: Option, } impl From for AgentModelInfo { diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 43ed3b90f3556eb24e45440a7fe0038e7a1b9535..4baa7f4ea4004d2137b5cddb255346fa91523091 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -30,7 +30,7 @@ use futures::{StreamExt, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, }; -use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry}; +use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry}; use project::{Project, ProjectItem, ProjectPath, Worktree}; use prompt_store::{ ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext, @@ -93,7 +93,7 @@ impl LanguageModels { fn refresh_list(&mut self, cx: &App) { let providers = LanguageModelRegistry::global(cx) .read(cx) - .providers() + .visible_providers() .into_iter() .filter(|provider| provider.is_authenticated(cx)) .collect::>(); @@ -153,7 +153,10 @@ impl LanguageModels { id: Self::model_id(model), name: model.name().0, description: None, - icon: Some(provider.icon()), + icon: Some(match provider.icon() { + IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path), + IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name), + }), } } @@ -164,7 +167,7 @@ impl LanguageModels { fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> { let authenticate_all_providers = LanguageModelRegistry::global(cx) .read(cx) - .providers() + .visible_providers() .iter() .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx))) .collect::>(); @@ -1630,7 +1633,9 @@ mod internal_tests { id: acp::ModelId::new("fake/fake"), name: "Fake".into(), description: None, - icon: Some(ui::IconName::ZedAssistant), + icon: Some(acp_thread::AgentModelIcon::Named( + ui::IconName::ZedAssistant + )), }] )]) ); diff --git a/crates/agent_ui/src/acp/model_selector.rs b/crates/agent_ui/src/acp/model_selector.rs index f3c07250de3cefc798b97d9ffad444489d153219..903d5fe425d99389aae0e2a8028d9a31b986fbb3 100644 --- a/crates/agent_ui/src/acp/model_selector.rs +++ b/crates/agent_ui/src/acp/model_selector.rs @@ -1,6 +1,6 @@ use std::{cmp::Reverse, rc::Rc, sync::Arc}; -use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector}; +use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector}; use agent_client_protocol::ModelId; use agent_servers::AgentServer; use agent_settings::AgentSettings; @@ -350,7 +350,11 @@ impl PickerDelegate for AcpModelPickerDelegate { }) .child( ModelSelectorListItem::new(ix, model_info.name.clone()) - .when_some(model_info.icon, |this, icon| this.icon(icon)) + .map(|this| match &model_info.icon { + Some(AgentModelIcon::Path(path)) => this.icon_path(path.clone()), + Some(AgentModelIcon::Named(icon)) => this.icon(*icon), + None => this, + }) .is_selected(is_selected) .is_focused(selected) .when(supports_favorites, |this| { diff --git a/crates/agent_ui/src/acp/model_selector_popover.rs b/crates/agent_ui/src/acp/model_selector_popover.rs index d6709081863c9545fba4c6e2304f195e77b013df..a15c01445dd8e9845f6744e795ed90a1ede6c7fc 100644 --- a/crates/agent_ui/src/acp/model_selector_popover.rs +++ b/crates/agent_ui/src/acp/model_selector_popover.rs @@ -1,7 +1,7 @@ use std::rc::Rc; use std::sync::Arc; -use acp_thread::{AgentModelInfo, AgentModelSelector}; +use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelSelector}; use agent_servers::AgentServer; use agent_settings::AgentSettings; use fs::Fs; @@ -70,7 +70,7 @@ impl Render for AcpModelSelectorPopover { .map(|model| model.name.clone()) .unwrap_or_else(|| SharedString::from("Select a Model")); - let model_icon = model.as_ref().and_then(|model| model.icon); + let model_icon = model.as_ref().and_then(|model| model.icon.clone()); let focus_handle = self.focus_handle.clone(); @@ -125,7 +125,14 @@ impl Render for AcpModelSelectorPopover { ButtonLike::new("active-model") .selected_style(ButtonStyle::Tinted(TintColor::Accent)) .when_some(model_icon, |this, icon| { - this.child(Icon::new(icon).color(color).size(IconSize::XSmall)) + this.child( + match icon { + AgentModelIcon::Path(path) => Icon::from_external_svg(path), + AgentModelIcon::Named(icon_name) => Icon::new(icon_name), + } + .color(color) + .size(IconSize::XSmall), + ) }) .child( Label::new(model_name) diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 24f019c605d1b167e62a6e68dfc1f3ed07c73f1c..562976453d963db65f9033536e528000de2b510f 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -22,7 +22,8 @@ use gpui::{ }; use language::LanguageRegistry; use language_model::{ - LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID, + IconOrSvg, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, + ZED_CLOUD_PROVIDER_ID, }; use language_models::AllLanguageModelSettings; use notifications::status_toast::{StatusToast, ToastIcon}; @@ -117,7 +118,7 @@ impl AgentConfiguration { } fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context) { - let providers = LanguageModelRegistry::read_global(cx).providers(); + let providers = LanguageModelRegistry::read_global(cx).visible_providers(); for provider in providers { self.add_provider_configuration_view(&provider, window, cx); } @@ -261,9 +262,12 @@ impl AgentConfiguration { .w_full() .gap_1p5() .child( - Icon::new(provider.icon()) - .size(IconSize::Small) - .color(Color::Muted), + match provider.icon() { + IconOrSvg::Svg(path) => Icon::from_external_svg(path), + IconOrSvg::Icon(name) => Icon::new(name), + } + .size(IconSize::Small) + .color(Color::Muted), ) .child( h_flex() @@ -416,7 +420,7 @@ impl AgentConfiguration { &mut self, cx: &mut Context, ) -> impl IntoElement { - let providers = LanguageModelRegistry::read_global(cx).providers(); + let providers = LanguageModelRegistry::read_global(cx).visible_providers(); let popover_menu = PopoverMenu::new("add-provider-popover") .trigger( diff --git a/crates/agent_ui/src/agent_model_selector.rs b/crates/agent_ui/src/agent_model_selector.rs index ac57ed575d9d1b6de2c53d3e0e4a91b4bd16ab1a..45cefbf2b9f8d4b1639a9849f2ee2e4468e530b1 100644 --- a/crates/agent_ui/src/agent_model_selector.rs +++ b/crates/agent_ui/src/agent_model_selector.rs @@ -4,6 +4,7 @@ use crate::{ }; use fs::Fs; use gpui::{Entity, FocusHandle, SharedString}; +use language_model::IconOrSvg; use picker::popover_menu::PickerPopoverMenu; use settings::update_settings_file; use std::sync::Arc; @@ -103,7 +104,14 @@ impl Render for AgentModelSelector { self.selector.clone(), ButtonLike::new("active-model") .when_some(provider_icon, |this, icon| { - this.child(Icon::new(icon).color(color).size(IconSize::XSmall)) + this.child( + match icon { + IconOrSvg::Svg(path) => Icon::from_external_svg(path), + IconOrSvg::Icon(name) => Icon::new(name), + } + .color(color) + .size(IconSize::XSmall), + ) }) .selected_style(ButtonStyle::Tinted(TintColor::Accent)) .child( @@ -115,7 +123,7 @@ impl Render for AgentModelSelector { .child( Icon::new(IconName::ChevronDown) .color(color) - .size(IconSize::Small), + .size(IconSize::XSmall), ), move |_window, cx| { Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx) diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 294cd8b4888950f6ea92d6bea1eba78c3d6d6de2..a050f75120cd73949251c09c8424314e3616c705 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -2428,7 +2428,7 @@ impl AgentPanel { let history_is_empty = self.history_store.read(cx).is_empty(cx); let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx) - .providers() + .visible_providers() .iter() .any(|provider| { provider.is_authenticated(cx) diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 02cb7e59948b10274302bd8cd6f74f1accbd30a3..401b506b302d9c2a86a36ddce0fc72df075f4c18 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -348,7 +348,8 @@ fn init_language_model_settings(cx: &mut App) { |_, event: &language_model::Event, cx| match event { language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) - | language_model::Event::RemovedProvider(_) => { + | language_model::Event::RemovedProvider(_) + | language_model::Event::ProvidersChanged => { update_active_language_model_from_settings(cx); } _ => {} diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 77c8c95255908dc54639ad7ac6c55f1e8b8151f0..704e340ace35f33f757ab7708f96ffc940a8eb91 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -7,8 +7,8 @@ use gpui::{ Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task, }; use language_model::{ - AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProvider, - LanguageModelProviderId, LanguageModelRegistry, + AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId, + LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, }; use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; @@ -55,7 +55,7 @@ pub fn language_model_selector( fn all_models(cx: &App) -> GroupedModels { let lm_registry = LanguageModelRegistry::global(cx).read(cx); - let providers = lm_registry.providers(); + let providers = lm_registry.visible_providers(); let mut favorites_index = FavoritesIndex::default(); @@ -94,7 +94,7 @@ type FavoritesIndex = HashMap> #[derive(Clone)] struct ModelInfo { model: Arc, - icon: IconName, + icon: IconOrSvg, is_favorite: bool, } @@ -203,7 +203,7 @@ impl LanguageModelPickerDelegate { fn authenticate_all_providers(cx: &mut App) -> Task<()> { let authenticate_all_providers = LanguageModelRegistry::global(cx) .read(cx) - .providers() + .visible_providers() .iter() .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx))) .collect::>(); @@ -474,7 +474,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { let configured_providers = language_model_registry .read(cx) - .providers() + .visible_providers() .into_iter() .filter(|provider| provider.is_authenticated(cx)) .collect::>(); @@ -566,7 +566,10 @@ impl PickerDelegate for LanguageModelPickerDelegate { Some( ModelSelectorListItem::new(ix, model_info.model.name().0) - .icon(model_info.icon) + .map(|this| match &model_info.icon { + IconOrSvg::Icon(icon_name) => this.icon(*icon_name), + IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()), + }) .is_selected(is_selected) .is_focused(selected) .is_favorite(is_favorite) @@ -702,7 +705,7 @@ mod tests { .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name); ModelInfo { model: Arc::new(TestLanguageModel::new(name, provider)), - icon: IconName::Ai, + icon: IconOrSvg::Icon(IconName::Ai), is_favorite, } }) diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 16d12cf261d3bbb8eb0b879394fedc1cc96e046c..514f45528427af89eeccf85512abf850a7a1be05 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -33,7 +33,8 @@ use language::{ language_settings::{SoftWrap, all_language_settings}, }; use language_model::{ - ConfigurationError, LanguageModelExt, LanguageModelImage, LanguageModelRegistry, Role, + ConfigurationError, IconOrSvg, LanguageModelExt, LanguageModelImage, LanguageModelRegistry, + Role, }; use multi_buffer::MultiBufferRow; use picker::{Picker, popover_menu::PickerPopoverMenu}; @@ -2231,10 +2232,10 @@ impl TextThreadEditor { .default_model() .map(|default| default.provider); - let provider_icon = match active_provider { - Some(provider) => provider.icon(), - None => IconName::Ai, - }; + let provider_icon = active_provider + .as_ref() + .map(|p| p.icon()) + .unwrap_or(IconOrSvg::Icon(IconName::Ai)); let focus_handle = self.editor().focus_handle(cx); @@ -2244,6 +2245,13 @@ impl TextThreadEditor { (Color::Muted, IconName::ChevronDown) }; + let provider_icon_element = match provider_icon { + IconOrSvg::Svg(path) => Icon::from_external_svg(path), + IconOrSvg::Icon(name) => Icon::new(name), + } + .color(color) + .size(IconSize::XSmall); + let tooltip = Tooltip::element({ move |_, cx| { let focus_handle = focus_handle.clone(); @@ -2291,7 +2299,7 @@ impl TextThreadEditor { .child( h_flex() .gap_0p5() - .child(Icon::new(provider_icon).color(color).size(IconSize::XSmall)) + .child(provider_icon_element) .child( Label::new(model_name) .color(color) diff --git a/crates/agent_ui/src/ui/model_selector_components.rs b/crates/agent_ui/src/ui/model_selector_components.rs index 061b4f58288798696b068a091fb392c033906627..beb0c13d761aa9e7e41c2ac4e35a8cfcc7e8d869 100644 --- a/crates/agent_ui/src/ui/model_selector_components.rs +++ b/crates/agent_ui/src/ui/model_selector_components.rs @@ -1,6 +1,11 @@ use gpui::{Action, FocusHandle, prelude::*}; use ui::{ElevationIndex, KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*}; +enum ModelIcon { + Name(IconName), + Path(SharedString), +} + #[derive(IntoElement)] pub struct ModelSelectorHeader { title: SharedString, @@ -39,7 +44,7 @@ impl RenderOnce for ModelSelectorHeader { pub struct ModelSelectorListItem { index: usize, title: SharedString, - icon: Option, + icon: Option, is_selected: bool, is_focused: bool, is_favorite: bool, @@ -60,7 +65,12 @@ impl ModelSelectorListItem { } pub fn icon(mut self, icon: IconName) -> Self { - self.icon = Some(icon); + self.icon = Some(ModelIcon::Name(icon)); + self + } + + pub fn icon_path(mut self, path: SharedString) -> Self { + self.icon = Some(ModelIcon::Path(path)); self } @@ -105,9 +115,12 @@ impl RenderOnce for ModelSelectorListItem { .gap_1p5() .when_some(self.icon, |this, icon| { this.child( - Icon::new(icon) - .color(model_icon_color) - .size(IconSize::Small), + match icon { + ModelIcon::Name(icon_name) => Icon::new(icon_name), + ModelIcon::Path(icon_path) => Icon::from_external_svg(icon_path), + } + .color(model_icon_color) + .size(IconSize::Small), ) }) .child(Label::new(self.title).truncate()), diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs index fadc4222ae44f3dbad862fd9479b89321dbd3016..47197ec2331b97dd4d7561d9f14c91c7f91c9fa0 100644 --- a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -1,9 +1,9 @@ use gpui::{Action, IntoElement, ParentElement, RenderOnce, point}; -use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; +use language_model::{IconOrSvg, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; use ui::{Divider, List, ListBulletItem, prelude::*}; pub struct ApiKeysWithProviders { - configured_providers: Vec<(IconName, SharedString)>, + configured_providers: Vec<(IconOrSvg, SharedString)>, } impl ApiKeysWithProviders { @@ -13,7 +13,8 @@ impl ApiKeysWithProviders { |this: &mut Self, _registry, event: &language_model::Event, cx| match event { language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) - | language_model::Event::RemovedProvider(_) => { + | language_model::Event::RemovedProvider(_) + | language_model::Event::ProvidersChanged => { this.configured_providers = Self::compute_configured_providers(cx) } _ => {} @@ -26,9 +27,9 @@ impl ApiKeysWithProviders { } } - fn compute_configured_providers(cx: &App) -> Vec<(IconName, SharedString)> { + fn compute_configured_providers(cx: &App) -> Vec<(IconOrSvg, SharedString)> { LanguageModelRegistry::read_global(cx) - .providers() + .visible_providers() .iter() .filter(|provider| { provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID @@ -47,7 +48,14 @@ impl Render for ApiKeysWithProviders { .map(|(icon, name)| { h_flex() .gap_1p5() - .child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted)) + .child( + match icon { + IconOrSvg::Icon(icon_name) => Icon::new(icon_name), + IconOrSvg::Svg(icon_path) => Icon::from_external_svg(icon_path), + } + .size(IconSize::XSmall) + .color(Color::Muted), + ) .child(Label::new(name)) }); div() diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs index 3c8ffc1663e0660829698b5449a006de5b3c6009..c2756927136449d649996ec3b4b87471114aca38 100644 --- a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -11,7 +11,7 @@ use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding}; pub struct AgentPanelOnboarding { user_store: Entity, client: Arc, - configured_providers: Vec<(IconName, SharedString)>, + has_configured_providers: bool, continue_with_zed_ai: Arc, } @@ -27,8 +27,9 @@ impl AgentPanelOnboarding { |this: &mut Self, _registry, event: &language_model::Event, cx| match event { language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) - | language_model::Event::RemovedProvider(_) => { - this.configured_providers = Self::compute_available_providers(cx) + | language_model::Event::RemovedProvider(_) + | language_model::Event::ProvidersChanged => { + this.has_configured_providers = Self::has_configured_providers(cx) } _ => {} }, @@ -38,20 +39,16 @@ impl AgentPanelOnboarding { Self { user_store, client, - configured_providers: Self::compute_available_providers(cx), + has_configured_providers: Self::has_configured_providers(cx), continue_with_zed_ai: Arc::new(continue_with_zed_ai), } } - fn compute_available_providers(cx: &App) -> Vec<(IconName, SharedString)> { + fn has_configured_providers(cx: &App) -> bool { LanguageModelRegistry::read_global(cx) - .providers() + .visible_providers() .iter() - .filter(|provider| { - provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID - }) - .map(|provider| (provider.icon(), provider.name().0)) - .collect() + .any(|provider| provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID) } } @@ -81,7 +78,7 @@ impl Render for AgentPanelOnboarding { }), ) .map(|this| { - if enrolled_in_trial || is_pro_user || !self.configured_providers.is_empty() { + if enrolled_in_trial || is_pro_user || self.has_configured_providers { this } else { this.child(ApiKeysWithoutProviders::new()) diff --git a/crates/extension/src/extension_host_proxy.rs b/crates/extension/src/extension_host_proxy.rs index 6a24e3ba3f496bd0f0b89d61e9125b29ecae0204..b445878389015d4b3b8c3e25a0d103586462fd86 100644 --- a/crates/extension/src/extension_host_proxy.rs +++ b/crates/extension/src/extension_host_proxy.rs @@ -19,6 +19,9 @@ impl Global for GlobalExtensionHostProxy {} /// /// This object implements each of the individual proxy types so that their /// methods can be called directly on it. +/// Registration function for language model providers. +pub type LanguageModelProviderRegistration = Box; + #[derive(Default)] pub struct ExtensionHostProxy { theme_proxy: RwLock>>, @@ -29,6 +32,7 @@ pub struct ExtensionHostProxy { slash_command_proxy: RwLock>>, context_server_proxy: RwLock>>, debug_adapter_provider_proxy: RwLock>>, + language_model_provider_proxy: RwLock>>, } impl ExtensionHostProxy { @@ -54,6 +58,7 @@ impl ExtensionHostProxy { slash_command_proxy: RwLock::default(), context_server_proxy: RwLock::default(), debug_adapter_provider_proxy: RwLock::default(), + language_model_provider_proxy: RwLock::default(), } } @@ -90,6 +95,15 @@ impl ExtensionHostProxy { .write() .replace(Arc::new(proxy)); } + + pub fn register_language_model_provider_proxy( + &self, + proxy: impl ExtensionLanguageModelProviderProxy, + ) { + self.language_model_provider_proxy + .write() + .replace(Arc::new(proxy)); + } } pub trait ExtensionThemeProxy: Send + Sync + 'static { @@ -446,3 +460,37 @@ impl ExtensionDebugAdapterProviderProxy for ExtensionHostProxy { proxy.unregister_debug_locator(locator_name) } } + +pub trait ExtensionLanguageModelProviderProxy: Send + Sync + 'static { + fn register_language_model_provider( + &self, + provider_id: Arc, + register_fn: LanguageModelProviderRegistration, + cx: &mut App, + ); + + fn unregister_language_model_provider(&self, provider_id: Arc, cx: &mut App); +} + +impl ExtensionLanguageModelProviderProxy for ExtensionHostProxy { + fn register_language_model_provider( + &self, + provider_id: Arc, + register_fn: LanguageModelProviderRegistration, + cx: &mut App, + ) { + let Some(proxy) = self.language_model_provider_proxy.read().clone() else { + return; + }; + + proxy.register_language_model_provider(provider_id, register_fn, cx) + } + + fn unregister_language_model_provider(&self, provider_id: Arc, cx: &mut App) { + let Some(proxy) = self.language_model_provider_proxy.read().clone() else { + return; + }; + + proxy.unregister_language_model_provider(provider_id, cx) + } +} diff --git a/crates/extension/src/extension_manifest.rs b/crates/extension/src/extension_manifest.rs index 4ecdd378ca86dbee263e439e13fa4776dab9e316..39b629db30d0d1cee3374dafc317bdeb0f368146 100644 --- a/crates/extension/src/extension_manifest.rs +++ b/crates/extension/src/extension_manifest.rs @@ -93,6 +93,8 @@ pub struct ExtensionManifest { pub debug_adapters: BTreeMap, DebugAdapterManifestEntry>, #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] pub debug_locators: BTreeMap, DebugLocatorManifestEntry>, + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + pub language_model_providers: BTreeMap, LanguageModelProviderManifestEntry>, } impl ExtensionManifest { @@ -288,6 +290,16 @@ pub struct DebugAdapterManifestEntry { #[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)] pub struct DebugLocatorManifestEntry {} +/// Manifest entry for a language model provider. +#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)] +pub struct LanguageModelProviderManifestEntry { + /// Display name for the provider. + pub name: String, + /// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg"). + #[serde(default)] + pub icon: Option, +} + impl ExtensionManifest { pub async fn load(fs: Arc, extension_dir: &Path) -> Result { let extension_name = extension_dir @@ -358,6 +370,7 @@ fn manifest_from_old_manifest( capabilities: Vec::new(), debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: Default::default(), } } @@ -391,6 +404,7 @@ mod tests { capabilities: vec![], debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: BTreeMap::default(), } } diff --git a/crates/extension_host/benches/extension_compilation_benchmark.rs b/crates/extension_host/benches/extension_compilation_benchmark.rs index a28f617dc36e5cba3ad36d7ab6477e7a665dd5c4..605b98c67071155d8444639ef7043b9c8901161d 100644 --- a/crates/extension_host/benches/extension_compilation_benchmark.rs +++ b/crates/extension_host/benches/extension_compilation_benchmark.rs @@ -148,6 +148,7 @@ fn manifest() -> ExtensionManifest { )], debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: BTreeMap::default(), } } diff --git a/crates/extension_host/src/capability_granter.rs b/crates/extension_host/src/capability_granter.rs index 9f27b5e480bc3c22faefe67cd49a06af21614096..6278deef0a7d41e40d4444ddbe992f007cd5e53e 100644 --- a/crates/extension_host/src/capability_granter.rs +++ b/crates/extension_host/src/capability_granter.rs @@ -113,6 +113,7 @@ mod tests { capabilities: vec![], debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: BTreeMap::default(), } } diff --git a/crates/extension_host/src/extension_store_test.rs b/crates/extension_host/src/extension_store_test.rs index 54b090347ffad3ffed444827f5cb60c120d25ad7..c17484f26a06b3392cdbcd8f3c1578eb43c7b213 100644 --- a/crates/extension_host/src/extension_store_test.rs +++ b/crates/extension_host/src/extension_store_test.rs @@ -165,6 +165,7 @@ async fn test_extension_store(cx: &mut TestAppContext) { capabilities: Vec::new(), debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: BTreeMap::default(), }), dev: false, }, @@ -196,6 +197,7 @@ async fn test_extension_store(cx: &mut TestAppContext) { capabilities: Vec::new(), debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: BTreeMap::default(), }), dev: false, }, @@ -376,6 +378,7 @@ async fn test_extension_store(cx: &mut TestAppContext) { capabilities: Vec::new(), debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: BTreeMap::default(), }), dev: false, }, diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 09d44b5b408324936af00a2a5e4f1deb4f351434..56a970404419ec6042c463d26c2844eb0904f829 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -797,11 +797,26 @@ pub enum AuthenticateError { Other(#[from] anyhow::Error), } +/// Either a built-in icon name or a path to an external SVG. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum IconOrSvg { + /// A built-in icon from Zed's icon set. + Icon(IconName), + /// Path to a custom SVG icon file. + Svg(SharedString), +} + +impl Default for IconOrSvg { + fn default() -> Self { + Self::Icon(IconName::ZedAssistant) + } +} + pub trait LanguageModelProvider: 'static { fn id(&self) -> LanguageModelProviderId; fn name(&self) -> LanguageModelProviderName; - fn icon(&self) -> IconName { - IconName::ZedAssistant + fn icon(&self) -> IconOrSvg { + IconOrSvg::default() } fn default_model(&self, cx: &App) -> Option>; fn default_fast_model(&self, cx: &App) -> Option>; @@ -820,7 +835,7 @@ pub trait LanguageModelProvider: 'static { fn reset_credentials(&self, cx: &mut App) -> Task>; } -#[derive(Default, Clone)] +#[derive(Default, Clone, PartialEq, Eq)] pub enum ConfigurationViewTargetAgent { #[default] ZedAgent, diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 27b8309810962981d3c0ec78e6e67dfdfba122bf..cf7718f7b102010cc0c8a981a0425583436176b7 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -2,12 +2,16 @@ use crate::{ LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState, }; -use collections::BTreeMap; +use collections::{BTreeMap, HashSet}; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; use std::{str::FromStr, sync::Arc}; use thiserror::Error; use util::maybe; +/// Function type for checking if a built-in provider should be hidden. +/// Returns Some(extension_id) if the provider should be hidden when that extension is installed. +pub type BuiltinProviderHidingFn = Box Option<&'static str> + Send + Sync>; + pub fn init(cx: &mut App) { let registry = cx.new(|_cx| LanguageModelRegistry::default()); cx.set_global(GlobalLanguageModelRegistry(registry)); @@ -48,6 +52,11 @@ pub struct LanguageModelRegistry { thread_summary_model: Option, providers: BTreeMap>, inline_alternatives: Vec>, + /// Set of installed extension IDs that provide language models. + /// Used to determine which built-in providers should be hidden. + installed_llm_extension_ids: HashSet>, + /// Function to check if a built-in provider should be hidden by an extension. + builtin_provider_hiding_fn: Option, } #[derive(Debug)] @@ -104,6 +113,8 @@ pub enum Event { ProviderStateChanged(LanguageModelProviderId), AddedProvider(LanguageModelProviderId), RemovedProvider(LanguageModelProviderId), + /// Emitted when provider visibility changes due to extension install/uninstall. + ProvidersChanged, } impl EventEmitter for LanguageModelRegistry {} @@ -183,6 +194,60 @@ impl LanguageModelRegistry { providers } + /// Returns providers, filtering out hidden built-in providers. + pub fn visible_providers(&self) -> Vec> { + self.providers() + .into_iter() + .filter(|p| !self.should_hide_provider(&p.id())) + .collect() + } + + /// Sets the function used to check if a built-in provider should be hidden. + pub fn set_builtin_provider_hiding_fn(&mut self, hiding_fn: BuiltinProviderHidingFn) { + self.builtin_provider_hiding_fn = Some(hiding_fn); + } + + /// Called when an extension is installed/loaded. + /// If the extension provides language models, track it so we can hide the corresponding built-in. + pub fn extension_installed(&mut self, extension_id: Arc, cx: &mut Context) { + if self.installed_llm_extension_ids.insert(extension_id) { + cx.emit(Event::ProvidersChanged); + cx.notify(); + } + } + + /// Called when an extension is uninstalled/unloaded. + pub fn extension_uninstalled(&mut self, extension_id: &str, cx: &mut Context) { + if self.installed_llm_extension_ids.remove(extension_id) { + cx.emit(Event::ProvidersChanged); + cx.notify(); + } + } + + /// Sync the set of installed LLM extension IDs. + pub fn sync_installed_llm_extensions( + &mut self, + extension_ids: HashSet>, + cx: &mut Context, + ) { + if extension_ids != self.installed_llm_extension_ids { + self.installed_llm_extension_ids = extension_ids; + cx.emit(Event::ProvidersChanged); + cx.notify(); + } + } + + /// Returns true if a provider should be hidden from the UI. + /// Built-in providers are hidden when their corresponding extension is installed. + pub fn should_hide_provider(&self, provider_id: &LanguageModelProviderId) -> bool { + if let Some(ref hiding_fn) = self.builtin_provider_hiding_fn { + if let Some(extension_id) = hiding_fn(&provider_id.0) { + return self.installed_llm_extension_ids.contains(extension_id); + } + } + false + } + pub fn configuration_error( &self, model: Option, @@ -416,4 +481,132 @@ mod tests { let providers = registry.read(cx).providers(); assert!(providers.is_empty()); } + + #[gpui::test] + fn test_provider_hiding_on_extension_install(cx: &mut App) { + let registry = cx.new(|_| LanguageModelRegistry::default()); + + let provider = Arc::new(FakeLanguageModelProvider::default()); + let provider_id = provider.id(); + + registry.update(cx, |registry, cx| { + registry.register_provider(provider.clone(), cx); + + registry.set_builtin_provider_hiding_fn(Box::new(|id| { + if id == "fake" { + Some("fake-extension") + } else { + None + } + })); + }); + + let visible = registry.read(cx).visible_providers(); + assert_eq!(visible.len(), 1); + assert_eq!(visible[0].id(), provider_id); + + registry.update(cx, |registry, cx| { + registry.extension_installed("fake-extension".into(), cx); + }); + + let visible = registry.read(cx).visible_providers(); + assert!(visible.is_empty()); + + let all = registry.read(cx).providers(); + assert_eq!(all.len(), 1); + } + + #[gpui::test] + fn test_provider_unhiding_on_extension_uninstall(cx: &mut App) { + let registry = cx.new(|_| LanguageModelRegistry::default()); + + let provider = Arc::new(FakeLanguageModelProvider::default()); + let provider_id = provider.id(); + + registry.update(cx, |registry, cx| { + registry.register_provider(provider.clone(), cx); + + registry.set_builtin_provider_hiding_fn(Box::new(|id| { + if id == "fake" { + Some("fake-extension") + } else { + None + } + })); + + registry.extension_installed("fake-extension".into(), cx); + }); + + let visible = registry.read(cx).visible_providers(); + assert!(visible.is_empty()); + + registry.update(cx, |registry, cx| { + registry.extension_uninstalled("fake-extension", cx); + }); + + let visible = registry.read(cx).visible_providers(); + assert_eq!(visible.len(), 1); + assert_eq!(visible[0].id(), provider_id); + } + + #[gpui::test] + fn test_should_hide_provider(cx: &mut App) { + let registry = cx.new(|_| LanguageModelRegistry::default()); + + registry.update(cx, |registry, cx| { + registry.set_builtin_provider_hiding_fn(Box::new(|id| { + if id == "anthropic" { + Some("anthropic") + } else if id == "openai" { + Some("openai") + } else { + None + } + })); + + registry.extension_installed("anthropic".into(), cx); + }); + + let registry_read = registry.read(cx); + + assert!(registry_read.should_hide_provider(&LanguageModelProviderId("anthropic".into()))); + + assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("openai".into()))); + + assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into()))); + } + + #[gpui::test] + fn test_sync_installed_llm_extensions(cx: &mut App) { + let registry = cx.new(|_| LanguageModelRegistry::default()); + + let provider = Arc::new(FakeLanguageModelProvider::default()); + + registry.update(cx, |registry, cx| { + registry.register_provider(provider.clone(), cx); + + registry.set_builtin_provider_hiding_fn(Box::new(|id| { + if id == "fake" { + Some("fake-extension") + } else { + None + } + })); + }); + + let mut extension_ids = HashSet::default(); + extension_ids.insert(Arc::from("fake-extension")); + + registry.update(cx, |registry, cx| { + registry.sync_installed_llm_extensions(extension_ids, cx); + }); + + assert!(registry.read(cx).visible_providers().is_empty()); + + registry.update(cx, |registry, cx| { + registry.sync_installed_llm_extensions(HashSet::default(), cx); + }); + + assert_eq!(registry.read(cx).visible_providers().len(), 1); + } } diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 5531e698ab7fccae736e800f38b16e35bcd35ac4..1bec5d94d2bb35f91305c6c77a9e85ed8579e1af 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -28,6 +28,8 @@ convert_case.workspace = true copilot.workspace = true credentials_provider.workspace = true deepseek = { workspace = true, features = ["schemars"] } +extension.workspace = true +extension_host.workspace = true fs.workspace = true futures.workspace = true google_ai = { workspace = true, features = ["schemars"] } diff --git a/crates/language_models/src/extension.rs b/crates/language_models/src/extension.rs new file mode 100644 index 0000000000000000000000000000000000000000..e0b46ab5e1d667fb61449a654769ecf7c221e720 --- /dev/null +++ b/crates/language_models/src/extension.rs @@ -0,0 +1,67 @@ +use collections::HashMap; +use extension::{ + ExtensionHostProxy, ExtensionLanguageModelProviderProxy, LanguageModelProviderRegistration, +}; +use gpui::{App, Entity}; +use language_model::{LanguageModelProviderId, LanguageModelRegistry}; +use std::sync::{Arc, LazyLock}; + +/// Maps built-in provider IDs to their corresponding extension IDs. +/// When an extension with this ID is installed, the built-in provider should be hidden. +static BUILTIN_TO_EXTENSION_MAP: LazyLock> = + LazyLock::new(|| { + let mut map = HashMap::default(); + map.insert("anthropic", "anthropic"); + map.insert("openai", "openai"); + map.insert("google", "google-ai"); + map.insert("openrouter", "openrouter"); + map.insert("copilot_chat", "copilot-chat"); + map + }); + +/// Returns the extension ID that should hide the given built-in provider. +pub fn extension_for_builtin_provider(provider_id: &str) -> Option<&'static str> { + BUILTIN_TO_EXTENSION_MAP.get(provider_id).copied() +} + +/// Proxy that registers extension language model providers with the LanguageModelRegistry. +pub struct LanguageModelProviderRegistryProxy { + registry: Entity, +} + +impl LanguageModelProviderRegistryProxy { + pub fn new(registry: Entity) -> Self { + Self { registry } + } +} + +impl ExtensionLanguageModelProviderProxy for LanguageModelProviderRegistryProxy { + fn register_language_model_provider( + &self, + _provider_id: Arc, + register_fn: LanguageModelProviderRegistration, + cx: &mut App, + ) { + register_fn(cx); + } + + fn unregister_language_model_provider(&self, provider_id: Arc, cx: &mut App) { + self.registry.update(cx, |registry, cx| { + registry.unregister_provider(LanguageModelProviderId::from(provider_id), cx); + }); + } +} + +/// Initialize the extension language model provider proxy. +/// This must be called BEFORE extension_host::init to ensure the proxy is available +/// when extensions try to register their language model providers. +pub fn init_proxy(cx: &mut App) { + let proxy = ExtensionHostProxy::default_global(cx); + let registry = LanguageModelRegistry::global(cx); + + registry.update(cx, |registry, _cx| { + registry.set_builtin_provider_hiding_fn(Box::new(extension_for_builtin_provider)); + }); + + proxy.register_language_model_provider_proxy(LanguageModelProviderRegistryProxy::new(registry)); +} diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 1038f5e233e0a5970b0e8bd969a65f6f0e2a7550..37d4ca5ddd4e5c1e7a0202c88c012d18b018cd4f 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -7,9 +7,12 @@ use gpui::{App, Context, Entity}; use language_model::{LanguageModelProviderId, LanguageModelRegistry}; use provider::deepseek::DeepSeekLanguageModelProvider; +pub mod extension; pub mod provider; mod settings; +pub use crate::extension::init_proxy as init_extension_proxy; + use crate::provider::anthropic::AnthropicLanguageModelProvider; use crate::provider::bedrock::BedrockLanguageModelProvider; use crate::provider::cloud::CloudLanguageModelProvider; @@ -31,6 +34,56 @@ pub fn init(user_store: Entity, client: Arc, cx: &mut App) { register_language_model_providers(registry, user_store, client.clone(), cx); }); + // Subscribe to extension store events to track LLM extension installations + if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) { + cx.subscribe(&extension_store, { + let registry = registry.clone(); + move |extension_store, event, cx| match event { + extension_host::Event::ExtensionInstalled(extension_id) => { + if let Some(manifest) = extension_store + .read(cx) + .extension_manifest_for_id(extension_id) + { + if !manifest.language_model_providers.is_empty() { + registry.update(cx, |registry, cx| { + registry.extension_installed(extension_id.clone(), cx); + }); + } + } + } + extension_host::Event::ExtensionUninstalled(extension_id) => { + registry.update(cx, |registry, cx| { + registry.extension_uninstalled(extension_id, cx); + }); + } + extension_host::Event::ExtensionsUpdated => { + let mut new_ids = HashSet::default(); + for (extension_id, entry) in extension_store.read(cx).installed_extensions() { + if !entry.manifest.language_model_providers.is_empty() { + new_ids.insert(extension_id.clone()); + } + } + registry.update(cx, |registry, cx| { + registry.sync_installed_llm_extensions(new_ids, cx); + }); + } + _ => {} + } + }) + .detach(); + + // Initialize with currently installed extensions + registry.update(cx, |registry, cx| { + let mut initial_ids = HashSet::default(); + for (extension_id, entry) in extension_store.read(cx).installed_extensions() { + if !entry.manifest.language_model_providers.is_empty() { + initial_ids.insert(extension_id.clone()); + } + } + registry.sync_installed_llm_extensions(initial_ids, cx); + }); + } + let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx) .openai_compatible .keys() diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index d8c972399c33922386bfba4236e1369d03d338dc..598834f85c496cd54ddd956089715cac64420202 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -8,7 +8,7 @@ use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::B use gpui::{AnyView, App, AsyncApp, Context, Entity, Task}; use http_client::HttpClient; use language_model::{ - ApiKeyState, AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModel, + ApiKeyState, AuthenticateError, ConfigurationViewTargetAgent, EnvVar, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, @@ -125,8 +125,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { PROVIDER_NAME } - fn icon(&self) -> IconName { - IconName::AiAnthropic + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::AiAnthropic) } fn default_model(&self, _cx: &App) -> Option> { diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 286f9ec1a4bf67c22868cf83e00e7b46e0737ba8..62237fbf376a0739fd2518bda44f51149b3457df 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -30,7 +30,7 @@ use gpui::{ use gpui_tokio::Tokio; use http_client::HttpClient; use language_model::{ - AuthenticateError, EnvVar, LanguageModel, LanguageModelCacheConfiguration, + AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, @@ -426,8 +426,8 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { PROVIDER_NAME } - fn icon(&self) -> IconName { - IconName::AiBedrock + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::AiBedrock) } fn default_model(&self, _cx: &App) -> Option> { diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index def1cef84d3166d08dcc7638ca5a29cabbd149c5..65a42740eb9a8aff830d7544ed5aa972c6697d88 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -19,7 +19,7 @@ use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Ta use http_client::http::{HeaderMap, HeaderValue}; use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode}; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, + AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, @@ -304,8 +304,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider { PROVIDER_NAME } - fn icon(&self) -> IconName { - IconName::AiZed + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::AiZed) } fn default_model(&self, cx: &App) -> Option> { diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 70198b337e467e1618192e781d3e3be305fea9c5..68eaab1dbed33a8d983de6a919b75dc809410a70 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -18,12 +18,12 @@ use gpui::{AnyView, App, AsyncApp, Entity, Subscription, Task}; use http_client::StatusCode; use language::language_settings::all_language_settings; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent, - LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role, - StopReason, TokenUsage, + AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolChoice, + LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, + MessageContent, RateLimiter, Role, StopReason, TokenUsage, }; use settings::SettingsStore; use ui::prelude::*; @@ -104,8 +104,8 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { PROVIDER_NAME } - fn icon(&self) -> IconName { - IconName::Copilot + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::Copilot) } fn default_model(&self, cx: &App) -> Option> { diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index b00a5d82f5665a5c87c662d1af84fbeb9ac07ebb..b3264b869195aa34d7083cd31992d8c220d20349 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -7,7 +7,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture, stream::BoxStream use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ - ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError, + ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, @@ -127,8 +127,8 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { PROVIDER_NAME } - fn icon(&self) -> IconName { - IconName::AiDeepSeek + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::AiDeepSeek) } fn default_model(&self, _cx: &App) -> Option> { diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 989b99061b6d0f4c6680f08616c55946138ae0fe..7d567d60f405c7880cb6494f6d2ff604d7f53ac2 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -14,7 +14,7 @@ use language_model::{ LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason, }; use language_model::{ - LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, + IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; @@ -164,8 +164,8 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { PROVIDER_NAME } - fn icon(&self) -> IconName { - IconName::AiGoogle + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::AiGoogle) } fn default_model(&self, _cx: &App) -> Option> { diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 94f99f10afc8928fb7fbc8526ab46e7dca37a5ce..237b64ac7d0ed728b057f6b553ad2a2a1ebae1db 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -10,7 +10,7 @@ use language_model::{ StopReason, TokenUsage, }; use language_model::{ - LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, + IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; @@ -175,8 +175,8 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { PROVIDER_NAME } - fn icon(&self) -> IconName { - IconName::AiLmStudio + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::AiLmStudio) } fn default_model(&self, _: &App) -> Option> { diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 64f3999e3aa96b2611e265a6eaf5df8063332c2a..0b8af405ade8fc00c0d1e2e57ba115560d94a71d 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -5,7 +5,7 @@ use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::B use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ - ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError, + ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, @@ -176,8 +176,8 @@ impl LanguageModelProvider for MistralLanguageModelProvider { PROVIDER_NAME } - fn icon(&self) -> IconName { - IconName::AiMistral + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::AiMistral) } fn default_model(&self, _cx: &App) -> Option> { diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 860d635b6ac704c2762023c463432bebae08d4a5..f5d8820e710ea6c9f89de6da5a7aae2f204c6470 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -5,7 +5,7 @@ use futures::{Stream, TryFutureExt, stream}; use gpui::{AnyView, App, AsyncApp, Context, CursorStyle, Entity, Task}; use http_client::HttpClient; use language_model::{ - ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError, + ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse, @@ -221,8 +221,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { PROVIDER_NAME } - fn icon(&self) -> IconName { - IconName::AiOllama + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::AiOllama) } fn default_model(&self, _: &App) -> Option> { diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index afaffba3e53eb2496f9fae795d69b9e9c9f57249..905d2b37862eebf57c7fb56a540b388338cfd065 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -5,7 +5,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ - ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError, + ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, @@ -122,8 +122,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { PROVIDER_NAME } - fn icon(&self) -> IconName { - IconName::AiOpenAi + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::AiOpenAi) } fn default_model(&self, _cx: &App) -> Option> { diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index e6e7a9984da3d48b9e3c0f9571b8e916359fba03..f95f567739d76670d3cfa7b835bbfaf34ddef92f 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -4,7 +4,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ - ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError, + ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, @@ -133,8 +133,8 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider { self.name.clone() } - fn icon(&self) -> IconName { - IconName::AiOpenAiCompat + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::AiOpenAiCompat) } fn default_model(&self, cx: &App) -> Option> { diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index ad2e90d9dd5f4ece7e2582a867da50f6962c981c..48d68ddebff7e0c9bbe39dbca696dd2ffcf62605 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -4,7 +4,7 @@ use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task}; use http_client::HttpClient; use language_model::{ - ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError, + ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, @@ -180,8 +180,8 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider { PROVIDER_NAME } - fn icon(&self) -> IconName { - IconName::AiOpenRouter + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::AiOpenRouter) } fn default_model(&self, _cx: &App) -> Option> { diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 4dfe848df80123dc4c37d27b81f76db359e076f9..e2e692eafff94c56d481dfc2bd96dbfa7adda262 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -4,7 +4,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ - ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError, + ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, RateLimiter, Role, env_var, @@ -117,8 +117,8 @@ impl LanguageModelProvider for VercelLanguageModelProvider { PROVIDER_NAME } - fn icon(&self) -> IconName { - IconName::AiVZero + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::AiVZero) } fn default_model(&self, _cx: &App) -> Option> { diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index 19c50d71cf4e483b68d48c8b982a975f3091ff46..f0aa0e71a83ae1a201d76a33f63ca0aadc6936a9 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -4,7 +4,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window}; use http_client::HttpClient; use language_model::{ - ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError, + ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, @@ -118,8 +118,8 @@ impl LanguageModelProvider for XAiLanguageModelProvider { PROVIDER_NAME } - fn icon(&self) -> IconName { - IconName::AiXAi + fn icon(&self) -> IconOrSvg { + IconOrSvg::Icon(IconName::AiXAi) } fn default_model(&self, _cx: &App) -> Option> { diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs index 1c8e36ec18d6184b38eb6772e8f5a13be181ae00..9d2c7ae3b515744125879f4a2c0e0d3e9a4fb841 100644 --- a/crates/ui/src/components/icon.rs +++ b/crates/ui/src/components/icon.rs @@ -126,17 +126,6 @@ enum IconSource { ExternalSvg(SharedString), } -impl IconSource { - fn from_path(path: impl Into) -> Self { - let path = path.into(); - if path.starts_with("icons/") { - Self::Embedded(path) - } else { - Self::External(Arc::from(PathBuf::from(path.as_ref()))) - } - } -} - #[derive(IntoElement, RegisterComponent)] pub struct Icon { source: IconSource, @@ -155,9 +144,18 @@ impl Icon { } } + /// Create an icon from a path. Uses a heuristic to determine if it's embedded or external: + /// - Paths starting with "icons/" are treated as embedded SVGs + /// - Other paths are treated as external raster images (from icon themes) pub fn from_path(path: impl Into) -> Self { + let path = path.into(); + let source = if path.starts_with("icons/") { + IconSource::Embedded(path) + } else { + IconSource::External(Arc::from(PathBuf::from(path.as_ref()))) + }; Self { - source: IconSource::from_path(path), + source, color: Color::default(), size: IconSize::default().rems(), transformation: Transformation::default(),