Cargo.lock 🔗
@@ -8932,6 +8932,8 @@ dependencies = [
"credentials_provider",
"deepseek",
"editor",
+ "extension",
+ "extension_host",
"fs",
"futures 0.3.31",
"google_ai",
Richard Feldman created
This adds support for provider extensions but doesn't actually add any
yet.
Release Notes:
- N/A
Cargo.lock | 2
crates/acp_thread/src/connection.rs | 11
crates/agent/src/agent.rs | 15
crates/agent_ui/src/acp/model_selector.rs | 8
crates/agent_ui/src/acp/model_selector_popover.rs | 13
crates/agent_ui/src/agent_configuration.rs | 16
crates/agent_ui/src/agent_model_selector.rs | 12
crates/agent_ui/src/agent_panel.rs | 2
crates/agent_ui/src/agent_ui.rs | 3
crates/agent_ui/src/language_model_selector.rs | 19
crates/agent_ui/src/text_thread_editor.rs | 20
crates/agent_ui/src/ui/model_selector_components.rs | 23
crates/ai_onboarding/src/agent_api_keys_onboarding.rs | 20
crates/ai_onboarding/src/agent_panel_onboarding_content.rs | 21
crates/extension/src/extension_host_proxy.rs | 48
crates/extension/src/extension_manifest.rs | 14
crates/extension_host/benches/extension_compilation_benchmark.rs | 1
crates/extension_host/src/capability_granter.rs | 1
crates/extension_host/src/extension_store_test.rs | 3
crates/language_model/src/language_model.rs | 21
crates/language_model/src/registry.rs | 195 +
crates/language_models/Cargo.toml | 2
crates/language_models/src/extension.rs | 67
crates/language_models/src/language_models.rs | 53
crates/language_models/src/provider/anthropic.rs | 6
crates/language_models/src/provider/bedrock.rs | 6
crates/language_models/src/provider/cloud.rs | 6
crates/language_models/src/provider/copilot_chat.rs | 16
crates/language_models/src/provider/deepseek.rs | 6
crates/language_models/src/provider/google.rs | 6
crates/language_models/src/provider/lmstudio.rs | 6
crates/language_models/src/provider/mistral.rs | 6
crates/language_models/src/provider/ollama.rs | 6
crates/language_models/src/provider/open_ai.rs | 6
crates/language_models/src/provider/open_ai_compatible.rs | 6
crates/language_models/src/provider/open_router.rs | 6
crates/language_models/src/provider/vercel.rs | 6
crates/language_models/src/provider/x_ai.rs | 6
crates/ui/src/components/icon.rs | 22
39 files changed, 585 insertions(+), 121 deletions(-)
@@ -8932,6 +8932,8 @@ dependencies = [
"credentials_provider",
"deepseek",
"editor",
+ "extension",
+ "extension_host",
"fs",
"futures 0.3.31",
"google_ai",
@@ -210,12 +210,21 @@ pub trait AgentModelSelector: 'static {
}
}
+/// Icon for a model in the model selector.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum AgentModelIcon {
+ /// A built-in icon from Zed's icon set.
+ Named(IconName),
+ /// Path to a custom SVG icon file.
+ Path(SharedString),
+}
+
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AgentModelInfo {
pub id: acp::ModelId,
pub name: SharedString,
pub description: Option<SharedString>,
- pub icon: Option<IconName>,
+ pub icon: Option<AgentModelIcon>,
}
impl From<acp::ModelInfo> for AgentModelInfo {
@@ -30,7 +30,7 @@ use futures::{StreamExt, future};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
};
-use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
+use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{
ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
@@ -93,7 +93,7 @@ impl LanguageModels {
fn refresh_list(&mut self, cx: &App) {
let providers = LanguageModelRegistry::global(cx)
.read(cx)
- .providers()
+ .visible_providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
@@ -153,7 +153,10 @@ impl LanguageModels {
id: Self::model_id(model),
name: model.name().0,
description: None,
- icon: Some(provider.icon()),
+ icon: Some(match provider.icon() {
+ IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
+ IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
+ }),
}
}
@@ -164,7 +167,7 @@ impl LanguageModels {
fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
let authenticate_all_providers = LanguageModelRegistry::global(cx)
.read(cx)
- .providers()
+ .visible_providers()
.iter()
.map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
.collect::<Vec<_>>();
@@ -1630,7 +1633,9 @@ mod internal_tests {
id: acp::ModelId::new("fake/fake"),
name: "Fake".into(),
description: None,
- icon: Some(ui::IconName::ZedAssistant),
+ icon: Some(acp_thread::AgentModelIcon::Named(
+ ui::IconName::ZedAssistant
+ )),
}]
)])
);
@@ -1,6 +1,6 @@
use std::{cmp::Reverse, rc::Rc, sync::Arc};
-use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
+use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
use agent_client_protocol::ModelId;
use agent_servers::AgentServer;
use agent_settings::AgentSettings;
@@ -350,7 +350,11 @@ impl PickerDelegate for AcpModelPickerDelegate {
})
.child(
ModelSelectorListItem::new(ix, model_info.name.clone())
- .when_some(model_info.icon, |this, icon| this.icon(icon))
+ .map(|this| match &model_info.icon {
+ Some(AgentModelIcon::Path(path)) => this.icon_path(path.clone()),
+ Some(AgentModelIcon::Named(icon)) => this.icon(*icon),
+ None => this,
+ })
.is_selected(is_selected)
.is_focused(selected)
.when(supports_favorites, |this| {
@@ -1,7 +1,7 @@
use std::rc::Rc;
use std::sync::Arc;
-use acp_thread::{AgentModelInfo, AgentModelSelector};
+use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelSelector};
use agent_servers::AgentServer;
use agent_settings::AgentSettings;
use fs::Fs;
@@ -70,7 +70,7 @@ impl Render for AcpModelSelectorPopover {
.map(|model| model.name.clone())
.unwrap_or_else(|| SharedString::from("Select a Model"));
- let model_icon = model.as_ref().and_then(|model| model.icon);
+ let model_icon = model.as_ref().and_then(|model| model.icon.clone());
let focus_handle = self.focus_handle.clone();
@@ -125,7 +125,14 @@ impl Render for AcpModelSelectorPopover {
ButtonLike::new("active-model")
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
.when_some(model_icon, |this, icon| {
- this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
+ this.child(
+ match icon {
+ AgentModelIcon::Path(path) => Icon::from_external_svg(path),
+ AgentModelIcon::Named(icon_name) => Icon::new(icon_name),
+ }
+ .color(color)
+ .size(IconSize::XSmall),
+ )
})
.child(
Label::new(model_name)
@@ -22,7 +22,8 @@ use gpui::{
};
use language::LanguageRegistry;
use language_model::{
- LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
+ IconOrSvg, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
+ ZED_CLOUD_PROVIDER_ID,
};
use language_models::AllLanguageModelSettings;
use notifications::status_toast::{StatusToast, ToastIcon};
@@ -117,7 +118,7 @@ impl AgentConfiguration {
}
fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- let providers = LanguageModelRegistry::read_global(cx).providers();
+ let providers = LanguageModelRegistry::read_global(cx).visible_providers();
for provider in providers {
self.add_provider_configuration_view(&provider, window, cx);
}
@@ -261,9 +262,12 @@ impl AgentConfiguration {
.w_full()
.gap_1p5()
.child(
- Icon::new(provider.icon())
- .size(IconSize::Small)
- .color(Color::Muted),
+ match provider.icon() {
+ IconOrSvg::Svg(path) => Icon::from_external_svg(path),
+ IconOrSvg::Icon(name) => Icon::new(name),
+ }
+ .size(IconSize::Small)
+ .color(Color::Muted),
)
.child(
h_flex()
@@ -416,7 +420,7 @@ impl AgentConfiguration {
&mut self,
cx: &mut Context<Self>,
) -> impl IntoElement {
- let providers = LanguageModelRegistry::read_global(cx).providers();
+ let providers = LanguageModelRegistry::read_global(cx).visible_providers();
let popover_menu = PopoverMenu::new("add-provider-popover")
.trigger(
@@ -4,6 +4,7 @@ use crate::{
};
use fs::Fs;
use gpui::{Entity, FocusHandle, SharedString};
+use language_model::IconOrSvg;
use picker::popover_menu::PickerPopoverMenu;
use settings::update_settings_file;
use std::sync::Arc;
@@ -103,7 +104,14 @@ impl Render for AgentModelSelector {
self.selector.clone(),
ButtonLike::new("active-model")
.when_some(provider_icon, |this, icon| {
- this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
+ this.child(
+ match icon {
+ IconOrSvg::Svg(path) => Icon::from_external_svg(path),
+ IconOrSvg::Icon(name) => Icon::new(name),
+ }
+ .color(color)
+ .size(IconSize::XSmall),
+ )
})
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
.child(
@@ -115,7 +123,7 @@ impl Render for AgentModelSelector {
.child(
Icon::new(IconName::ChevronDown)
.color(color)
- .size(IconSize::Small),
+ .size(IconSize::XSmall),
),
move |_window, cx| {
Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx)
@@ -2428,7 +2428,7 @@ impl AgentPanel {
let history_is_empty = self.history_store.read(cx).is_empty(cx);
let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx)
- .providers()
+ .visible_providers()
.iter()
.any(|provider| {
provider.is_authenticated(cx)
@@ -348,7 +348,8 @@ fn init_language_model_settings(cx: &mut App) {
|_, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
- | language_model::Event::RemovedProvider(_) => {
+ | language_model::Event::RemovedProvider(_)
+ | language_model::Event::ProvidersChanged => {
update_active_language_model_from_settings(cx);
}
_ => {}
@@ -7,8 +7,8 @@ use gpui::{
Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
};
use language_model::{
- AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProvider,
- LanguageModelProviderId, LanguageModelRegistry,
+ AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId,
+ LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
@@ -55,7 +55,7 @@ pub fn language_model_selector(
fn all_models(cx: &App) -> GroupedModels {
let lm_registry = LanguageModelRegistry::global(cx).read(cx);
- let providers = lm_registry.providers();
+ let providers = lm_registry.visible_providers();
let mut favorites_index = FavoritesIndex::default();
@@ -94,7 +94,7 @@ type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>
#[derive(Clone)]
struct ModelInfo {
model: Arc<dyn LanguageModel>,
- icon: IconName,
+ icon: IconOrSvg,
is_favorite: bool,
}
@@ -203,7 +203,7 @@ impl LanguageModelPickerDelegate {
fn authenticate_all_providers(cx: &mut App) -> Task<()> {
let authenticate_all_providers = LanguageModelRegistry::global(cx)
.read(cx)
- .providers()
+ .visible_providers()
.iter()
.map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
.collect::<Vec<_>>();
@@ -474,7 +474,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
let configured_providers = language_model_registry
.read(cx)
- .providers()
+ .visible_providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
@@ -566,7 +566,10 @@ impl PickerDelegate for LanguageModelPickerDelegate {
Some(
ModelSelectorListItem::new(ix, model_info.model.name().0)
- .icon(model_info.icon)
+ .map(|this| match &model_info.icon {
+ IconOrSvg::Icon(icon_name) => this.icon(*icon_name),
+ IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()),
+ })
.is_selected(is_selected)
.is_focused(selected)
.is_favorite(is_favorite)
@@ -702,7 +705,7 @@ mod tests {
.any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
ModelInfo {
model: Arc::new(TestLanguageModel::new(name, provider)),
- icon: IconName::Ai,
+ icon: IconOrSvg::Icon(IconName::Ai),
is_favorite,
}
})
@@ -33,7 +33,8 @@ use language::{
language_settings::{SoftWrap, all_language_settings},
};
use language_model::{
- ConfigurationError, LanguageModelExt, LanguageModelImage, LanguageModelRegistry, Role,
+ ConfigurationError, IconOrSvg, LanguageModelExt, LanguageModelImage, LanguageModelRegistry,
+ Role,
};
use multi_buffer::MultiBufferRow;
use picker::{Picker, popover_menu::PickerPopoverMenu};
@@ -2231,10 +2232,10 @@ impl TextThreadEditor {
.default_model()
.map(|default| default.provider);
- let provider_icon = match active_provider {
- Some(provider) => provider.icon(),
- None => IconName::Ai,
- };
+ let provider_icon = active_provider
+ .as_ref()
+ .map(|p| p.icon())
+ .unwrap_or(IconOrSvg::Icon(IconName::Ai));
let focus_handle = self.editor().focus_handle(cx);
@@ -2244,6 +2245,13 @@ impl TextThreadEditor {
(Color::Muted, IconName::ChevronDown)
};
+ let provider_icon_element = match provider_icon {
+ IconOrSvg::Svg(path) => Icon::from_external_svg(path),
+ IconOrSvg::Icon(name) => Icon::new(name),
+ }
+ .color(color)
+ .size(IconSize::XSmall);
+
let tooltip = Tooltip::element({
move |_, cx| {
let focus_handle = focus_handle.clone();
@@ -2291,7 +2299,7 @@ impl TextThreadEditor {
.child(
h_flex()
.gap_0p5()
- .child(Icon::new(provider_icon).color(color).size(IconSize::XSmall))
+ .child(provider_icon_element)
.child(
Label::new(model_name)
.color(color)
@@ -1,6 +1,11 @@
use gpui::{Action, FocusHandle, prelude::*};
use ui::{ElevationIndex, KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*};
+enum ModelIcon {
+ Name(IconName),
+ Path(SharedString),
+}
+
#[derive(IntoElement)]
pub struct ModelSelectorHeader {
title: SharedString,
@@ -39,7 +44,7 @@ impl RenderOnce for ModelSelectorHeader {
pub struct ModelSelectorListItem {
index: usize,
title: SharedString,
- icon: Option<IconName>,
+ icon: Option<ModelIcon>,
is_selected: bool,
is_focused: bool,
is_favorite: bool,
@@ -60,7 +65,12 @@ impl ModelSelectorListItem {
}
pub fn icon(mut self, icon: IconName) -> Self {
- self.icon = Some(icon);
+ self.icon = Some(ModelIcon::Name(icon));
+ self
+ }
+
+ pub fn icon_path(mut self, path: SharedString) -> Self {
+ self.icon = Some(ModelIcon::Path(path));
self
}
@@ -105,9 +115,12 @@ impl RenderOnce for ModelSelectorListItem {
.gap_1p5()
.when_some(self.icon, |this, icon| {
this.child(
- Icon::new(icon)
- .color(model_icon_color)
- .size(IconSize::Small),
+ match icon {
+ ModelIcon::Name(icon_name) => Icon::new(icon_name),
+ ModelIcon::Path(icon_path) => Icon::from_external_svg(icon_path),
+ }
+ .color(model_icon_color)
+ .size(IconSize::Small),
)
})
.child(Label::new(self.title).truncate()),
@@ -1,9 +1,9 @@
use gpui::{Action, IntoElement, ParentElement, RenderOnce, point};
-use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
+use language_model::{IconOrSvg, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
use ui::{Divider, List, ListBulletItem, prelude::*};
pub struct ApiKeysWithProviders {
- configured_providers: Vec<(IconName, SharedString)>,
+ configured_providers: Vec<(IconOrSvg, SharedString)>,
}
impl ApiKeysWithProviders {
@@ -13,7 +13,8 @@ impl ApiKeysWithProviders {
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
- | language_model::Event::RemovedProvider(_) => {
+ | language_model::Event::RemovedProvider(_)
+ | language_model::Event::ProvidersChanged => {
this.configured_providers = Self::compute_configured_providers(cx)
}
_ => {}
@@ -26,9 +27,9 @@ impl ApiKeysWithProviders {
}
}
- fn compute_configured_providers(cx: &App) -> Vec<(IconName, SharedString)> {
+ fn compute_configured_providers(cx: &App) -> Vec<(IconOrSvg, SharedString)> {
LanguageModelRegistry::read_global(cx)
- .providers()
+ .visible_providers()
.iter()
.filter(|provider| {
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
@@ -47,7 +48,14 @@ impl Render for ApiKeysWithProviders {
.map(|(icon, name)| {
h_flex()
.gap_1p5()
- .child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted))
+ .child(
+ match icon {
+ IconOrSvg::Icon(icon_name) => Icon::new(icon_name),
+ IconOrSvg::Svg(icon_path) => Icon::from_external_svg(icon_path),
+ }
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
.child(Label::new(name))
});
div()
@@ -11,7 +11,7 @@ use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding};
pub struct AgentPanelOnboarding {
user_store: Entity<UserStore>,
client: Arc<Client>,
- configured_providers: Vec<(IconName, SharedString)>,
+ has_configured_providers: bool,
continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>,
}
@@ -27,8 +27,9 @@ impl AgentPanelOnboarding {
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
- | language_model::Event::RemovedProvider(_) => {
- this.configured_providers = Self::compute_available_providers(cx)
+ | language_model::Event::RemovedProvider(_)
+ | language_model::Event::ProvidersChanged => {
+ this.has_configured_providers = Self::has_configured_providers(cx)
}
_ => {}
},
@@ -38,20 +39,16 @@ impl AgentPanelOnboarding {
Self {
user_store,
client,
- configured_providers: Self::compute_available_providers(cx),
+ has_configured_providers: Self::has_configured_providers(cx),
continue_with_zed_ai: Arc::new(continue_with_zed_ai),
}
}
- fn compute_available_providers(cx: &App) -> Vec<(IconName, SharedString)> {
+ fn has_configured_providers(cx: &App) -> bool {
LanguageModelRegistry::read_global(cx)
- .providers()
+ .visible_providers()
.iter()
- .filter(|provider| {
- provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
- })
- .map(|provider| (provider.icon(), provider.name().0))
- .collect()
+ .any(|provider| provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID)
}
}
@@ -81,7 +78,7 @@ impl Render for AgentPanelOnboarding {
}),
)
.map(|this| {
- if enrolled_in_trial || is_pro_user || !self.configured_providers.is_empty() {
+ if enrolled_in_trial || is_pro_user || self.has_configured_providers {
this
} else {
this.child(ApiKeysWithoutProviders::new())
@@ -19,6 +19,9 @@ impl Global for GlobalExtensionHostProxy {}
///
/// This object implements each of the individual proxy types so that their
/// methods can be called directly on it.
+/// Registration function for language model providers.
+pub type LanguageModelProviderRegistration = Box<dyn FnOnce(&mut App) + Send>;
+
#[derive(Default)]
pub struct ExtensionHostProxy {
theme_proxy: RwLock<Option<Arc<dyn ExtensionThemeProxy>>>,
@@ -29,6 +32,7 @@ pub struct ExtensionHostProxy {
slash_command_proxy: RwLock<Option<Arc<dyn ExtensionSlashCommandProxy>>>,
context_server_proxy: RwLock<Option<Arc<dyn ExtensionContextServerProxy>>>,
debug_adapter_provider_proxy: RwLock<Option<Arc<dyn ExtensionDebugAdapterProviderProxy>>>,
+ language_model_provider_proxy: RwLock<Option<Arc<dyn ExtensionLanguageModelProviderProxy>>>,
}
impl ExtensionHostProxy {
@@ -54,6 +58,7 @@ impl ExtensionHostProxy {
slash_command_proxy: RwLock::default(),
context_server_proxy: RwLock::default(),
debug_adapter_provider_proxy: RwLock::default(),
+ language_model_provider_proxy: RwLock::default(),
}
}
@@ -90,6 +95,15 @@ impl ExtensionHostProxy {
.write()
.replace(Arc::new(proxy));
}
+
+ pub fn register_language_model_provider_proxy(
+ &self,
+ proxy: impl ExtensionLanguageModelProviderProxy,
+ ) {
+ self.language_model_provider_proxy
+ .write()
+ .replace(Arc::new(proxy));
+ }
}
pub trait ExtensionThemeProxy: Send + Sync + 'static {
@@ -446,3 +460,37 @@ impl ExtensionDebugAdapterProviderProxy for ExtensionHostProxy {
proxy.unregister_debug_locator(locator_name)
}
}
+
+pub trait ExtensionLanguageModelProviderProxy: Send + Sync + 'static {
+ fn register_language_model_provider(
+ &self,
+ provider_id: Arc<str>,
+ register_fn: LanguageModelProviderRegistration,
+ cx: &mut App,
+ );
+
+ fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App);
+}
+
+impl ExtensionLanguageModelProviderProxy for ExtensionHostProxy {
+ fn register_language_model_provider(
+ &self,
+ provider_id: Arc<str>,
+ register_fn: LanguageModelProviderRegistration,
+ cx: &mut App,
+ ) {
+ let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
+ return;
+ };
+
+ proxy.register_language_model_provider(provider_id, register_fn, cx)
+ }
+
+ fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {
+ let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
+ return;
+ };
+
+ proxy.unregister_language_model_provider(provider_id, cx)
+ }
+}
@@ -93,6 +93,8 @@ pub struct ExtensionManifest {
pub debug_adapters: BTreeMap<Arc<str>, DebugAdapterManifestEntry>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub debug_locators: BTreeMap<Arc<str>, DebugLocatorManifestEntry>,
+ #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
+ pub language_model_providers: BTreeMap<Arc<str>, LanguageModelProviderManifestEntry>,
}
impl ExtensionManifest {
@@ -288,6 +290,16 @@ pub struct DebugAdapterManifestEntry {
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct DebugLocatorManifestEntry {}
+/// Manifest entry for a language model provider.
+#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
+pub struct LanguageModelProviderManifestEntry {
+ /// Display name for the provider.
+ pub name: String,
+ /// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg").
+ #[serde(default)]
+ pub icon: Option<String>,
+}
+
impl ExtensionManifest {
pub async fn load(fs: Arc<dyn Fs>, extension_dir: &Path) -> Result<Self> {
let extension_name = extension_dir
@@ -358,6 +370,7 @@ fn manifest_from_old_manifest(
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
+ language_model_providers: Default::default(),
}
}
@@ -391,6 +404,7 @@ mod tests {
capabilities: vec![],
debug_adapters: Default::default(),
debug_locators: Default::default(),
+ language_model_providers: BTreeMap::default(),
}
}
@@ -148,6 +148,7 @@ fn manifest() -> ExtensionManifest {
)],
debug_adapters: Default::default(),
debug_locators: Default::default(),
+ language_model_providers: BTreeMap::default(),
}
}
@@ -113,6 +113,7 @@ mod tests {
capabilities: vec![],
debug_adapters: Default::default(),
debug_locators: Default::default(),
+ language_model_providers: BTreeMap::default(),
}
}
@@ -165,6 +165,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
+ language_model_providers: BTreeMap::default(),
}),
dev: false,
},
@@ -196,6 +197,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
+ language_model_providers: BTreeMap::default(),
}),
dev: false,
},
@@ -376,6 +378,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
+ language_model_providers: BTreeMap::default(),
}),
dev: false,
},
@@ -797,11 +797,26 @@ pub enum AuthenticateError {
Other(#[from] anyhow::Error),
}
+/// Either a built-in icon name or a path to an external SVG.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum IconOrSvg {
+ /// A built-in icon from Zed's icon set.
+ Icon(IconName),
+ /// Path to a custom SVG icon file.
+ Svg(SharedString),
+}
+
+impl Default for IconOrSvg {
+ fn default() -> Self {
+ Self::Icon(IconName::ZedAssistant)
+ }
+}
+
pub trait LanguageModelProvider: 'static {
fn id(&self) -> LanguageModelProviderId;
fn name(&self) -> LanguageModelProviderName;
- fn icon(&self) -> IconName {
- IconName::ZedAssistant
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::default()
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
@@ -820,7 +835,7 @@ pub trait LanguageModelProvider: 'static {
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
}
-#[derive(Default, Clone)]
+#[derive(Default, Clone, PartialEq, Eq)]
pub enum ConfigurationViewTargetAgent {
#[default]
ZedAgent,
@@ -2,12 +2,16 @@ use crate::{
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState,
};
-use collections::BTreeMap;
+use collections::{BTreeMap, HashSet};
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
use util::maybe;
+/// Function type for checking if a built-in provider should be hidden.
+/// Returns Some(extension_id) if the provider should be hidden when that extension is installed.
+pub type BuiltinProviderHidingFn = Box<dyn Fn(&str) -> Option<&'static str> + Send + Sync>;
+
pub fn init(cx: &mut App) {
let registry = cx.new(|_cx| LanguageModelRegistry::default());
cx.set_global(GlobalLanguageModelRegistry(registry));
@@ -48,6 +52,11 @@ pub struct LanguageModelRegistry {
thread_summary_model: Option<ConfiguredModel>,
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
+ /// Set of installed extension IDs that provide language models.
+ /// Used to determine which built-in providers should be hidden.
+ installed_llm_extension_ids: HashSet<Arc<str>>,
+ /// Function to check if a built-in provider should be hidden by an extension.
+ builtin_provider_hiding_fn: Option<BuiltinProviderHidingFn>,
}
#[derive(Debug)]
@@ -104,6 +113,8 @@ pub enum Event {
ProviderStateChanged(LanguageModelProviderId),
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
+ /// Emitted when provider visibility changes due to extension install/uninstall.
+ ProvidersChanged,
}
impl EventEmitter<Event> for LanguageModelRegistry {}
@@ -183,6 +194,60 @@ impl LanguageModelRegistry {
providers
}
+ /// Returns providers, filtering out hidden built-in providers.
+ pub fn visible_providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
+ self.providers()
+ .into_iter()
+ .filter(|p| !self.should_hide_provider(&p.id()))
+ .collect()
+ }
+
+ /// Sets the function used to check if a built-in provider should be hidden.
+ pub fn set_builtin_provider_hiding_fn(&mut self, hiding_fn: BuiltinProviderHidingFn) {
+ self.builtin_provider_hiding_fn = Some(hiding_fn);
+ }
+
+ /// Called when an extension is installed/loaded.
+ /// If the extension provides language models, track it so we can hide the corresponding built-in.
+ pub fn extension_installed(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
+ if self.installed_llm_extension_ids.insert(extension_id) {
+ cx.emit(Event::ProvidersChanged);
+ cx.notify();
+ }
+ }
+
+ /// Called when an extension is uninstalled/unloaded.
+ pub fn extension_uninstalled(&mut self, extension_id: &str, cx: &mut Context<Self>) {
+ if self.installed_llm_extension_ids.remove(extension_id) {
+ cx.emit(Event::ProvidersChanged);
+ cx.notify();
+ }
+ }
+
+ /// Sync the set of installed LLM extension IDs.
+ pub fn sync_installed_llm_extensions(
+ &mut self,
+ extension_ids: HashSet<Arc<str>>,
+ cx: &mut Context<Self>,
+ ) {
+ if extension_ids != self.installed_llm_extension_ids {
+ self.installed_llm_extension_ids = extension_ids;
+ cx.emit(Event::ProvidersChanged);
+ cx.notify();
+ }
+ }
+
+ /// Returns true if a provider should be hidden from the UI.
+ /// Built-in providers are hidden when their corresponding extension is installed.
+ pub fn should_hide_provider(&self, provider_id: &LanguageModelProviderId) -> bool {
+ if let Some(ref hiding_fn) = self.builtin_provider_hiding_fn {
+ if let Some(extension_id) = hiding_fn(&provider_id.0) {
+ return self.installed_llm_extension_ids.contains(extension_id);
+ }
+ }
+ false
+ }
+
pub fn configuration_error(
&self,
model: Option<ConfiguredModel>,
@@ -416,4 +481,132 @@ mod tests {
let providers = registry.read(cx).providers();
assert!(providers.is_empty());
}
+
+ #[gpui::test]
+ fn test_provider_hiding_on_extension_install(cx: &mut App) {
+ let registry = cx.new(|_| LanguageModelRegistry::default());
+
+ let provider = Arc::new(FakeLanguageModelProvider::default());
+ let provider_id = provider.id();
+
+ registry.update(cx, |registry, cx| {
+ registry.register_provider(provider.clone(), cx);
+
+ registry.set_builtin_provider_hiding_fn(Box::new(|id| {
+ if id == "fake" {
+ Some("fake-extension")
+ } else {
+ None
+ }
+ }));
+ });
+
+ let visible = registry.read(cx).visible_providers();
+ assert_eq!(visible.len(), 1);
+ assert_eq!(visible[0].id(), provider_id);
+
+ registry.update(cx, |registry, cx| {
+ registry.extension_installed("fake-extension".into(), cx);
+ });
+
+ let visible = registry.read(cx).visible_providers();
+ assert!(visible.is_empty());
+
+ let all = registry.read(cx).providers();
+ assert_eq!(all.len(), 1);
+ }
+
+ #[gpui::test]
+ fn test_provider_unhiding_on_extension_uninstall(cx: &mut App) {
+ let registry = cx.new(|_| LanguageModelRegistry::default());
+
+ let provider = Arc::new(FakeLanguageModelProvider::default());
+ let provider_id = provider.id();
+
+ registry.update(cx, |registry, cx| {
+ registry.register_provider(provider.clone(), cx);
+
+ registry.set_builtin_provider_hiding_fn(Box::new(|id| {
+ if id == "fake" {
+ Some("fake-extension")
+ } else {
+ None
+ }
+ }));
+
+ registry.extension_installed("fake-extension".into(), cx);
+ });
+
+ let visible = registry.read(cx).visible_providers();
+ assert!(visible.is_empty());
+
+ registry.update(cx, |registry, cx| {
+ registry.extension_uninstalled("fake-extension", cx);
+ });
+
+ let visible = registry.read(cx).visible_providers();
+ assert_eq!(visible.len(), 1);
+ assert_eq!(visible[0].id(), provider_id);
+ }
+
+ #[gpui::test]
+ fn test_should_hide_provider(cx: &mut App) {
+ let registry = cx.new(|_| LanguageModelRegistry::default());
+
+ registry.update(cx, |registry, cx| {
+ registry.set_builtin_provider_hiding_fn(Box::new(|id| {
+ if id == "anthropic" {
+ Some("anthropic")
+ } else if id == "openai" {
+ Some("openai")
+ } else {
+ None
+ }
+ }));
+
+ registry.extension_installed("anthropic".into(), cx);
+ });
+
+ let registry_read = registry.read(cx);
+
+ assert!(registry_read.should_hide_provider(&LanguageModelProviderId("anthropic".into())));
+
+ assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("openai".into())));
+
+ assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into())));
+ }
+
+ #[gpui::test]
+ fn test_sync_installed_llm_extensions(cx: &mut App) {
+ let registry = cx.new(|_| LanguageModelRegistry::default());
+
+ let provider = Arc::new(FakeLanguageModelProvider::default());
+
+ registry.update(cx, |registry, cx| {
+ registry.register_provider(provider.clone(), cx);
+
+ registry.set_builtin_provider_hiding_fn(Box::new(|id| {
+ if id == "fake" {
+ Some("fake-extension")
+ } else {
+ None
+ }
+ }));
+ });
+
+ let mut extension_ids = HashSet::default();
+ extension_ids.insert(Arc::from("fake-extension"));
+
+ registry.update(cx, |registry, cx| {
+ registry.sync_installed_llm_extensions(extension_ids, cx);
+ });
+
+ assert!(registry.read(cx).visible_providers().is_empty());
+
+ registry.update(cx, |registry, cx| {
+ registry.sync_installed_llm_extensions(HashSet::default(), cx);
+ });
+
+ assert_eq!(registry.read(cx).visible_providers().len(), 1);
+ }
}
@@ -28,6 +28,8 @@ convert_case.workspace = true
copilot.workspace = true
credentials_provider.workspace = true
deepseek = { workspace = true, features = ["schemars"] }
+extension.workspace = true
+extension_host.workspace = true
fs.workspace = true
futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] }
@@ -0,0 +1,67 @@
+use collections::HashMap;
+use extension::{
+ ExtensionHostProxy, ExtensionLanguageModelProviderProxy, LanguageModelProviderRegistration,
+};
+use gpui::{App, Entity};
+use language_model::{LanguageModelProviderId, LanguageModelRegistry};
+use std::sync::{Arc, LazyLock};
+
+/// Maps built-in provider IDs to their corresponding extension IDs.
+/// When an extension with this ID is installed, the built-in provider should be hidden.
+static BUILTIN_TO_EXTENSION_MAP: LazyLock<HashMap<&'static str, &'static str>> =
+ LazyLock::new(|| {
+ let mut map = HashMap::default();
+ map.insert("anthropic", "anthropic");
+ map.insert("openai", "openai");
+ map.insert("google", "google-ai");
+ map.insert("openrouter", "openrouter");
+ map.insert("copilot_chat", "copilot-chat");
+ map
+ });
+
+/// Returns the extension ID that should hide the given built-in provider.
+pub fn extension_for_builtin_provider(provider_id: &str) -> Option<&'static str> {
+ BUILTIN_TO_EXTENSION_MAP.get(provider_id).copied()
+}
+
+/// Proxy that registers extension language model providers with the LanguageModelRegistry.
+pub struct LanguageModelProviderRegistryProxy {
+ registry: Entity<LanguageModelRegistry>,
+}
+
+impl LanguageModelProviderRegistryProxy {
+ pub fn new(registry: Entity<LanguageModelRegistry>) -> Self {
+ Self { registry }
+ }
+}
+
+impl ExtensionLanguageModelProviderProxy for LanguageModelProviderRegistryProxy {
+ fn register_language_model_provider(
+ &self,
+ _provider_id: Arc<str>,
+ register_fn: LanguageModelProviderRegistration,
+ cx: &mut App,
+ ) {
+ register_fn(cx);
+ }
+
+ fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {
+ self.registry.update(cx, |registry, cx| {
+ registry.unregister_provider(LanguageModelProviderId::from(provider_id), cx);
+ });
+ }
+}
+
+/// Initialize the extension language model provider proxy.
+/// This must be called BEFORE extension_host::init to ensure the proxy is available
+/// when extensions try to register their language model providers.
+pub fn init_proxy(cx: &mut App) {
+ let proxy = ExtensionHostProxy::default_global(cx);
+ let registry = LanguageModelRegistry::global(cx);
+
+ registry.update(cx, |registry, _cx| {
+ registry.set_builtin_provider_hiding_fn(Box::new(extension_for_builtin_provider));
+ });
+
+ proxy.register_language_model_provider_proxy(LanguageModelProviderRegistryProxy::new(registry));
+}
@@ -7,9 +7,12 @@ use gpui::{App, Context, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use provider::deepseek::DeepSeekLanguageModelProvider;
+pub mod extension;
pub mod provider;
mod settings;
+pub use crate::extension::init_proxy as init_extension_proxy;
+
use crate::provider::anthropic::AnthropicLanguageModelProvider;
use crate::provider::bedrock::BedrockLanguageModelProvider;
use crate::provider::cloud::CloudLanguageModelProvider;
@@ -31,6 +34,56 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
register_language_model_providers(registry, user_store, client.clone(), cx);
});
+ // Subscribe to extension store events to track LLM extension installations
+ if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
+ cx.subscribe(&extension_store, {
+ let registry = registry.clone();
+ move |extension_store, event, cx| match event {
+ extension_host::Event::ExtensionInstalled(extension_id) => {
+ if let Some(manifest) = extension_store
+ .read(cx)
+ .extension_manifest_for_id(extension_id)
+ {
+ if !manifest.language_model_providers.is_empty() {
+ registry.update(cx, |registry, cx| {
+ registry.extension_installed(extension_id.clone(), cx);
+ });
+ }
+ }
+ }
+ extension_host::Event::ExtensionUninstalled(extension_id) => {
+ registry.update(cx, |registry, cx| {
+ registry.extension_uninstalled(extension_id, cx);
+ });
+ }
+ extension_host::Event::ExtensionsUpdated => {
+ let mut new_ids = HashSet::default();
+ for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
+ if !entry.manifest.language_model_providers.is_empty() {
+ new_ids.insert(extension_id.clone());
+ }
+ }
+ registry.update(cx, |registry, cx| {
+ registry.sync_installed_llm_extensions(new_ids, cx);
+ });
+ }
+ _ => {}
+ }
+ })
+ .detach();
+
+ // Initialize with currently installed extensions
+ registry.update(cx, |registry, cx| {
+ let mut initial_ids = HashSet::default();
+ for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
+ if !entry.manifest.language_model_providers.is_empty() {
+ initial_ids.insert(extension_id.clone());
+ }
+ }
+ registry.sync_installed_llm_extensions(initial_ids, cx);
+ });
+ }
+
let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
.openai_compatible
.keys()
@@ -8,7 +8,7 @@ use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::B
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
use http_client::HttpClient;
use language_model::{
- ApiKeyState, AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModel,
+ ApiKeyState, AuthenticateError, ConfigurationViewTargetAgent, EnvVar, IconOrSvg, LanguageModel,
LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
@@ -125,8 +125,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
PROVIDER_NAME
}
- fn icon(&self) -> IconName {
- IconName::AiAnthropic
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiAnthropic)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -30,7 +30,7 @@ use gpui::{
use gpui_tokio::Tokio;
use http_client::HttpClient;
use language_model::{
- AuthenticateError, EnvVar, LanguageModel, LanguageModelCacheConfiguration,
+ AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
@@ -426,8 +426,8 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
PROVIDER_NAME
}
- fn icon(&self) -> IconName {
- IconName::AiBedrock
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiBedrock)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -19,7 +19,7 @@ use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Ta
use http_client::http::{HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
use language_model::{
- AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
+ AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
@@ -304,8 +304,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
PROVIDER_NAME
}
- fn icon(&self) -> IconName {
- IconName::AiZed
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiZed)
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -18,12 +18,12 @@ use gpui::{AnyView, App, AsyncApp, Entity, Subscription, Task};
use http_client::StatusCode;
use language::language_settings::all_language_settings;
use language_model::{
- AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
- LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
- LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
- LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
- LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
- StopReason, TokenUsage,
+ AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCompletionError,
+ LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
+ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
+ LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolChoice,
+ LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
+ MessageContent, RateLimiter, Role, StopReason, TokenUsage,
};
use settings::SettingsStore;
use ui::prelude::*;
@@ -104,8 +104,8 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
PROVIDER_NAME
}
- fn icon(&self) -> IconName {
- IconName::Copilot
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::Copilot)
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -7,7 +7,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture, stream::BoxStream
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
- ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
+ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
@@ -127,8 +127,8 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
PROVIDER_NAME
}
- fn icon(&self) -> IconName {
- IconName::AiDeepSeek
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiDeepSeek)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -14,7 +14,7 @@ use language_model::{
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
};
use language_model::{
- LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
+ IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, RateLimiter, Role,
};
@@ -164,8 +164,8 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
PROVIDER_NAME
}
- fn icon(&self) -> IconName {
- IconName::AiGoogle
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiGoogle)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -10,7 +10,7 @@ use language_model::{
StopReason, TokenUsage,
};
use language_model::{
- LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
+ IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, RateLimiter, Role,
};
@@ -175,8 +175,8 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
PROVIDER_NAME
}
- fn icon(&self) -> IconName {
- IconName::AiLmStudio
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiLmStudio)
}
fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -5,7 +5,7 @@ use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::B
use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
- ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
+ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
@@ -176,8 +176,8 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
PROVIDER_NAME
}
- fn icon(&self) -> IconName {
- IconName::AiMistral
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiMistral)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -5,7 +5,7 @@ use futures::{Stream, TryFutureExt, stream};
use gpui::{AnyView, App, AsyncApp, Context, CursorStyle, Entity, Task};
use http_client::HttpClient;
use language_model::{
- ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
+ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
@@ -221,8 +221,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
PROVIDER_NAME
}
- fn icon(&self) -> IconName {
- IconName::AiOllama
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiOllama)
}
fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -5,7 +5,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
- ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
+ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
@@ -122,8 +122,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
PROVIDER_NAME
}
- fn icon(&self) -> IconName {
- IconName::AiOpenAi
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiOpenAi)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -4,7 +4,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
- ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
+ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
@@ -133,8 +133,8 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
self.name.clone()
}
- fn icon(&self) -> IconName {
- IconName::AiOpenAiCompat
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiOpenAiCompat)
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -4,7 +4,7 @@ use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task};
use http_client::HttpClient;
use language_model::{
- ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
+ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
@@ -180,8 +180,8 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
PROVIDER_NAME
}
- fn icon(&self) -> IconName {
- IconName::AiOpenRouter
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiOpenRouter)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -4,7 +4,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
- ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
+ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, RateLimiter, Role, env_var,
@@ -117,8 +117,8 @@ impl LanguageModelProvider for VercelLanguageModelProvider {
PROVIDER_NAME
}
- fn icon(&self) -> IconName {
- IconName::AiVZero
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiVZero)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -4,7 +4,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window};
use http_client::HttpClient;
use language_model::{
- ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
+ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
@@ -118,8 +118,8 @@ impl LanguageModelProvider for XAiLanguageModelProvider {
PROVIDER_NAME
}
- fn icon(&self) -> IconName {
- IconName::AiXAi
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiXAi)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -126,17 +126,6 @@ enum IconSource {
ExternalSvg(SharedString),
}
-impl IconSource {
- fn from_path(path: impl Into<SharedString>) -> Self {
- let path = path.into();
- if path.starts_with("icons/") {
- Self::Embedded(path)
- } else {
- Self::External(Arc::from(PathBuf::from(path.as_ref())))
- }
- }
-}
-
#[derive(IntoElement, RegisterComponent)]
pub struct Icon {
source: IconSource,
@@ -155,9 +144,18 @@ impl Icon {
}
}
+ /// Create an icon from a path. Uses a heuristic to determine if it's embedded or external:
+ /// - Paths starting with "icons/" are treated as embedded SVGs
+ /// - Other paths are treated as external raster images (from icon themes)
pub fn from_path(path: impl Into<SharedString>) -> Self {
+ let path = path.into();
+ let source = if path.starts_with("icons/") {
+ IconSource::Embedded(path)
+ } else {
+ IconSource::External(Arc::from(PathBuf::from(path.as_ref())))
+ };
Self {
- source: IconSource::from_path(path),
+ source,
color: Color::default(),
size: IconSize::default().rems(),
transformation: Transformation::default(),