Improve model selection in the assistant (#12472)

Antonio Scandurra created

https://github.com/zed-industries/zed/assets/482957/3b017850-b7b6-457a-9b2f-324d5533442e


Release Notes:

- Improved the UX for selecting a model in the assistant panel. You can
now switch model using just the keyboard by pressing `alt-m`. Also, when
switching models via the UI, settings will now be updated automatically.

Change summary

Cargo.lock                                            |   3 
assets/keymaps/default-linux.json                     |   3 
assets/keymaps/default-macos.json                     |   5 
crates/anthropic/Cargo.toml                           |   1 
crates/anthropic/src/anthropic.rs                     |   3 
crates/assistant/Cargo.toml                           |   1 
crates/assistant/src/assistant.rs                     |   5 
crates/assistant/src/assistant_panel.rs               | 184 ++------
crates/assistant/src/assistant_settings.rs            | 256 +++++++++---
crates/assistant/src/completion_provider.rs           |  76 ++-
crates/assistant/src/completion_provider/anthropic.rs |  21 
crates/assistant/src/completion_provider/open_ai.rs   |  21 
crates/assistant/src/completion_provider/zed.rs       |  30 +
crates/assistant/src/model_selector.rs                |  84 ++++
crates/open_ai/Cargo.toml                             |   1 
crates/open_ai/src/open_ai.rs                         |   6 
crates/ui/src/components/popover_menu.rs              | 112 ++++-
17 files changed, 517 insertions(+), 295 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -230,6 +230,7 @@ dependencies = [
  "schemars",
  "serde",
  "serde_json",
+ "strum",
  "tokio",
 ]
 
@@ -376,6 +377,7 @@ dependencies = [
  "settings",
  "smol",
  "strsim 0.11.1",
+ "strum",
  "telemetry_events",
  "theme",
  "tiktoken-rs",
@@ -6983,6 +6985,7 @@ dependencies = [
  "schemars",
  "serde",
  "serde_json",
+ "strum",
 ]
 
 [[package]]

assets/keymaps/default-linux.json 🔗

@@ -201,7 +201,8 @@
     "context": "AssistantPanel",
     "bindings": {
       "ctrl-g": "search::SelectNextMatch",
-      "ctrl-shift-g": "search::SelectPrevMatch"
+      "ctrl-shift-g": "search::SelectPrevMatch",
+      "alt-m": "assistant::ToggleModelSelector"
     }
   },
   {

assets/keymaps/default-macos.json 🔗

@@ -214,10 +214,11 @@
     }
   },
   {
-    "context": "AssistantPanel", // Used in the assistant crate, which we're replacing
+    "context": "AssistantPanel",
     "bindings": {
       "cmd-g": "search::SelectNextMatch",
-      "cmd-shift-g": "search::SelectPrevMatch"
+      "cmd-shift-g": "search::SelectPrevMatch",
+      "alt-m": "assistant::ToggleModelSelector"
     }
   },
   {

crates/anthropic/Cargo.toml 🔗

@@ -23,6 +23,7 @@ isahc.workspace = true
 schemars = { workspace = true, optional = true }
 serde.workspace = true
 serde_json.workspace = true
+strum.workspace = true
 
 [dev-dependencies]
 tokio.workspace = true

crates/anthropic/src/anthropic.rs 🔗

@@ -4,11 +4,12 @@ use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use isahc::config::Configurable;
 use serde::{Deserialize, Serialize};
 use std::{convert::TryFrom, time::Duration};
+use strum::EnumIter;
 
 pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
 
 #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
-#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 pub enum Model {
     #[default]
     #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]

crates/assistant/Cargo.toml 🔗

@@ -49,6 +49,7 @@ serde_json.workspace = true
 settings.workspace = true
 smol.workspace = true
 strsim = "0.11"
+strum.workspace = true
 telemetry_events.workspace = true
 theme.workspace = true
 tiktoken-rs.workspace = true

crates/assistant/src/assistant.rs 🔗

@@ -2,6 +2,7 @@ pub mod assistant_panel;
 pub mod assistant_settings;
 mod codegen;
 mod completion_provider;
+mod model_selector;
 mod prompts;
 mod saved_conversation;
 mod search;
@@ -15,6 +16,7 @@ use client::{proto, Client};
 use command_palette_hooks::CommandPaletteFilter;
 pub(crate) use completion_provider::*;
 use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
+pub(crate) use model_selector::*;
 pub(crate) use saved_conversation::*;
 use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 use serde::{Deserialize, Serialize};
@@ -38,7 +40,8 @@ actions!(
         InsertActivePrompt,
         ToggleHistory,
         ApplyEdit,
-        ConfirmCommand
+        ConfirmCommand,
+        ToggleModelSelector
     ]
 );
 

crates/assistant/src/assistant_panel.rs 🔗

@@ -1,7 +1,7 @@
 use crate::prompts::{generate_content_prompt, PromptLibrary, PromptManager};
 use crate::slash_command::{rustdoc_command, search_command, tabs_command};
 use crate::{
-    assistant_settings::{AssistantDockPosition, AssistantSettings, ZedDotDevModel},
+    assistant_settings::{AssistantDockPosition, AssistantSettings},
     codegen::{self, Codegen, CodegenKind},
     search::*,
     slash_command::{
@@ -9,10 +9,11 @@ use crate::{
         SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry,
     },
     ApplyEdit, Assist, CompletionProvider, ConfirmCommand, CycleMessageRole, InlineAssist,
-    LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata,
-    MessageStatus, QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata,
-    SavedMessage, Split, ToggleFocus, ToggleHistory,
+    LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus,
+    QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
+    Split, ToggleFocus, ToggleHistory,
 };
+use crate::{ModelSelector, ToggleModelSelector};
 use anyhow::{anyhow, Result};
 use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection};
 use client::telemetry::Telemetry;
@@ -64,8 +65,8 @@ use std::{
 use telemetry_events::AssistantKind;
 use theme::ThemeSettings;
 use ui::{
-    popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding, Tab, TabBar,
-    Tooltip,
+    popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding,
+    PopoverMenuHandle, Tab, TabBar, Tooltip,
 };
 use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
 use uuid::Uuid;
@@ -119,8 +120,8 @@ pub struct AssistantPanel {
     pending_inline_assist_ids_by_editor: HashMap<WeakView<Editor>, Vec<usize>>,
     inline_prompt_history: VecDeque<String>,
     _watch_saved_conversations: Task<Result<()>>,
-    model: LanguageModel,
     authentication_prompt: Option<AnyView>,
+    model_menu_handle: PopoverMenuHandle<ContextMenu>,
 }
 
 struct ActiveConversationEditor {
@@ -203,7 +204,6 @@ impl AssistantPanel {
                             }
                         }),
                     ];
-                    let model = CompletionProvider::global(cx).default_model();
 
                     cx.observe_global::<FileIcons>(|_, cx| {
                         cx.notify();
@@ -244,8 +244,8 @@ impl AssistantPanel {
                         pending_inline_assist_ids_by_editor: Default::default(),
                         inline_prompt_history: Default::default(),
                         _watch_saved_conversations,
-                        model,
                         authentication_prompt: None,
+                        model_menu_handle: PopoverMenuHandle::default(),
                     }
                 })
             })
@@ -277,12 +277,20 @@ impl AssistantPanel {
         if self.is_authenticated(cx) {
             self.authentication_prompt = None;
 
-            let model = CompletionProvider::global(cx).default_model();
-            self.set_model(model, cx);
+            if let Some(editor) = self.active_conversation_editor() {
+                editor.update(cx, |active_conversation, cx| {
+                    active_conversation
+                        .conversation
+                        .update(cx, |conversation, cx| {
+                            conversation.completion_provider_changed(cx)
+                        })
+                })
+            }
 
             if self.active_conversation_editor().is_none() {
                 self.new_conversation(cx);
             }
+            cx.notify();
         } else if self.authentication_prompt.is_none()
             || prev_settings_version != CompletionProvider::global(cx).settings_version()
         {
@@ -290,6 +298,7 @@ impl AssistantPanel {
                 Some(cx.update_global::<CompletionProvider, _>(|provider, cx| {
                     provider.authentication_prompt(cx)
                 }));
+            cx.notify();
         }
     }
 
@@ -734,7 +743,7 @@ impl AssistantPanel {
                     .map(|message| message.to_request_message(buffer)),
             );
         }
-        let model = self.model.clone();
+        let model = CompletionProvider::global(cx).model();
 
         cx.spawn(|_, mut cx| async move {
             // I Don't know if we want to return a ? here.
@@ -809,7 +818,6 @@ impl AssistantPanel {
 
         let editor = cx.new_view(|cx| {
             ConversationEditor::new(
-                self.model.clone(),
                 self.languages.clone(),
                 self.slash_commands.clone(),
                 self.fs.clone(),
@@ -850,53 +858,6 @@ impl AssistantPanel {
         cx.notify();
     }
 
-    fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
-        let next_model = match &self.model {
-            LanguageModel::OpenAi(model) => LanguageModel::OpenAi(match &model {
-                open_ai::Model::ThreePointFiveTurbo => open_ai::Model::Four,
-                open_ai::Model::Four => open_ai::Model::FourTurbo,
-                open_ai::Model::FourTurbo => open_ai::Model::FourOmni,
-                open_ai::Model::FourOmni => open_ai::Model::ThreePointFiveTurbo,
-            }),
-            LanguageModel::Anthropic(model) => LanguageModel::Anthropic(match &model {
-                anthropic::Model::Claude3Opus => anthropic::Model::Claude3Sonnet,
-                anthropic::Model::Claude3Sonnet => anthropic::Model::Claude3Haiku,
-                anthropic::Model::Claude3Haiku => anthropic::Model::Claude3Opus,
-            }),
-            LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
-                ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
-                ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
-                ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Gpt4Omni,
-                ZedDotDevModel::Gpt4Omni => ZedDotDevModel::Claude3Opus,
-                ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet,
-                ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku,
-                ZedDotDevModel::Claude3Haiku => {
-                    match CompletionProvider::global(cx).default_model() {
-                        LanguageModel::ZedDotDev(custom @ ZedDotDevModel::Custom(_)) => custom,
-                        _ => ZedDotDevModel::Gpt3Point5Turbo,
-                    }
-                }
-                ZedDotDevModel::Custom(_) => ZedDotDevModel::Gpt3Point5Turbo,
-            }),
-        };
-
-        self.set_model(next_model, cx);
-    }
-
-    fn set_model(&mut self, model: LanguageModel, cx: &mut ViewContext<Self>) {
-        self.model = model.clone();
-        if let Some(editor) = self.active_conversation_editor() {
-            editor.update(cx, |active_conversation, cx| {
-                active_conversation
-                    .conversation
-                    .update(cx, |conversation, cx| {
-                        conversation.set_model(model, cx);
-                    })
-            })
-        }
-        cx.notify();
-    }
-
     fn handle_conversation_editor_event(
         &mut self,
         _: View<ConversationEditor>,
@@ -978,6 +939,10 @@ impl AssistantPanel {
             .detach_and_log_err(cx);
     }
 
+    fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) {
+        self.model_menu_handle.toggle(cx);
+    }
+
     fn active_conversation_editor(&self) -> Option<&View<ConversationEditor>> {
         Some(&self.active_conversation_editor.as_ref()?.editor)
     }
@@ -1133,10 +1098,8 @@ impl AssistantPanel {
 
         cx.spawn(|this, mut cx| async move {
             let saved_conversation = SavedConversation::load(&path, fs.as_ref()).await?;
-            let model = this.update(&mut cx, |this, _| this.model.clone())?;
             let conversation = Conversation::deserialize(
                 saved_conversation,
-                model,
                 path.clone(),
                 languages,
                 slash_commands,
@@ -1206,7 +1169,10 @@ impl AssistantPanel {
                         this.child(
                             h_flex()
                                 .gap_1()
-                                .child(self.render_model(&conversation, cx))
+                                .child(ModelSelector::new(
+                                    self.model_menu_handle.clone(),
+                                    self.fs.clone(),
+                                ))
                                 .children(self.render_remaining_tokens(&conversation, cx)),
                         )
                         .child(
@@ -1256,6 +1222,7 @@ impl AssistantPanel {
             .on_action(cx.listener(AssistantPanel::select_prev_match))
             .on_action(cx.listener(AssistantPanel::handle_editor_cancel))
             .on_action(cx.listener(AssistantPanel::reset_credentials))
+            .on_action(cx.listener(AssistantPanel::toggle_model_selector))
             .track_focus(&self.focus_handle)
             .child(header)
             .children(if self.toolbar.read(cx).hidden() {
@@ -1314,23 +1281,12 @@ impl AssistantPanel {
             ))
     }
 
-    fn render_model(
-        &self,
-        conversation: &Model<Conversation>,
-        cx: &mut ViewContext<Self>,
-    ) -> impl IntoElement {
-        Button::new("current_model", conversation.read(cx).model.display_name())
-            .style(ButtonStyle::Filled)
-            .tooltip(move |cx| Tooltip::text("Change Model", cx))
-            .on_click(cx.listener(|this, _, cx| this.cycle_model(cx)))
-    }
-
     fn render_remaining_tokens(
         &self,
         conversation: &Model<Conversation>,
         cx: &mut ViewContext<Self>,
     ) -> Option<impl IntoElement> {
-        let remaining_tokens = conversation.read(cx).remaining_tokens()?;
+        let remaining_tokens = conversation.read(cx).remaining_tokens(cx)?;
         let remaining_tokens_color = if remaining_tokens <= 0 {
             Color::Error
         } else if remaining_tokens <= 500 {
@@ -1486,7 +1442,6 @@ pub struct Conversation {
     pending_summary: Task<Option<()>>,
     completion_count: usize,
     pending_completions: Vec<PendingCompletion>,
-    model: LanguageModel,
     token_count: Option<usize>,
     pending_token_count: Task<Option<()>>,
     pending_edit_suggestion_parse: Option<Task<()>>,
@@ -1502,7 +1457,6 @@ impl EventEmitter<ConversationEvent> for Conversation {}
 
 impl Conversation {
     fn new(
-        model: LanguageModel,
         language_registry: Arc<LanguageRegistry>,
         slash_command_registry: Arc<SlashCommandRegistry>,
         telemetry: Option<Arc<Telemetry>>,
@@ -1530,7 +1484,6 @@ impl Conversation {
             token_count: None,
             pending_token_count: Task::ready(None),
             pending_edit_suggestion_parse: None,
-            model,
             _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
             pending_save: Task::ready(Ok(())),
             path: None,
@@ -1583,7 +1536,6 @@ impl Conversation {
     #[allow(clippy::too_many_arguments)]
     async fn deserialize(
         saved_conversation: SavedConversation,
-        model: LanguageModel,
         path: PathBuf,
         language_registry: Arc<LanguageRegistry>,
         slash_command_registry: Arc<SlashCommandRegistry>,
@@ -1640,7 +1592,6 @@ impl Conversation {
                 token_count: None,
                 pending_edit_suggestion_parse: None,
                 pending_token_count: Task::ready(None),
-                model,
                 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
                 pending_save: Task::ready(Ok(())),
                 path: Some(path),
@@ -1938,12 +1889,12 @@ impl Conversation {
         }
     }
 
-    fn remaining_tokens(&self) -> Option<isize> {
-        Some(self.model.max_token_count() as isize - self.token_count? as isize)
+    fn remaining_tokens(&self, cx: &AppContext) -> Option<isize> {
+        let model = CompletionProvider::global(cx).model();
+        Some(model.max_token_count() as isize - self.token_count? as isize)
     }
 
-    fn set_model(&mut self, model: LanguageModel, cx: &mut ModelContext<Self>) {
-        self.model = model;
+    fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
         self.count_remaining_tokens(cx);
     }
 
@@ -2079,10 +2030,11 @@ impl Conversation {
                             }
 
                             if let Some(telemetry) = this.telemetry.as_ref() {
+                                let model = CompletionProvider::global(cx).model();
                                 telemetry.report_assistant_event(
                                     this.id.clone(),
                                     AssistantKind::Panel,
-                                    this.model.telemetry_id(),
+                                    model.telemetry_id(),
                                     response_latency,
                                     error_message,
                                 );
@@ -2111,7 +2063,7 @@ impl Conversation {
             .map(|message| message.to_request_message(self.buffer.read(cx)));
 
         LanguageModelRequest {
-            model: self.model.clone(),
+            model: CompletionProvider::global(cx).model(),
             messages: messages.collect(),
             stop: vec![],
             temperature: 1.0,
@@ -2300,7 +2252,7 @@ impl Conversation {
                         .into(),
                 }));
             let request = LanguageModelRequest {
-                model: self.model.clone(),
+                model: CompletionProvider::global(cx).model(),
                 messages: messages.collect(),
                 stop: vec![],
                 temperature: 1.0,
@@ -2605,7 +2557,6 @@ pub struct ConversationEditor {
 
 impl ConversationEditor {
     fn new(
-        model: LanguageModel,
         language_registry: Arc<LanguageRegistry>,
         slash_command_registry: Arc<SlashCommandRegistry>,
         fs: Arc<dyn Fs>,
@@ -2618,7 +2569,6 @@ impl ConversationEditor {
 
         let conversation = cx.new_model(|cx| {
             Conversation::new(
-                model,
                 language_registry,
                 slash_command_registry,
                 Some(telemetry),
@@ -3847,15 +3797,8 @@ mod tests {
         init(cx);
         let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 
-        let conversation = cx.new_model(|cx| {
-            Conversation::new(
-                LanguageModel::default(),
-                registry,
-                Default::default(),
-                None,
-                cx,
-            )
-        });
+        let conversation =
+            cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
         let buffer = conversation.read(cx).buffer.clone();
 
         let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3986,15 +3929,8 @@ mod tests {
         init(cx);
         let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 
-        let conversation = cx.new_model(|cx| {
-            Conversation::new(
-                LanguageModel::default(),
-                registry,
-                Default::default(),
-                None,
-                cx,
-            )
-        });
+        let conversation =
+            cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
         let buffer = conversation.read(cx).buffer.clone();
 
         let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -4092,15 +4028,8 @@ mod tests {
         cx.set_global(settings_store);
         init(cx);
         let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
-        let conversation = cx.new_model(|cx| {
-            Conversation::new(
-                LanguageModel::default(),
-                registry,
-                Default::default(),
-                None,
-                cx,
-            )
-        });
+        let conversation =
+            cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
         let buffer = conversation.read(cx).buffer.clone();
 
         let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -4209,15 +4138,8 @@ mod tests {
         ));
 
         let registry = Arc::new(LanguageRegistry::test(cx.executor()));
-        let conversation = cx.new_model(|cx| {
-            Conversation::new(
-                LanguageModel::default(),
-                registry.clone(),
-                slash_command_registry,
-                None,
-                cx,
-            )
-        });
+        let conversation = cx
+            .new_model(|cx| Conversation::new(registry.clone(), slash_command_registry, None, cx));
 
         let output_ranges = Rc::new(RefCell::new(HashSet::default()));
         conversation.update(cx, |_, cx| {
@@ -4390,15 +4312,8 @@ mod tests {
         cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
         cx.update(init);
         let registry = Arc::new(LanguageRegistry::test(cx.executor()));
-        let conversation = cx.new_model(|cx| {
-            Conversation::new(
-                LanguageModel::default(),
-                registry.clone(),
-                Default::default(),
-                None,
-                cx,
-            )
-        });
+        let conversation =
+            cx.new_model(|cx| Conversation::new(registry.clone(), Default::default(), None, cx));
         let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
         let message_0 =
             conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id);
@@ -4434,7 +4349,6 @@ mod tests {
 
         let deserialized_conversation = Conversation::deserialize(
             conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)),
-            LanguageModel::default(),
             Default::default(),
             registry.clone(),
             Default::default(),

crates/assistant/src/assistant_settings.rs 🔗

@@ -12,8 +12,11 @@ use serde::{
     Deserialize, Deserializer, Serialize, Serializer,
 };
 use settings::{Settings, SettingsSources};
+use strum::{EnumIter, IntoEnumIterator};
 
-#[derive(Clone, Debug, Default, PartialEq)]
+use crate::LanguageModel;
+
+#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
 pub enum ZedDotDevModel {
     Gpt3Point5Turbo,
     Gpt4,
@@ -53,13 +56,10 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
             where
                 E: de::Error,
             {
-                match value {
-                    "gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
-                    "gpt-4" => Ok(ZedDotDevModel::Gpt4),
-                    "gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
-                    "gpt-4o" => Ok(ZedDotDevModel::Gpt4Omni),
-                    _ => Ok(ZedDotDevModel::Custom(value.to_owned())),
-                }
+                let model = ZedDotDevModel::iter()
+                    .find(|model| model.id() == value)
+                    .unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string()));
+                Ok(model)
             }
         }
 
@@ -73,24 +73,23 @@ impl JsonSchema for ZedDotDevModel {
     }
 
     fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
-        let variants = vec![
-            "gpt-3.5-turbo".to_owned(),
-            "gpt-4".to_owned(),
-            "gpt-4-turbo-preview".to_owned(),
-            "gpt-4o".to_owned(),
-        ];
+        let variants = ZedDotDevModel::iter()
+            .filter_map(|model| {
+                let id = model.id();
+                if id.is_empty() {
+                    None
+                } else {
+                    Some(id.to_string())
+                }
+            })
+            .collect::<Vec<_>>();
         Schema::Object(SchemaObject {
             instance_type: Some(InstanceType::String.into()),
-            enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
+            enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
             metadata: Some(Box::new(Metadata {
                 title: Some("ZedDotDevModel".to_owned()),
-                default: Some(serde_json::json!("gpt-4-turbo-preview")),
-                examples: vec![
-                    serde_json::json!("gpt-3.5-turbo"),
-                    serde_json::json!("gpt-4"),
-                    serde_json::json!("gpt-4-turbo-preview"),
-                    serde_json::json!("custom-model-name"),
-                ],
+                default: Some(ZedDotDevModel::default().id().into()),
+                examples: variants.into_iter().map(Into::into).collect(),
                 ..Default::default()
             })),
             ..Default::default()
@@ -145,51 +144,55 @@ pub enum AssistantDockPosition {
     Bottom,
 }
 
-#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
-#[serde(tag = "name", rename_all = "snake_case")]
+#[derive(Debug, PartialEq)]
 pub enum AssistantProvider {
-    #[serde(rename = "zed.dev")]
     ZedDotDev {
-        #[serde(default)]
-        default_model: ZedDotDevModel,
+        model: ZedDotDevModel,
     },
-    #[serde(rename = "openai")]
     OpenAi {
-        #[serde(default)]
-        default_model: OpenAiModel,
-        #[serde(default = "open_ai_url")]
+        model: OpenAiModel,
         api_url: String,
-        #[serde(default)]
         low_speed_timeout_in_seconds: Option<u64>,
     },
-    #[serde(rename = "anthropic")]
     Anthropic {
-        #[serde(default)]
-        default_model: AnthropicModel,
-        #[serde(default = "anthropic_api_url")]
+        model: AnthropicModel,
         api_url: String,
-        #[serde(default)]
         low_speed_timeout_in_seconds: Option<u64>,
     },
 }
 
 impl Default for AssistantProvider {
     fn default() -> Self {
-        Self::ZedDotDev {
-            default_model: ZedDotDevModel::default(),
+        Self::OpenAi {
+            model: OpenAiModel::default(),
+            api_url: open_ai::OPEN_AI_API_URL.into(),
+            low_speed_timeout_in_seconds: None,
         }
     }
 }
 
-fn open_ai_url() -> String {
-    open_ai::OPEN_AI_API_URL.to_string()
-}
-
-fn anthropic_api_url() -> String {
-    anthropic::ANTHROPIC_API_URL.to_string()
+#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
+#[serde(tag = "name", rename_all = "snake_case")]
+pub enum AssistantProviderContent {
+    #[serde(rename = "zed.dev")]
+    ZedDotDev {
+        default_model: Option<ZedDotDevModel>,
+    },
+    #[serde(rename = "openai")]
+    OpenAi {
+        default_model: Option<OpenAiModel>,
+        api_url: Option<String>,
+        low_speed_timeout_in_seconds: Option<u64>,
+    },
+    #[serde(rename = "anthropic")]
+    Anthropic {
+        default_model: Option<AnthropicModel>,
+        api_url: Option<String>,
+        low_speed_timeout_in_seconds: Option<u64>,
+    },
 }
 
-#[derive(Default, Debug, Deserialize, Serialize)]
+#[derive(Debug, Default)]
 pub struct AssistantSettings {
     pub enabled: bool,
     pub button: bool,
@@ -240,16 +243,16 @@ impl AssistantSettingsContent {
                 default_width: settings.default_width,
                 default_height: settings.default_height,
                 provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
-                    Some(AssistantProvider::OpenAi {
-                        default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
-                        api_url: open_ai_api_url.clone(),
+                    Some(AssistantProviderContent::OpenAi {
+                        default_model: settings.default_open_ai_model.clone(),
+                        api_url: Some(open_ai_api_url.clone()),
                         low_speed_timeout_in_seconds: None,
                     })
                 } else {
                     settings.default_open_ai_model.clone().map(|open_ai_model| {
-                        AssistantProvider::OpenAi {
-                            default_model: open_ai_model,
-                            api_url: open_ai_url(),
+                        AssistantProviderContent::OpenAi {
+                            default_model: Some(open_ai_model),
+                            api_url: None,
                             low_speed_timeout_in_seconds: None,
                         }
                     })
@@ -270,6 +273,64 @@ impl AssistantSettingsContent {
             }
         }
     }
+
+    pub fn set_model(&mut self, new_model: LanguageModel) {
+        match self {
+            AssistantSettingsContent::Versioned(settings) => match settings {
+                VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
+                    Some(AssistantProviderContent::ZedDotDev {
+                        default_model: model,
+                    }) => {
+                        if let LanguageModel::ZedDotDev(new_model) = new_model {
+                            *model = Some(new_model);
+                        }
+                    }
+                    Some(AssistantProviderContent::OpenAi {
+                        default_model: model,
+                        ..
+                    }) => {
+                        if let LanguageModel::OpenAi(new_model) = new_model {
+                            *model = Some(new_model);
+                        }
+                    }
+                    Some(AssistantProviderContent::Anthropic {
+                        default_model: model,
+                        ..
+                    }) => {
+                        if let LanguageModel::Anthropic(new_model) = new_model {
+                            *model = Some(new_model);
+                        }
+                    }
+                    provider => match new_model {
+                        LanguageModel::ZedDotDev(model) => {
+                            *provider = Some(AssistantProviderContent::ZedDotDev {
+                                default_model: Some(model),
+                            })
+                        }
+                        LanguageModel::OpenAi(model) => {
+                            *provider = Some(AssistantProviderContent::OpenAi {
+                                default_model: Some(model),
+                                api_url: None,
+                                low_speed_timeout_in_seconds: None,
+                            })
+                        }
+                        LanguageModel::Anthropic(model) => {
+                            *provider = Some(AssistantProviderContent::Anthropic {
+                                default_model: Some(model),
+                                api_url: None,
+                                low_speed_timeout_in_seconds: None,
+                            })
+                        }
+                    },
+                },
+            },
+            AssistantSettingsContent::Legacy(settings) => {
+                if let LanguageModel::OpenAi(model) = new_model {
+                    settings.default_open_ai_model = Some(model);
+                }
+            }
+        }
+    }
 }
 
 #[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
@@ -318,7 +379,7 @@ pub struct AssistantSettingsContentV1 {
     ///
     /// This can either be the internal `zed.dev` service or an external `openai` service,
     /// each with their respective default models and configurations.
-    provider: Option<AssistantProvider>,
+    provider: Option<AssistantProviderContent>,
 }
 
 #[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
@@ -376,31 +437,82 @@ impl Settings for AssistantSettings {
             if let Some(provider) = value.provider.clone() {
                 match (&mut settings.provider, provider) {
                     (
-                        AssistantProvider::ZedDotDev { default_model },
-                        AssistantProvider::ZedDotDev {
-                            default_model: default_model_override,
+                        AssistantProvider::ZedDotDev { model },
+                        AssistantProviderContent::ZedDotDev {
+                            default_model: model_override,
                         },
                     ) => {
-                        *default_model = default_model_override;
+                        merge(model, model_override);
                     }
                     (
                         AssistantProvider::OpenAi {
-                            default_model,
+                            model,
                             api_url,
                             low_speed_timeout_in_seconds,
                         },
-                        AssistantProvider::OpenAi {
-                            default_model: default_model_override,
+                        AssistantProviderContent::OpenAi {
+                            default_model: model_override,
+                            api_url: api_url_override,
+                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
+                        },
+                    ) => {
+                        merge(model, model_override);
+                        merge(api_url, api_url_override);
+                        if let Some(low_speed_timeout_in_seconds_override) =
+                            low_speed_timeout_in_seconds_override
+                        {
+                            *low_speed_timeout_in_seconds =
+                                Some(low_speed_timeout_in_seconds_override);
+                        }
+                    }
+                    (
+                        AssistantProvider::Anthropic {
+                            model,
+                            api_url,
+                            low_speed_timeout_in_seconds,
+                        },
+                        AssistantProviderContent::Anthropic {
+                            default_model: model_override,
                             api_url: api_url_override,
                             low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
                         },
                     ) => {
-                        *default_model = default_model_override;
-                        *api_url = api_url_override;
-                        *low_speed_timeout_in_seconds = low_speed_timeout_in_seconds_override;
+                        merge(model, model_override);
+                        merge(api_url, api_url_override);
+                        if let Some(low_speed_timeout_in_seconds_override) =
+                            low_speed_timeout_in_seconds_override
+                        {
+                            *low_speed_timeout_in_seconds =
+                                Some(low_speed_timeout_in_seconds_override);
+                        }
                     }
-                    (merged, provider_override) => {
-                        *merged = provider_override;
+                    (provider, provider_override) => {
+                        *provider = match provider_override {
+                            AssistantProviderContent::ZedDotDev {
+                                default_model: model,
+                            } => AssistantProvider::ZedDotDev {
+                                model: model.unwrap_or_default(),
+                            },
+                            AssistantProviderContent::OpenAi {
+                                default_model: model,
+                                api_url,
+                                low_speed_timeout_in_seconds,
+                            } => AssistantProvider::OpenAi {
+                                model: model.unwrap_or_default(),
+                                api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
+                                low_speed_timeout_in_seconds,
+                            },
+                            AssistantProviderContent::Anthropic {
+                                default_model: model,
+                                api_url,
+                                low_speed_timeout_in_seconds,
+                            } => AssistantProvider::Anthropic {
+                                model: model.unwrap_or_default(),
+                                api_url: api_url
+                                    .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
+                                low_speed_timeout_in_seconds,
+                            },
+                        };
                     }
                 }
             }
@@ -410,7 +522,7 @@ impl Settings for AssistantSettings {
     }
 }
 
-fn merge<T: Copy>(target: &mut T, value: Option<T>) {
+fn merge<T>(target: &mut T, value: Option<T>) {
     if let Some(value) = value {
         *target = value;
     }
@@ -433,8 +545,8 @@ mod tests {
         assert_eq!(
             AssistantSettings::get_global(cx).provider,
             AssistantProvider::OpenAi {
-                default_model: OpenAiModel::FourOmni,
-                api_url: open_ai_url(),
+                model: OpenAiModel::FourOmni,
+                api_url: open_ai::OPEN_AI_API_URL.into(),
                 low_speed_timeout_in_seconds: None,
             }
         );
@@ -455,7 +567,7 @@ mod tests {
         assert_eq!(
             AssistantSettings::get_global(cx).provider,
             AssistantProvider::OpenAi {
-                default_model: OpenAiModel::FourOmni,
+                model: OpenAiModel::FourOmni,
                 api_url: "test-url".into(),
                 low_speed_timeout_in_seconds: None,
             }
@@ -475,8 +587,8 @@ mod tests {
         assert_eq!(
             AssistantSettings::get_global(cx).provider,
             AssistantProvider::OpenAi {
-                default_model: OpenAiModel::Four,
-                api_url: open_ai_url(),
+                model: OpenAiModel::Four,
+                api_url: open_ai::OPEN_AI_API_URL.into(),
                 low_speed_timeout_in_seconds: None,
             }
         );
@@ -501,7 +613,7 @@ mod tests {
         assert_eq!(
             AssistantSettings::get_global(cx).provider,
             AssistantProvider::ZedDotDev {
-                default_model: ZedDotDevModel::Custom("custom".into())
+                model: ZedDotDevModel::Custom("custom".into())
             }
         );
     }

crates/assistant/src/completion_provider.rs 🔗

@@ -25,31 +25,26 @@ use std::time::Duration;
 pub fn init(client: Arc<Client>, cx: &mut AppContext) {
     let mut settings_version = 0;
     let provider = match &AssistantSettings::get_global(cx).provider {
-        AssistantProvider::ZedDotDev { default_model } => {
-            CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
-                default_model.clone(),
-                client.clone(),
-                settings_version,
-                cx,
-            ))
-        }
+        AssistantProvider::ZedDotDev { model } => CompletionProvider::ZedDotDev(
+            ZedDotDevCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
+        ),
         AssistantProvider::OpenAi {
-            default_model,
+            model,
             api_url,
             low_speed_timeout_in_seconds,
         } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
-            default_model.clone(),
+            model.clone(),
             api_url.clone(),
             client.http_client(),
             low_speed_timeout_in_seconds.map(Duration::from_secs),
             settings_version,
         )),
         AssistantProvider::Anthropic {
-            default_model,
+            model,
             api_url,
             low_speed_timeout_in_seconds,
         } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
-            default_model.clone(),
+            model.clone(),
             api_url.clone(),
             client.http_client(),
             low_speed_timeout_in_seconds.map(Duration::from_secs),
@@ -65,13 +60,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
                 (
                     CompletionProvider::OpenAi(provider),
                     AssistantProvider::OpenAi {
-                        default_model,
+                        model,
                         api_url,
                         low_speed_timeout_in_seconds,
                     },
                 ) => {
                     provider.update(
-                        default_model.clone(),
+                        model.clone(),
                         api_url.clone(),
                         low_speed_timeout_in_seconds.map(Duration::from_secs),
                         settings_version,
@@ -80,13 +75,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
                 (
                     CompletionProvider::Anthropic(provider),
                     AssistantProvider::Anthropic {
-                        default_model,
+                        model,
                         api_url,
                         low_speed_timeout_in_seconds,
                     },
                 ) => {
                     provider.update(
-                        default_model.clone(),
+                        model.clone(),
                         api_url.clone(),
                         low_speed_timeout_in_seconds.map(Duration::from_secs),
                         settings_version,
@@ -94,13 +89,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
                 }
                 (
                     CompletionProvider::ZedDotDev(provider),
-                    AssistantProvider::ZedDotDev { default_model },
+                    AssistantProvider::ZedDotDev { model },
                 ) => {
-                    provider.update(default_model.clone(), settings_version);
+                    provider.update(model.clone(), settings_version);
                 }
-                (_, AssistantProvider::ZedDotDev { default_model }) => {
+                (_, AssistantProvider::ZedDotDev { model }) => {
                     *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
-                        default_model.clone(),
+                        model.clone(),
                         client.clone(),
                         settings_version,
                         cx,
@@ -109,13 +104,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
                 (
                     _,
                     AssistantProvider::OpenAi {
-                        default_model,
+                        model,
                         api_url,
                         low_speed_timeout_in_seconds,
                     },
                 ) => {
                     *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
-                        default_model.clone(),
+                        model.clone(),
                         api_url.clone(),
                         client.http_client(),
                         low_speed_timeout_in_seconds.map(Duration::from_secs),
@@ -125,13 +120,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
                 (
                     _,
                     AssistantProvider::Anthropic {
-                        default_model,
+                        model,
                         api_url,
                         low_speed_timeout_in_seconds,
                     },
                 ) => {
                     *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
-                        default_model.clone(),
+                        model.clone(),
                         api_url.clone(),
                         client.http_client(),
                         low_speed_timeout_in_seconds.map(Duration::from_secs),
@@ -159,6 +154,25 @@ impl CompletionProvider {
         cx.global::<Self>()
     }
 
+    pub fn available_models(&self) -> Vec<LanguageModel> {
+        match self {
+            CompletionProvider::OpenAi(provider) => provider
+                .available_models()
+                .map(LanguageModel::OpenAi)
+                .collect(),
+            CompletionProvider::Anthropic(provider) => provider
+                .available_models()
+                .map(LanguageModel::Anthropic)
+                .collect(),
+            CompletionProvider::ZedDotDev(provider) => provider
+                .available_models()
+                .map(LanguageModel::ZedDotDev)
+                .collect(),
+            #[cfg(test)]
+            CompletionProvider::Fake(_) => unimplemented!(),
+        }
+    }
+
     pub fn settings_version(&self) -> usize {
         match self {
             CompletionProvider::OpenAi(provider) => provider.settings_version(),
@@ -209,17 +223,13 @@ impl CompletionProvider {
         }
     }
 
-    pub fn default_model(&self) -> LanguageModel {
+    pub fn model(&self) -> LanguageModel {
         match self {
-            CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
-            CompletionProvider::Anthropic(provider) => {
-                LanguageModel::Anthropic(provider.default_model())
-            }
-            CompletionProvider::ZedDotDev(provider) => {
-                LanguageModel::ZedDotDev(provider.default_model())
-            }
+            CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
+            CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
+            CompletionProvider::ZedDotDev(provider) => LanguageModel::ZedDotDev(provider.model()),
             #[cfg(test)]
-            CompletionProvider::Fake(_) => unimplemented!(),
+            CompletionProvider::Fake(_) => LanguageModel::default(),
         }
     }
 

crates/assistant/src/completion_provider/anthropic.rs 🔗

@@ -12,6 +12,7 @@ use http::HttpClient;
 use settings::Settings;
 use std::time::Duration;
 use std::{env, sync::Arc};
+use strum::IntoEnumIterator;
 use theme::ThemeSettings;
 use ui::prelude::*;
 use util::ResultExt;
@@ -19,7 +20,7 @@ use util::ResultExt;
 pub struct AnthropicCompletionProvider {
     api_key: Option<String>,
     api_url: String,
-    default_model: AnthropicModel,
+    model: AnthropicModel,
     http_client: Arc<dyn HttpClient>,
     low_speed_timeout: Option<Duration>,
     settings_version: usize,
@@ -27,7 +28,7 @@ pub struct AnthropicCompletionProvider {
 
 impl AnthropicCompletionProvider {
     pub fn new(
-        default_model: AnthropicModel,
+        model: AnthropicModel,
         api_url: String,
         http_client: Arc<dyn HttpClient>,
         low_speed_timeout: Option<Duration>,
@@ -36,7 +37,7 @@ impl AnthropicCompletionProvider {
         Self {
             api_key: None,
             api_url,
-            default_model,
+            model,
             http_client,
             low_speed_timeout,
             settings_version,
@@ -45,17 +46,21 @@ impl AnthropicCompletionProvider {
 
     pub fn update(
         &mut self,
-        default_model: AnthropicModel,
+        model: AnthropicModel,
         api_url: String,
         low_speed_timeout: Option<Duration>,
         settings_version: usize,
     ) {
-        self.default_model = default_model;
+        self.model = model;
         self.api_url = api_url;
         self.low_speed_timeout = low_speed_timeout;
         self.settings_version = settings_version;
     }
 
+    pub fn available_models(&self) -> impl Iterator<Item = AnthropicModel> {
+        AnthropicModel::iter()
+    }
+
     pub fn settings_version(&self) -> usize {
         self.settings_version
     }
@@ -105,8 +110,8 @@ impl AnthropicCompletionProvider {
             .into()
     }
 
-    pub fn default_model(&self) -> AnthropicModel {
-        self.default_model.clone()
+    pub fn model(&self) -> AnthropicModel {
+        self.model.clone()
     }
 
     pub fn count_tokens(
@@ -165,7 +170,7 @@ impl AnthropicCompletionProvider {
     fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request {
         let model = match request.model {
             LanguageModel::Anthropic(model) => model,
-            _ => self.default_model(),
+            _ => self.model(),
         };
 
         let mut system_message = String::new();

crates/assistant/src/completion_provider/open_ai.rs 🔗

@@ -11,6 +11,7 @@ use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
 use settings::Settings;
 use std::time::Duration;
 use std::{env, sync::Arc};
+use strum::IntoEnumIterator;
 use theme::ThemeSettings;
 use ui::prelude::*;
 use util::ResultExt;
@@ -18,7 +19,7 @@ use util::ResultExt;
 pub struct OpenAiCompletionProvider {
     api_key: Option<String>,
     api_url: String,
-    default_model: OpenAiModel,
+    model: OpenAiModel,
     http_client: Arc<dyn HttpClient>,
     low_speed_timeout: Option<Duration>,
     settings_version: usize,
@@ -26,7 +27,7 @@ pub struct OpenAiCompletionProvider {
 
 impl OpenAiCompletionProvider {
     pub fn new(
-        default_model: OpenAiModel,
+        model: OpenAiModel,
         api_url: String,
         http_client: Arc<dyn HttpClient>,
         low_speed_timeout: Option<Duration>,
@@ -35,7 +36,7 @@ impl OpenAiCompletionProvider {
         Self {
             api_key: None,
             api_url,
-            default_model,
+            model,
             http_client,
             low_speed_timeout,
             settings_version,
@@ -44,17 +45,21 @@ impl OpenAiCompletionProvider {
 
     pub fn update(
         &mut self,
-        default_model: OpenAiModel,
+        model: OpenAiModel,
         api_url: String,
         low_speed_timeout: Option<Duration>,
         settings_version: usize,
     ) {
-        self.default_model = default_model;
+        self.model = model;
         self.api_url = api_url;
         self.low_speed_timeout = low_speed_timeout;
         self.settings_version = settings_version;
     }
 
+    pub fn available_models(&self) -> impl Iterator<Item = OpenAiModel> {
+        OpenAiModel::iter()
+    }
+
     pub fn settings_version(&self) -> usize {
         self.settings_version
     }
@@ -104,8 +109,8 @@ impl OpenAiCompletionProvider {
             .into()
     }
 
-    pub fn default_model(&self) -> OpenAiModel {
-        self.default_model.clone()
+    pub fn model(&self) -> OpenAiModel {
+        self.model.clone()
     }
 
     pub fn count_tokens(
@@ -152,7 +157,7 @@ impl OpenAiCompletionProvider {
     fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
         let model = match request.model {
             LanguageModel::OpenAi(model) => model,
-            _ => self.default_model(),
+            _ => self.model(),
         };
 
         Request {

crates/assistant/src/completion_provider/zed.rs 🔗

@@ -7,11 +7,12 @@ use client::{proto, Client};
 use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
 use gpui::{AnyView, AppContext, Task};
 use std::{future, sync::Arc};
+use strum::IntoEnumIterator;
 use ui::prelude::*;
 
 pub struct ZedDotDevCompletionProvider {
     client: Arc<Client>,
-    default_model: ZedDotDevModel,
+    model: ZedDotDevModel,
     settings_version: usize,
     status: client::Status,
     _maintain_client_status: Task<()>,
@@ -19,7 +20,7 @@ pub struct ZedDotDevCompletionProvider {
 
 impl ZedDotDevCompletionProvider {
     pub fn new(
-        default_model: ZedDotDevModel,
+        model: ZedDotDevModel,
         client: Arc<Client>,
         settings_version: usize,
         cx: &mut AppContext,
@@ -39,24 +40,39 @@ impl ZedDotDevCompletionProvider {
         });
         Self {
             client,
-            default_model,
+            model,
             settings_version,
             status,
             _maintain_client_status: maintain_client_status,
         }
     }
 
-    pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) {
-        self.default_model = default_model;
+    pub fn update(&mut self, model: ZedDotDevModel, settings_version: usize) {
+        self.model = model;
         self.settings_version = settings_version;
     }
 
+    pub fn available_models(&self) -> impl Iterator<Item = ZedDotDevModel> {
+        let mut custom_model = if let ZedDotDevModel::Custom(custom_model) = self.model.clone() {
+            Some(custom_model)
+        } else {
+            None
+        };
+        ZedDotDevModel::iter().filter_map(move |model| {
+            if let ZedDotDevModel::Custom(_) = model {
+                Some(ZedDotDevModel::Custom(custom_model.take()?))
+            } else {
+                Some(model)
+            }
+        })
+    }
+
     pub fn settings_version(&self) -> usize {
         self.settings_version
     }
 
-    pub fn default_model(&self) -> ZedDotDevModel {
-        self.default_model.clone()
+    pub fn model(&self) -> ZedDotDevModel {
+        self.model.clone()
     }
 
     pub fn is_authenticated(&self) -> bool {

crates/assistant/src/model_selector.rs 🔗

@@ -0,0 +1,84 @@
+use std::sync::Arc;
+
+use crate::{assistant_settings::AssistantSettings, CompletionProvider, ToggleModelSelector};
+use fs::Fs;
+use settings::update_settings_file;
+use ui::{popover_menu, prelude::*, ButtonLike, ContextMenu, PopoverMenuHandle, Tooltip};
+
+#[derive(IntoElement)]
+pub struct ModelSelector {
+    handle: PopoverMenuHandle<ContextMenu>,
+    fs: Arc<dyn Fs>,
+}
+
+impl ModelSelector {
+    pub fn new(handle: PopoverMenuHandle<ContextMenu>, fs: Arc<dyn Fs>) -> Self {
+        ModelSelector { handle, fs }
+    }
+}
+
+impl RenderOnce for ModelSelector {
+    fn render(self, cx: &mut WindowContext) -> impl IntoElement {
+        popover_menu("model-switcher")
+            .with_handle(self.handle)
+            .menu(move |cx| {
+                ContextMenu::build(cx, |mut menu, cx| {
+                    for model in CompletionProvider::global(cx).available_models() {
+                        menu = menu.custom_entry(
+                            {
+                                let model = model.clone();
+                                move |_| Label::new(model.display_name()).into_any_element()
+                            },
+                            {
+                                let fs = self.fs.clone();
+                                let model = model.clone();
+                                move |cx| {
+                                    let model = model.clone();
+                                    update_settings_file::<AssistantSettings>(
+                                        fs.clone(),
+                                        cx,
+                                        move |settings| settings.set_model(model),
+                                    );
+                                }
+                            },
+                        );
+                    }
+                    menu
+                })
+                .into()
+            })
+            .trigger(
+                ButtonLike::new("active-model")
+                    .child(
+                        h_flex()
+                            .w_full()
+                            .gap_0p5()
+                            .child(
+                                div()
+                                    .overflow_x_hidden()
+                                    .flex_grow()
+                                    .whitespace_nowrap()
+                                    .child(
+                                        Label::new(
+                                            CompletionProvider::global(cx).model().display_name(),
+                                        )
+                                        .size(LabelSize::Small)
+                                        .color(Color::Muted),
+                                    ),
+                            )
+                            .child(
+                                div().child(
+                                    Icon::new(IconName::ChevronDown)
+                                        .color(Color::Muted)
+                                        .size(IconSize::XSmall),
+                                ),
+                            ),
+                    )
+                    .style(ButtonStyle::Subtle)
+                    .tooltip(move |cx| {
+                        Tooltip::for_action("Change Model", &ToggleModelSelector, cx)
+                    }),
+            )
+            .anchor(gpui::AnchorCorner::BottomRight)
+    }
+}

crates/open_ai/Cargo.toml 🔗

@@ -20,3 +20,4 @@ isahc.workspace = true
 schemars = { workspace = true, optional = true }
 serde.workspace = true
 serde_json.workspace = true
+strum.workspace = true

crates/open_ai/src/open_ai.rs 🔗

@@ -4,8 +4,8 @@ use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use isahc::config::Configurable;
 use serde::{Deserialize, Serialize};
 use serde_json::{Map, Value};
-use std::time::Duration;
-use std::{convert::TryFrom, future::Future};
+use std::{convert::TryFrom, future::Future, time::Duration};
+use strum::EnumIter;
 
 pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 
@@ -44,7 +44,7 @@ impl From<Role> for String {
 }
 
 #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
-#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 pub enum Model {
     #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
     ThreePointFiveTurbo,

crates/ui/src/components/popover_menu.rs 🔗

@@ -13,6 +13,51 @@ pub trait PopoverTrigger: IntoElement + Clickable + Selectable + 'static {}
 
 impl<T: IntoElement + Clickable + Selectable + 'static> PopoverTrigger for T {}
 
+pub struct PopoverMenuHandle<M>(Rc<RefCell<Option<PopoverMenuHandleState<M>>>>);
+
+impl<M> Clone for PopoverMenuHandle<M> {
+    fn clone(&self) -> Self {
+        Self(self.0.clone())
+    }
+}
+
+impl<M> Default for PopoverMenuHandle<M> {
+    fn default() -> Self {
+        Self(Rc::default())
+    }
+}
+
+struct PopoverMenuHandleState<M> {
+    menu_builder: Rc<dyn Fn(&mut WindowContext) -> Option<View<M>>>,
+    menu: Rc<RefCell<Option<View<M>>>>,
+}
+
+impl<M: ManagedView> PopoverMenuHandle<M> {
+    pub fn show(&self, cx: &mut WindowContext) {
+        if let Some(state) = self.0.borrow().as_ref() {
+            show_menu(&state.menu_builder, &state.menu, cx);
+        }
+    }
+
+    pub fn hide(&self, cx: &mut WindowContext) {
+        if let Some(state) = self.0.borrow().as_ref() {
+            if let Some(menu) = state.menu.borrow().as_ref() {
+                menu.update(cx, |_, cx| cx.emit(DismissEvent));
+            }
+        }
+    }
+
+    pub fn toggle(&self, cx: &mut WindowContext) {
+        if let Some(state) = self.0.borrow().as_ref() {
+            if state.menu.borrow().is_some() {
+                self.hide(cx);
+            } else {
+                self.show(cx);
+            }
+        }
+    }
+}
+
 pub struct PopoverMenu<M: ManagedView> {
     id: ElementId,
     child_builder: Option<
@@ -28,6 +73,7 @@ pub struct PopoverMenu<M: ManagedView> {
     anchor: AnchorCorner,
     attach: Option<AnchorCorner>,
     offset: Option<Point<Pixels>>,
+    trigger_handle: Option<PopoverMenuHandle<M>>,
 }
 
 impl<M: ManagedView> PopoverMenu<M> {
@@ -36,35 +82,17 @@ impl<M: ManagedView> PopoverMenu<M> {
         self
     }
 
+    pub fn with_handle(mut self, handle: PopoverMenuHandle<M>) -> Self {
+        self.trigger_handle = Some(handle);
+        self
+    }
+
     pub fn trigger<T: PopoverTrigger>(mut self, t: T) -> Self {
         self.child_builder = Some(Box::new(|menu, builder| {
             let open = menu.borrow().is_some();
             t.selected(open)
                 .when_some(builder, |el, builder| {
-                    el.on_click({
-                        move |_, cx| {
-                            let Some(new_menu) = (builder)(cx) else {
-                                return;
-                            };
-                            let menu2 = menu.clone();
-                            let previous_focus_handle = cx.focused();
-
-                            cx.subscribe(&new_menu, move |modal, _: &DismissEvent, cx| {
-                                if modal.focus_handle(cx).contains_focused(cx) {
-                                    if let Some(previous_focus_handle) =
-                                        previous_focus_handle.as_ref()
-                                    {
-                                        cx.focus(previous_focus_handle);
-                                    }
-                                }
-                                *menu2.borrow_mut() = None;
-                                cx.refresh();
-                            })
-                            .detach();
-                            cx.focus_view(&new_menu);
-                            *menu.borrow_mut() = Some(new_menu);
-                        }
-                    })
+                    el.on_click(move |_, cx| show_menu(&builder, &menu, cx))
                 })
                 .into_any_element()
         }));
@@ -111,6 +139,32 @@ impl<M: ManagedView> PopoverMenu<M> {
     }
 }
 
+fn show_menu<M: ManagedView>(
+    builder: &Rc<dyn Fn(&mut WindowContext) -> Option<View<M>>>,
+    menu: &Rc<RefCell<Option<View<M>>>>,
+    cx: &mut WindowContext,
+) {
+    let Some(new_menu) = (builder)(cx) else {
+        return;
+    };
+    let menu2 = menu.clone();
+    let previous_focus_handle = cx.focused();
+
+    cx.subscribe(&new_menu, move |modal, _: &DismissEvent, cx| {
+        if modal.focus_handle(cx).contains_focused(cx) {
+            if let Some(previous_focus_handle) = previous_focus_handle.as_ref() {
+                cx.focus(previous_focus_handle);
+            }
+        }
+        *menu2.borrow_mut() = None;
+        cx.refresh();
+    })
+    .detach();
+    cx.focus_view(&new_menu);
+    *menu.borrow_mut() = Some(new_menu);
+    cx.refresh();
+}
+
 /// Creates a [`PopoverMenu`]
 pub fn popover_menu<M: ManagedView>(id: impl Into<ElementId>) -> PopoverMenu<M> {
     PopoverMenu {
@@ -120,6 +174,7 @@ pub fn popover_menu<M: ManagedView>(id: impl Into<ElementId>) -> PopoverMenu<M>
         anchor: AnchorCorner::TopLeft,
         attach: None,
         offset: None,
+        trigger_handle: None,
     }
 }
 
@@ -190,6 +245,15 @@ impl<M: ManagedView> Element for PopoverMenu<M> {
                     (child_builder)(element_state.menu.clone(), self.menu_builder.clone())
                 });
 
+                if let Some(trigger_handle) = self.trigger_handle.take() {
+                    if let Some(menu_builder) = self.menu_builder.clone() {
+                        *trigger_handle.0.borrow_mut() = Some(PopoverMenuHandleState {
+                            menu_builder,
+                            menu: element_state.menu.clone(),
+                        });
+                    }
+                }
+
                 let child_layout_id = child_element
                     .as_mut()
                     .map(|child_element| child_element.request_layout(cx));