Detailed changes
@@ -230,6 +230,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
+ "strum",
"tokio",
]
@@ -376,6 +377,7 @@ dependencies = [
"settings",
"smol",
"strsim 0.11.1",
+ "strum",
"telemetry_events",
"theme",
"tiktoken-rs",
@@ -6983,6 +6985,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
+ "strum",
]
[[package]]
@@ -201,7 +201,8 @@
"context": "AssistantPanel",
"bindings": {
"ctrl-g": "search::SelectNextMatch",
- "ctrl-shift-g": "search::SelectPrevMatch"
+ "ctrl-shift-g": "search::SelectPrevMatch",
+ "alt-m": "assistant::ToggleModelSelector"
}
},
{
@@ -214,10 +214,11 @@
}
},
{
- "context": "AssistantPanel", // Used in the assistant crate, which we're replacing
+ "context": "AssistantPanel",
"bindings": {
"cmd-g": "search::SelectNextMatch",
- "cmd-shift-g": "search::SelectPrevMatch"
+ "cmd-shift-g": "search::SelectPrevMatch",
+ "alt-m": "assistant::ToggleModelSelector"
}
},
{
@@ -23,6 +23,7 @@ isahc.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
+strum.workspace = true
[dev-dependencies]
tokio.workspace = true
@@ -4,11 +4,12 @@ use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use std::{convert::TryFrom, time::Duration};
+use strum::EnumIter;
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
-#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
#[default]
#[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
@@ -49,6 +49,7 @@ serde_json.workspace = true
settings.workspace = true
smol.workspace = true
strsim = "0.11"
+strum.workspace = true
telemetry_events.workspace = true
theme.workspace = true
tiktoken-rs.workspace = true
@@ -2,6 +2,7 @@ pub mod assistant_panel;
pub mod assistant_settings;
mod codegen;
mod completion_provider;
+mod model_selector;
mod prompts;
mod saved_conversation;
mod search;
@@ -15,6 +16,7 @@ use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
pub(crate) use completion_provider::*;
use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
+pub(crate) use model_selector::*;
pub(crate) use saved_conversation::*;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize};
@@ -38,7 +40,8 @@ actions!(
InsertActivePrompt,
ToggleHistory,
ApplyEdit,
- ConfirmCommand
+ ConfirmCommand,
+ ToggleModelSelector
]
);
@@ -1,7 +1,7 @@
use crate::prompts::{generate_content_prompt, PromptLibrary, PromptManager};
use crate::slash_command::{rustdoc_command, search_command, tabs_command};
use crate::{
- assistant_settings::{AssistantDockPosition, AssistantSettings, ZedDotDevModel},
+ assistant_settings::{AssistantDockPosition, AssistantSettings},
codegen::{self, Codegen, CodegenKind},
search::*,
slash_command::{
@@ -9,10 +9,11 @@ use crate::{
SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry,
},
ApplyEdit, Assist, CompletionProvider, ConfirmCommand, CycleMessageRole, InlineAssist,
- LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata,
- MessageStatus, QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata,
- SavedMessage, Split, ToggleFocus, ToggleHistory,
+ LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus,
+ QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
+ Split, ToggleFocus, ToggleHistory,
};
+use crate::{ModelSelector, ToggleModelSelector};
use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection};
use client::telemetry::Telemetry;
@@ -64,8 +65,8 @@ use std::{
use telemetry_events::AssistantKind;
use theme::ThemeSettings;
use ui::{
- popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding, Tab, TabBar,
- Tooltip,
+ popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding,
+ PopoverMenuHandle, Tab, TabBar, Tooltip,
};
use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
use uuid::Uuid;
@@ -119,8 +120,8 @@ pub struct AssistantPanel {
pending_inline_assist_ids_by_editor: HashMap<WeakView<Editor>, Vec<usize>>,
inline_prompt_history: VecDeque<String>,
_watch_saved_conversations: Task<Result<()>>,
- model: LanguageModel,
authentication_prompt: Option<AnyView>,
+ model_menu_handle: PopoverMenuHandle<ContextMenu>,
}
struct ActiveConversationEditor {
@@ -203,7 +204,6 @@ impl AssistantPanel {
}
}),
];
- let model = CompletionProvider::global(cx).default_model();
cx.observe_global::<FileIcons>(|_, cx| {
cx.notify();
@@ -244,8 +244,8 @@ impl AssistantPanel {
pending_inline_assist_ids_by_editor: Default::default(),
inline_prompt_history: Default::default(),
_watch_saved_conversations,
- model,
authentication_prompt: None,
+ model_menu_handle: PopoverMenuHandle::default(),
}
})
})
@@ -277,12 +277,20 @@ impl AssistantPanel {
if self.is_authenticated(cx) {
self.authentication_prompt = None;
- let model = CompletionProvider::global(cx).default_model();
- self.set_model(model, cx);
+ if let Some(editor) = self.active_conversation_editor() {
+ editor.update(cx, |active_conversation, cx| {
+ active_conversation
+ .conversation
+ .update(cx, |conversation, cx| {
+ conversation.completion_provider_changed(cx)
+ })
+ })
+ }
if self.active_conversation_editor().is_none() {
self.new_conversation(cx);
}
+ cx.notify();
} else if self.authentication_prompt.is_none()
|| prev_settings_version != CompletionProvider::global(cx).settings_version()
{
@@ -290,6 +298,7 @@ impl AssistantPanel {
Some(cx.update_global::<CompletionProvider, _>(|provider, cx| {
provider.authentication_prompt(cx)
}));
+ cx.notify();
}
}
@@ -734,7 +743,7 @@ impl AssistantPanel {
.map(|message| message.to_request_message(buffer)),
);
}
- let model = self.model.clone();
+ let model = CompletionProvider::global(cx).model();
cx.spawn(|_, mut cx| async move {
// I Don't know if we want to return a ? here.
@@ -809,7 +818,6 @@ impl AssistantPanel {
let editor = cx.new_view(|cx| {
ConversationEditor::new(
- self.model.clone(),
self.languages.clone(),
self.slash_commands.clone(),
self.fs.clone(),
@@ -850,53 +858,6 @@ impl AssistantPanel {
cx.notify();
}
- fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
- let next_model = match &self.model {
- LanguageModel::OpenAi(model) => LanguageModel::OpenAi(match &model {
- open_ai::Model::ThreePointFiveTurbo => open_ai::Model::Four,
- open_ai::Model::Four => open_ai::Model::FourTurbo,
- open_ai::Model::FourTurbo => open_ai::Model::FourOmni,
- open_ai::Model::FourOmni => open_ai::Model::ThreePointFiveTurbo,
- }),
- LanguageModel::Anthropic(model) => LanguageModel::Anthropic(match &model {
- anthropic::Model::Claude3Opus => anthropic::Model::Claude3Sonnet,
- anthropic::Model::Claude3Sonnet => anthropic::Model::Claude3Haiku,
- anthropic::Model::Claude3Haiku => anthropic::Model::Claude3Opus,
- }),
- LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
- ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
- ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
- ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Gpt4Omni,
- ZedDotDevModel::Gpt4Omni => ZedDotDevModel::Claude3Opus,
- ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet,
- ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku,
- ZedDotDevModel::Claude3Haiku => {
- match CompletionProvider::global(cx).default_model() {
- LanguageModel::ZedDotDev(custom @ ZedDotDevModel::Custom(_)) => custom,
- _ => ZedDotDevModel::Gpt3Point5Turbo,
- }
- }
- ZedDotDevModel::Custom(_) => ZedDotDevModel::Gpt3Point5Turbo,
- }),
- };
-
- self.set_model(next_model, cx);
- }
-
- fn set_model(&mut self, model: LanguageModel, cx: &mut ViewContext<Self>) {
- self.model = model.clone();
- if let Some(editor) = self.active_conversation_editor() {
- editor.update(cx, |active_conversation, cx| {
- active_conversation
- .conversation
- .update(cx, |conversation, cx| {
- conversation.set_model(model, cx);
- })
- })
- }
- cx.notify();
- }
-
fn handle_conversation_editor_event(
&mut self,
_: View<ConversationEditor>,
@@ -978,6 +939,10 @@ impl AssistantPanel {
.detach_and_log_err(cx);
}
+ fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) {
+ self.model_menu_handle.toggle(cx);
+ }
+
fn active_conversation_editor(&self) -> Option<&View<ConversationEditor>> {
Some(&self.active_conversation_editor.as_ref()?.editor)
}
@@ -1133,10 +1098,8 @@ impl AssistantPanel {
cx.spawn(|this, mut cx| async move {
let saved_conversation = SavedConversation::load(&path, fs.as_ref()).await?;
- let model = this.update(&mut cx, |this, _| this.model.clone())?;
let conversation = Conversation::deserialize(
saved_conversation,
- model,
path.clone(),
languages,
slash_commands,
@@ -1206,7 +1169,10 @@ impl AssistantPanel {
this.child(
h_flex()
.gap_1()
- .child(self.render_model(&conversation, cx))
+ .child(ModelSelector::new(
+ self.model_menu_handle.clone(),
+ self.fs.clone(),
+ ))
.children(self.render_remaining_tokens(&conversation, cx)),
)
.child(
@@ -1256,6 +1222,7 @@ impl AssistantPanel {
.on_action(cx.listener(AssistantPanel::select_prev_match))
.on_action(cx.listener(AssistantPanel::handle_editor_cancel))
.on_action(cx.listener(AssistantPanel::reset_credentials))
+ .on_action(cx.listener(AssistantPanel::toggle_model_selector))
.track_focus(&self.focus_handle)
.child(header)
.children(if self.toolbar.read(cx).hidden() {
@@ -1314,23 +1281,12 @@ impl AssistantPanel {
))
}
- fn render_model(
- &self,
- conversation: &Model<Conversation>,
- cx: &mut ViewContext<Self>,
- ) -> impl IntoElement {
- Button::new("current_model", conversation.read(cx).model.display_name())
- .style(ButtonStyle::Filled)
- .tooltip(move |cx| Tooltip::text("Change Model", cx))
- .on_click(cx.listener(|this, _, cx| this.cycle_model(cx)))
- }
-
fn render_remaining_tokens(
&self,
conversation: &Model<Conversation>,
cx: &mut ViewContext<Self>,
) -> Option<impl IntoElement> {
- let remaining_tokens = conversation.read(cx).remaining_tokens()?;
+ let remaining_tokens = conversation.read(cx).remaining_tokens(cx)?;
let remaining_tokens_color = if remaining_tokens <= 0 {
Color::Error
} else if remaining_tokens <= 500 {
@@ -1486,7 +1442,6 @@ pub struct Conversation {
pending_summary: Task<Option<()>>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
- model: LanguageModel,
token_count: Option<usize>,
pending_token_count: Task<Option<()>>,
pending_edit_suggestion_parse: Option<Task<()>>,
@@ -1502,7 +1457,6 @@ impl EventEmitter<ConversationEvent> for Conversation {}
impl Conversation {
fn new(
- model: LanguageModel,
language_registry: Arc<LanguageRegistry>,
slash_command_registry: Arc<SlashCommandRegistry>,
telemetry: Option<Arc<Telemetry>>,
@@ -1530,7 +1484,6 @@ impl Conversation {
token_count: None,
pending_token_count: Task::ready(None),
pending_edit_suggestion_parse: None,
- model,
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: None,
@@ -1583,7 +1536,6 @@ impl Conversation {
#[allow(clippy::too_many_arguments)]
async fn deserialize(
saved_conversation: SavedConversation,
- model: LanguageModel,
path: PathBuf,
language_registry: Arc<LanguageRegistry>,
slash_command_registry: Arc<SlashCommandRegistry>,
@@ -1640,7 +1592,6 @@ impl Conversation {
token_count: None,
pending_edit_suggestion_parse: None,
pending_token_count: Task::ready(None),
- model,
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: Some(path),
@@ -1938,12 +1889,12 @@ impl Conversation {
}
}
- fn remaining_tokens(&self) -> Option<isize> {
- Some(self.model.max_token_count() as isize - self.token_count? as isize)
+ fn remaining_tokens(&self, cx: &AppContext) -> Option<isize> {
+ let model = CompletionProvider::global(cx).model();
+ Some(model.max_token_count() as isize - self.token_count? as isize)
}
- fn set_model(&mut self, model: LanguageModel, cx: &mut ModelContext<Self>) {
- self.model = model;
+ fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
self.count_remaining_tokens(cx);
}
@@ -2079,10 +2030,11 @@ impl Conversation {
}
if let Some(telemetry) = this.telemetry.as_ref() {
+ let model = CompletionProvider::global(cx).model();
telemetry.report_assistant_event(
this.id.clone(),
AssistantKind::Panel,
- this.model.telemetry_id(),
+ model.telemetry_id(),
response_latency,
error_message,
);
@@ -2111,7 +2063,7 @@ impl Conversation {
.map(|message| message.to_request_message(self.buffer.read(cx)));
LanguageModelRequest {
- model: self.model.clone(),
+ model: CompletionProvider::global(cx).model(),
messages: messages.collect(),
stop: vec![],
temperature: 1.0,
@@ -2300,7 +2252,7 @@ impl Conversation {
.into(),
}));
let request = LanguageModelRequest {
- model: self.model.clone(),
+ model: CompletionProvider::global(cx).model(),
messages: messages.collect(),
stop: vec![],
temperature: 1.0,
@@ -2605,7 +2557,6 @@ pub struct ConversationEditor {
impl ConversationEditor {
fn new(
- model: LanguageModel,
language_registry: Arc<LanguageRegistry>,
slash_command_registry: Arc<SlashCommandRegistry>,
fs: Arc<dyn Fs>,
@@ -2618,7 +2569,6 @@ impl ConversationEditor {
let conversation = cx.new_model(|cx| {
Conversation::new(
- model,
language_registry,
slash_command_registry,
Some(telemetry),
@@ -3847,15 +3797,8 @@ mod tests {
init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
- let conversation = cx.new_model(|cx| {
- Conversation::new(
- LanguageModel::default(),
- registry,
- Default::default(),
- None,
- cx,
- )
- });
+ let conversation =
+ cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3986,15 +3929,8 @@ mod tests {
init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
- let conversation = cx.new_model(|cx| {
- Conversation::new(
- LanguageModel::default(),
- registry,
- Default::default(),
- None,
- cx,
- )
- });
+ let conversation =
+ cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -4092,15 +4028,8 @@ mod tests {
cx.set_global(settings_store);
init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
- let conversation = cx.new_model(|cx| {
- Conversation::new(
- LanguageModel::default(),
- registry,
- Default::default(),
- None,
- cx,
- )
- });
+ let conversation =
+ cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -4209,15 +4138,8 @@ mod tests {
));
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
- let conversation = cx.new_model(|cx| {
- Conversation::new(
- LanguageModel::default(),
- registry.clone(),
- slash_command_registry,
- None,
- cx,
- )
- });
+ let conversation = cx
+ .new_model(|cx| Conversation::new(registry.clone(), slash_command_registry, None, cx));
let output_ranges = Rc::new(RefCell::new(HashSet::default()));
conversation.update(cx, |_, cx| {
@@ -4390,15 +4312,8 @@ mod tests {
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
cx.update(init);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
- let conversation = cx.new_model(|cx| {
- Conversation::new(
- LanguageModel::default(),
- registry.clone(),
- Default::default(),
- None,
- cx,
- )
- });
+ let conversation =
+ cx.new_model(|cx| Conversation::new(registry.clone(), Default::default(), None, cx));
let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
let message_0 =
conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id);
@@ -4434,7 +4349,6 @@ mod tests {
let deserialized_conversation = Conversation::deserialize(
conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)),
- LanguageModel::default(),
Default::default(),
registry.clone(),
Default::default(),
@@ -12,8 +12,11 @@ use serde::{
Deserialize, Deserializer, Serialize, Serializer,
};
use settings::{Settings, SettingsSources};
+use strum::{EnumIter, IntoEnumIterator};
-#[derive(Clone, Debug, Default, PartialEq)]
+use crate::LanguageModel;
+
+#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum ZedDotDevModel {
Gpt3Point5Turbo,
Gpt4,
@@ -53,13 +56,10 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
where
E: de::Error,
{
- match value {
- "gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
- "gpt-4" => Ok(ZedDotDevModel::Gpt4),
- "gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
- "gpt-4o" => Ok(ZedDotDevModel::Gpt4Omni),
- _ => Ok(ZedDotDevModel::Custom(value.to_owned())),
- }
+ let model = ZedDotDevModel::iter()
+ .find(|model| model.id() == value)
+ .unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string()));
+ Ok(model)
}
}
@@ -73,24 +73,23 @@ impl JsonSchema for ZedDotDevModel {
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
- let variants = vec![
- "gpt-3.5-turbo".to_owned(),
- "gpt-4".to_owned(),
- "gpt-4-turbo-preview".to_owned(),
- "gpt-4o".to_owned(),
- ];
+ let variants = ZedDotDevModel::iter()
+ .filter_map(|model| {
+ let id = model.id();
+ if id.is_empty() {
+ None
+ } else {
+ Some(id.to_string())
+ }
+ })
+ .collect::<Vec<_>>();
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
- enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
+ enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
- default: Some(serde_json::json!("gpt-4-turbo-preview")),
- examples: vec![
- serde_json::json!("gpt-3.5-turbo"),
- serde_json::json!("gpt-4"),
- serde_json::json!("gpt-4-turbo-preview"),
- serde_json::json!("custom-model-name"),
- ],
+ default: Some(ZedDotDevModel::default().id().into()),
+ examples: variants.into_iter().map(Into::into).collect(),
..Default::default()
})),
..Default::default()
@@ -145,51 +144,55 @@ pub enum AssistantDockPosition {
Bottom,
}
-#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
-#[serde(tag = "name", rename_all = "snake_case")]
+#[derive(Debug, PartialEq)]
pub enum AssistantProvider {
- #[serde(rename = "zed.dev")]
ZedDotDev {
- #[serde(default)]
- default_model: ZedDotDevModel,
+ model: ZedDotDevModel,
},
- #[serde(rename = "openai")]
OpenAi {
- #[serde(default)]
- default_model: OpenAiModel,
- #[serde(default = "open_ai_url")]
+ model: OpenAiModel,
api_url: String,
- #[serde(default)]
low_speed_timeout_in_seconds: Option<u64>,
},
- #[serde(rename = "anthropic")]
Anthropic {
- #[serde(default)]
- default_model: AnthropicModel,
- #[serde(default = "anthropic_api_url")]
+ model: AnthropicModel,
api_url: String,
- #[serde(default)]
low_speed_timeout_in_seconds: Option<u64>,
},
}
impl Default for AssistantProvider {
fn default() -> Self {
- Self::ZedDotDev {
- default_model: ZedDotDevModel::default(),
+ Self::OpenAi {
+ model: OpenAiModel::default(),
+ api_url: open_ai::OPEN_AI_API_URL.into(),
+ low_speed_timeout_in_seconds: None,
}
}
}
-fn open_ai_url() -> String {
- open_ai::OPEN_AI_API_URL.to_string()
-}
-
-fn anthropic_api_url() -> String {
- anthropic::ANTHROPIC_API_URL.to_string()
+#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
+#[serde(tag = "name", rename_all = "snake_case")]
+pub enum AssistantProviderContent {
+ #[serde(rename = "zed.dev")]
+ ZedDotDev {
+ default_model: Option<ZedDotDevModel>,
+ },
+ #[serde(rename = "openai")]
+ OpenAi {
+ default_model: Option<OpenAiModel>,
+ api_url: Option<String>,
+ low_speed_timeout_in_seconds: Option<u64>,
+ },
+ #[serde(rename = "anthropic")]
+ Anthropic {
+ default_model: Option<AnthropicModel>,
+ api_url: Option<String>,
+ low_speed_timeout_in_seconds: Option<u64>,
+ },
}
-#[derive(Default, Debug, Deserialize, Serialize)]
+#[derive(Debug, Default)]
pub struct AssistantSettings {
pub enabled: bool,
pub button: bool,
@@ -240,16 +243,16 @@ impl AssistantSettingsContent {
default_width: settings.default_width,
default_height: settings.default_height,
provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
- Some(AssistantProvider::OpenAi {
- default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
- api_url: open_ai_api_url.clone(),
+ Some(AssistantProviderContent::OpenAi {
+ default_model: settings.default_open_ai_model.clone(),
+ api_url: Some(open_ai_api_url.clone()),
low_speed_timeout_in_seconds: None,
})
} else {
settings.default_open_ai_model.clone().map(|open_ai_model| {
- AssistantProvider::OpenAi {
- default_model: open_ai_model,
- api_url: open_ai_url(),
+ AssistantProviderContent::OpenAi {
+ default_model: Some(open_ai_model),
+ api_url: None,
low_speed_timeout_in_seconds: None,
}
})
@@ -270,6 +273,64 @@ impl AssistantSettingsContent {
}
}
}
+
+ pub fn set_model(&mut self, new_model: LanguageModel) {
+ match self {
+ AssistantSettingsContent::Versioned(settings) => match settings {
+ VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
+ Some(AssistantProviderContent::ZedDotDev {
+ default_model: model,
+ }) => {
+ if let LanguageModel::ZedDotDev(new_model) = new_model {
+ *model = Some(new_model);
+ }
+ }
+ Some(AssistantProviderContent::OpenAi {
+ default_model: model,
+ ..
+ }) => {
+ if let LanguageModel::OpenAi(new_model) = new_model {
+ *model = Some(new_model);
+ }
+ }
+ Some(AssistantProviderContent::Anthropic {
+ default_model: model,
+ ..
+ }) => {
+ if let LanguageModel::Anthropic(new_model) = new_model {
+ *model = Some(new_model);
+ }
+ }
+ provider => match new_model {
+ LanguageModel::ZedDotDev(model) => {
+ *provider = Some(AssistantProviderContent::ZedDotDev {
+ default_model: Some(model),
+ })
+ }
+ LanguageModel::OpenAi(model) => {
+ *provider = Some(AssistantProviderContent::OpenAi {
+ default_model: Some(model),
+ api_url: None,
+ low_speed_timeout_in_seconds: None,
+ })
+ }
+ LanguageModel::Anthropic(model) => {
+ *provider = Some(AssistantProviderContent::Anthropic {
+ default_model: Some(model),
+ api_url: None,
+ low_speed_timeout_in_seconds: None,
+ })
+ }
+ },
+ },
+ },
+ AssistantSettingsContent::Legacy(settings) => {
+ if let LanguageModel::OpenAi(model) = new_model {
+ settings.default_open_ai_model = Some(model);
+ }
+ }
+ }
+ }
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
@@ -318,7 +379,7 @@ pub struct AssistantSettingsContentV1 {
///
/// This can either be the internal `zed.dev` service or an external `openai` service,
/// each with their respective default models and configurations.
- provider: Option<AssistantProvider>,
+ provider: Option<AssistantProviderContent>,
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
@@ -376,31 +437,82 @@ impl Settings for AssistantSettings {
if let Some(provider) = value.provider.clone() {
match (&mut settings.provider, provider) {
(
- AssistantProvider::ZedDotDev { default_model },
- AssistantProvider::ZedDotDev {
- default_model: default_model_override,
+ AssistantProvider::ZedDotDev { model },
+ AssistantProviderContent::ZedDotDev {
+ default_model: model_override,
},
) => {
- *default_model = default_model_override;
+ merge(model, model_override);
}
(
AssistantProvider::OpenAi {
- default_model,
+ model,
api_url,
low_speed_timeout_in_seconds,
},
- AssistantProvider::OpenAi {
- default_model: default_model_override,
+ AssistantProviderContent::OpenAi {
+ default_model: model_override,
+ api_url: api_url_override,
+ low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
+ },
+ ) => {
+ merge(model, model_override);
+ merge(api_url, api_url_override);
+ if let Some(low_speed_timeout_in_seconds_override) =
+ low_speed_timeout_in_seconds_override
+ {
+ *low_speed_timeout_in_seconds =
+ Some(low_speed_timeout_in_seconds_override);
+ }
+ }
+ (
+ AssistantProvider::Anthropic {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ },
+ AssistantProviderContent::Anthropic {
+ default_model: model_override,
api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
},
) => {
- *default_model = default_model_override;
- *api_url = api_url_override;
- *low_speed_timeout_in_seconds = low_speed_timeout_in_seconds_override;
+ merge(model, model_override);
+ merge(api_url, api_url_override);
+ if let Some(low_speed_timeout_in_seconds_override) =
+ low_speed_timeout_in_seconds_override
+ {
+ *low_speed_timeout_in_seconds =
+ Some(low_speed_timeout_in_seconds_override);
+ }
}
- (merged, provider_override) => {
- *merged = provider_override;
+ (provider, provider_override) => {
+ *provider = match provider_override {
+ AssistantProviderContent::ZedDotDev {
+ default_model: model,
+ } => AssistantProvider::ZedDotDev {
+ model: model.unwrap_or_default(),
+ },
+ AssistantProviderContent::OpenAi {
+ default_model: model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ } => AssistantProvider::OpenAi {
+ model: model.unwrap_or_default(),
+ api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
+ low_speed_timeout_in_seconds,
+ },
+ AssistantProviderContent::Anthropic {
+ default_model: model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ } => AssistantProvider::Anthropic {
+ model: model.unwrap_or_default(),
+ api_url: api_url
+ .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
+ low_speed_timeout_in_seconds,
+ },
+ };
}
}
}
@@ -410,7 +522,7 @@ impl Settings for AssistantSettings {
}
}
-fn merge<T: Copy>(target: &mut T, value: Option<T>) {
+fn merge<T>(target: &mut T, value: Option<T>) {
if let Some(value) = value {
*target = value;
}
@@ -433,8 +545,8 @@ mod tests {
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
- default_model: OpenAiModel::FourOmni,
- api_url: open_ai_url(),
+ model: OpenAiModel::FourOmni,
+ api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
}
);
@@ -455,7 +567,7 @@ mod tests {
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
- default_model: OpenAiModel::FourOmni,
+ model: OpenAiModel::FourOmni,
api_url: "test-url".into(),
low_speed_timeout_in_seconds: None,
}
@@ -475,8 +587,8 @@ mod tests {
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
- default_model: OpenAiModel::Four,
- api_url: open_ai_url(),
+ model: OpenAiModel::Four,
+ api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
}
);
@@ -501,7 +613,7 @@ mod tests {
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev {
- default_model: ZedDotDevModel::Custom("custom".into())
+ model: ZedDotDevModel::Custom("custom".into())
}
);
}
@@ -25,31 +25,26 @@ use std::time::Duration;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let mut settings_version = 0;
let provider = match &AssistantSettings::get_global(cx).provider {
- AssistantProvider::ZedDotDev { default_model } => {
- CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
- default_model.clone(),
- client.clone(),
- settings_version,
- cx,
- ))
- }
+ AssistantProvider::ZedDotDev { model } => CompletionProvider::ZedDotDev(
+ ZedDotDevCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
+ ),
AssistantProvider::OpenAi {
- default_model,
+ model,
api_url,
low_speed_timeout_in_seconds,
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
- default_model.clone(),
+ model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
)),
AssistantProvider::Anthropic {
- default_model,
+ model,
api_url,
low_speed_timeout_in_seconds,
} => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
- default_model.clone(),
+ model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
@@ -65,13 +60,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
(
CompletionProvider::OpenAi(provider),
AssistantProvider::OpenAi {
- default_model,
+ model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
provider.update(
- default_model.clone(),
+ model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
@@ -80,13 +75,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
(
CompletionProvider::Anthropic(provider),
AssistantProvider::Anthropic {
- default_model,
+ model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
provider.update(
- default_model.clone(),
+ model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
@@ -94,13 +89,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
}
(
CompletionProvider::ZedDotDev(provider),
- AssistantProvider::ZedDotDev { default_model },
+ AssistantProvider::ZedDotDev { model },
) => {
- provider.update(default_model.clone(), settings_version);
+ provider.update(model.clone(), settings_version);
}
- (_, AssistantProvider::ZedDotDev { default_model }) => {
+ (_, AssistantProvider::ZedDotDev { model }) => {
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
- default_model.clone(),
+ model.clone(),
client.clone(),
settings_version,
cx,
@@ -109,13 +104,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
(
_,
AssistantProvider::OpenAi {
- default_model,
+ model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
- default_model.clone(),
+ model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
@@ -125,13 +120,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
(
_,
AssistantProvider::Anthropic {
- default_model,
+ model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
*provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
- default_model.clone(),
+ model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
@@ -159,6 +154,25 @@ impl CompletionProvider {
cx.global::<Self>()
}
+ pub fn available_models(&self) -> Vec<LanguageModel> {
+ match self {
+ CompletionProvider::OpenAi(provider) => provider
+ .available_models()
+ .map(LanguageModel::OpenAi)
+ .collect(),
+ CompletionProvider::Anthropic(provider) => provider
+ .available_models()
+ .map(LanguageModel::Anthropic)
+ .collect(),
+ CompletionProvider::ZedDotDev(provider) => provider
+ .available_models()
+ .map(LanguageModel::ZedDotDev)
+ .collect(),
+ #[cfg(test)]
+ CompletionProvider::Fake(_) => unimplemented!(),
+ }
+ }
+
pub fn settings_version(&self) -> usize {
match self {
CompletionProvider::OpenAi(provider) => provider.settings_version(),
@@ -209,17 +223,13 @@ impl CompletionProvider {
}
}
- pub fn default_model(&self) -> LanguageModel {
+ pub fn model(&self) -> LanguageModel {
match self {
- CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
- CompletionProvider::Anthropic(provider) => {
- LanguageModel::Anthropic(provider.default_model())
- }
- CompletionProvider::ZedDotDev(provider) => {
- LanguageModel::ZedDotDev(provider.default_model())
- }
+ CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
+ CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
+ CompletionProvider::ZedDotDev(provider) => LanguageModel::ZedDotDev(provider.model()),
#[cfg(test)]
- CompletionProvider::Fake(_) => unimplemented!(),
+ CompletionProvider::Fake(_) => LanguageModel::default(),
}
}
@@ -12,6 +12,7 @@ use http::HttpClient;
use settings::Settings;
use std::time::Duration;
use std::{env, sync::Arc};
+use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
use util::ResultExt;
@@ -19,7 +20,7 @@ use util::ResultExt;
pub struct AnthropicCompletionProvider {
api_key: Option<String>,
api_url: String,
- default_model: AnthropicModel,
+ model: AnthropicModel,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
@@ -27,7 +28,7 @@ pub struct AnthropicCompletionProvider {
impl AnthropicCompletionProvider {
pub fn new(
- default_model: AnthropicModel,
+ model: AnthropicModel,
api_url: String,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
@@ -36,7 +37,7 @@ impl AnthropicCompletionProvider {
Self {
api_key: None,
api_url,
- default_model,
+ model,
http_client,
low_speed_timeout,
settings_version,
@@ -45,17 +46,21 @@ impl AnthropicCompletionProvider {
pub fn update(
&mut self,
- default_model: AnthropicModel,
+ model: AnthropicModel,
api_url: String,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) {
- self.default_model = default_model;
+ self.model = model;
self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version;
}
+ pub fn available_models(&self) -> impl Iterator<Item = AnthropicModel> {
+ AnthropicModel::iter()
+ }
+
pub fn settings_version(&self) -> usize {
self.settings_version
}
@@ -105,8 +110,8 @@ impl AnthropicCompletionProvider {
.into()
}
- pub fn default_model(&self) -> AnthropicModel {
- self.default_model.clone()
+ pub fn model(&self) -> AnthropicModel {
+ self.model.clone()
}
pub fn count_tokens(
@@ -165,7 +170,7 @@ impl AnthropicCompletionProvider {
fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request {
let model = match request.model {
LanguageModel::Anthropic(model) => model,
- _ => self.default_model(),
+ _ => self.model(),
};
let mut system_message = String::new();
@@ -11,6 +11,7 @@ use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
use settings::Settings;
use std::time::Duration;
use std::{env, sync::Arc};
+use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
use util::ResultExt;
@@ -18,7 +19,7 @@ use util::ResultExt;
pub struct OpenAiCompletionProvider {
api_key: Option<String>,
api_url: String,
- default_model: OpenAiModel,
+ model: OpenAiModel,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
@@ -26,7 +27,7 @@ pub struct OpenAiCompletionProvider {
impl OpenAiCompletionProvider {
pub fn new(
- default_model: OpenAiModel,
+ model: OpenAiModel,
api_url: String,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
@@ -35,7 +36,7 @@ impl OpenAiCompletionProvider {
Self {
api_key: None,
api_url,
- default_model,
+ model,
http_client,
low_speed_timeout,
settings_version,
@@ -44,17 +45,21 @@ impl OpenAiCompletionProvider {
pub fn update(
&mut self,
- default_model: OpenAiModel,
+ model: OpenAiModel,
api_url: String,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) {
- self.default_model = default_model;
+ self.model = model;
self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version;
}
+ pub fn available_models(&self) -> impl Iterator<Item = OpenAiModel> {
+ OpenAiModel::iter()
+ }
+
pub fn settings_version(&self) -> usize {
self.settings_version
}
@@ -104,8 +109,8 @@ impl OpenAiCompletionProvider {
.into()
}
- pub fn default_model(&self) -> OpenAiModel {
- self.default_model.clone()
+ pub fn model(&self) -> OpenAiModel {
+ self.model.clone()
}
pub fn count_tokens(
@@ -152,7 +157,7 @@ impl OpenAiCompletionProvider {
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
let model = match request.model {
LanguageModel::OpenAi(model) => model,
- _ => self.default_model(),
+ _ => self.model(),
};
Request {
@@ -7,11 +7,12 @@ use client::{proto, Client};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, Task};
use std::{future, sync::Arc};
+use strum::IntoEnumIterator;
use ui::prelude::*;
pub struct ZedDotDevCompletionProvider {
client: Arc<Client>,
- default_model: ZedDotDevModel,
+ model: ZedDotDevModel,
settings_version: usize,
status: client::Status,
_maintain_client_status: Task<()>,
@@ -19,7 +20,7 @@ pub struct ZedDotDevCompletionProvider {
impl ZedDotDevCompletionProvider {
pub fn new(
- default_model: ZedDotDevModel,
+ model: ZedDotDevModel,
client: Arc<Client>,
settings_version: usize,
cx: &mut AppContext,
@@ -39,24 +40,39 @@ impl ZedDotDevCompletionProvider {
});
Self {
client,
- default_model,
+ model,
settings_version,
status,
_maintain_client_status: maintain_client_status,
}
}
- pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) {
- self.default_model = default_model;
+ pub fn update(&mut self, model: ZedDotDevModel, settings_version: usize) {
+ self.model = model;
self.settings_version = settings_version;
}
+ pub fn available_models(&self) -> impl Iterator<Item = ZedDotDevModel> {
+ let mut custom_model = if let ZedDotDevModel::Custom(custom_model) = self.model.clone() {
+ Some(custom_model)
+ } else {
+ None
+ };
+ ZedDotDevModel::iter().filter_map(move |model| {
+ if let ZedDotDevModel::Custom(_) = model {
+ Some(ZedDotDevModel::Custom(custom_model.take()?))
+ } else {
+ Some(model)
+ }
+ })
+ }
+
pub fn settings_version(&self) -> usize {
self.settings_version
}
- pub fn default_model(&self) -> ZedDotDevModel {
- self.default_model.clone()
+ pub fn model(&self) -> ZedDotDevModel {
+ self.model.clone()
}
pub fn is_authenticated(&self) -> bool {
@@ -0,0 +1,84 @@
+use std::sync::Arc;
+
+use crate::{assistant_settings::AssistantSettings, CompletionProvider, ToggleModelSelector};
+use fs::Fs;
+use settings::update_settings_file;
+use ui::{popover_menu, prelude::*, ButtonLike, ContextMenu, PopoverMenuHandle, Tooltip};
+
+#[derive(IntoElement)]
+pub struct ModelSelector {
+ handle: PopoverMenuHandle<ContextMenu>,
+ fs: Arc<dyn Fs>,
+}
+
+impl ModelSelector {
+ pub fn new(handle: PopoverMenuHandle<ContextMenu>, fs: Arc<dyn Fs>) -> Self {
+ ModelSelector { handle, fs }
+ }
+}
+
+impl RenderOnce for ModelSelector {
+ fn render(self, cx: &mut WindowContext) -> impl IntoElement {
+ popover_menu("model-switcher")
+ .with_handle(self.handle)
+ .menu(move |cx| {
+ ContextMenu::build(cx, |mut menu, cx| {
+ for model in CompletionProvider::global(cx).available_models() {
+ menu = menu.custom_entry(
+ {
+ let model = model.clone();
+ move |_| Label::new(model.display_name()).into_any_element()
+ },
+ {
+ let fs = self.fs.clone();
+ let model = model.clone();
+ move |cx| {
+ let model = model.clone();
+ update_settings_file::<AssistantSettings>(
+ fs.clone(),
+ cx,
+ move |settings| settings.set_model(model),
+ );
+ }
+ },
+ );
+ }
+ menu
+ })
+ .into()
+ })
+ .trigger(
+ ButtonLike::new("active-model")
+ .child(
+ h_flex()
+ .w_full()
+ .gap_0p5()
+ .child(
+ div()
+ .overflow_x_hidden()
+ .flex_grow()
+ .whitespace_nowrap()
+ .child(
+ Label::new(
+ CompletionProvider::global(cx).model().display_name(),
+ )
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ )
+ .child(
+ div().child(
+ Icon::new(IconName::ChevronDown)
+ .color(Color::Muted)
+ .size(IconSize::XSmall),
+ ),
+ ),
+ )
+ .style(ButtonStyle::Subtle)
+ .tooltip(move |cx| {
+ Tooltip::for_action("Change Model", &ToggleModelSelector, cx)
+ }),
+ )
+ .anchor(gpui::AnchorCorner::BottomRight)
+ }
+}
@@ -20,3 +20,4 @@ isahc.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
+strum.workspace = true
@@ -4,8 +4,8 @@ use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
-use std::time::Duration;
-use std::{convert::TryFrom, future::Future};
+use std::{convert::TryFrom, future::Future, time::Duration};
+use strum::EnumIter;
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
@@ -44,7 +44,7 @@ impl From<Role> for String {
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
-#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
#[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
ThreePointFiveTurbo,
@@ -13,6 +13,51 @@ pub trait PopoverTrigger: IntoElement + Clickable + Selectable + 'static {}
impl<T: IntoElement + Clickable + Selectable + 'static> PopoverTrigger for T {}
+pub struct PopoverMenuHandle<M>(Rc<RefCell<Option<PopoverMenuHandleState<M>>>>);
+
+impl<M> Clone for PopoverMenuHandle<M> {
+ fn clone(&self) -> Self {
+ Self(self.0.clone())
+ }
+}
+
+impl<M> Default for PopoverMenuHandle<M> {
+ fn default() -> Self {
+ Self(Rc::default())
+ }
+}
+
+struct PopoverMenuHandleState<M> {
+ menu_builder: Rc<dyn Fn(&mut WindowContext) -> Option<View<M>>>,
+ menu: Rc<RefCell<Option<View<M>>>>,
+}
+
+impl<M: ManagedView> PopoverMenuHandle<M> {
+ pub fn show(&self, cx: &mut WindowContext) {
+ if let Some(state) = self.0.borrow().as_ref() {
+ show_menu(&state.menu_builder, &state.menu, cx);
+ }
+ }
+
+ pub fn hide(&self, cx: &mut WindowContext) {
+ if let Some(state) = self.0.borrow().as_ref() {
+ if let Some(menu) = state.menu.borrow().as_ref() {
+ menu.update(cx, |_, cx| cx.emit(DismissEvent));
+ }
+ }
+ }
+
+ pub fn toggle(&self, cx: &mut WindowContext) {
+ if let Some(state) = self.0.borrow().as_ref() {
+ if state.menu.borrow().is_some() {
+ self.hide(cx);
+ } else {
+ self.show(cx);
+ }
+ }
+ }
+}
+
pub struct PopoverMenu<M: ManagedView> {
id: ElementId,
child_builder: Option<
@@ -28,6 +73,7 @@ pub struct PopoverMenu<M: ManagedView> {
anchor: AnchorCorner,
attach: Option<AnchorCorner>,
offset: Option<Point<Pixels>>,
+ trigger_handle: Option<PopoverMenuHandle<M>>,
}
impl<M: ManagedView> PopoverMenu<M> {
@@ -36,35 +82,17 @@ impl<M: ManagedView> PopoverMenu<M> {
self
}
+ pub fn with_handle(mut self, handle: PopoverMenuHandle<M>) -> Self {
+ self.trigger_handle = Some(handle);
+ self
+ }
+
pub fn trigger<T: PopoverTrigger>(mut self, t: T) -> Self {
self.child_builder = Some(Box::new(|menu, builder| {
let open = menu.borrow().is_some();
t.selected(open)
.when_some(builder, |el, builder| {
- el.on_click({
- move |_, cx| {
- let Some(new_menu) = (builder)(cx) else {
- return;
- };
- let menu2 = menu.clone();
- let previous_focus_handle = cx.focused();
-
- cx.subscribe(&new_menu, move |modal, _: &DismissEvent, cx| {
- if modal.focus_handle(cx).contains_focused(cx) {
- if let Some(previous_focus_handle) =
- previous_focus_handle.as_ref()
- {
- cx.focus(previous_focus_handle);
- }
- }
- *menu2.borrow_mut() = None;
- cx.refresh();
- })
- .detach();
- cx.focus_view(&new_menu);
- *menu.borrow_mut() = Some(new_menu);
- }
- })
+ el.on_click(move |_, cx| show_menu(&builder, &menu, cx))
})
.into_any_element()
}));
@@ -111,6 +139,32 @@ impl<M: ManagedView> PopoverMenu<M> {
}
}
+fn show_menu<M: ManagedView>(
+ builder: &Rc<dyn Fn(&mut WindowContext) -> Option<View<M>>>,
+ menu: &Rc<RefCell<Option<View<M>>>>,
+ cx: &mut WindowContext,
+) {
+ let Some(new_menu) = (builder)(cx) else {
+ return;
+ };
+ let menu2 = menu.clone();
+ let previous_focus_handle = cx.focused();
+
+ cx.subscribe(&new_menu, move |modal, _: &DismissEvent, cx| {
+ if modal.focus_handle(cx).contains_focused(cx) {
+ if let Some(previous_focus_handle) = previous_focus_handle.as_ref() {
+ cx.focus(previous_focus_handle);
+ }
+ }
+ *menu2.borrow_mut() = None;
+ cx.refresh();
+ })
+ .detach();
+ cx.focus_view(&new_menu);
+ *menu.borrow_mut() = Some(new_menu);
+ cx.refresh();
+}
+
/// Creates a [`PopoverMenu`]
pub fn popover_menu<M: ManagedView>(id: impl Into<ElementId>) -> PopoverMenu<M> {
PopoverMenu {
@@ -120,6 +174,7 @@ pub fn popover_menu<M: ManagedView>(id: impl Into<ElementId>) -> PopoverMenu<M>
anchor: AnchorCorner::TopLeft,
attach: None,
offset: None,
+ trigger_handle: None,
}
}
@@ -190,6 +245,15 @@ impl<M: ManagedView> Element for PopoverMenu<M> {
(child_builder)(element_state.menu.clone(), self.menu_builder.clone())
});
+ if let Some(trigger_handle) = self.trigger_handle.take() {
+ if let Some(menu_builder) = self.menu_builder.clone() {
+ *trigger_handle.0.borrow_mut() = Some(PopoverMenuHandleState {
+ menu_builder,
+ menu: element_state.menu.clone(),
+ });
+ }
+ }
+
let child_layout_id = child_element
.as_mut()
.map(|child_element| child_element.request_layout(cx));