Add a `default_open_ai_model` setting for the assistant (#2876)

Joseph T. Lyons and Mikayla created

[This PR has been sitting around for a
bit](https://github.com/zed-industries/zed/pull/2845). I received a bit
of mixed opinions from the team on how this setting should work, if it
should use the full model names or some simpler form of it, etc. I went
ahead and made the decision to do the following:

- Use the full model names in settings - ex: `gpt-4-0613`
- Default to `gpt-4-0613` when no setting is present
- Save the full model names in the conversation history files (this is
how it was prior) - ex: `gpt-4-0613`
- Display the shortened model names in the assistant - ex: `gpt-4`
- Not worry about adding an option to add custom models (can add in a
follow-up PR)
- Not query what models are available to the user via their api key (can
add in a follow-up PR)

Release Notes:

- Added a `default_open_ai_model` setting for the assistant (defaults to
`gpt-4-0613`).

---------

Co-authored-by: Mikayla <mikayla@zed.dev>

Change summary

assets/settings/default.json          |  8 ++++
crates/ai/src/ai.rs                   |  3 +
crates/ai/src/assistant.rs            | 45 ++++++++++++++++++----------
crates/ai/src/assistant_settings.rs   | 33 +++++++++++++++++++++
crates/settings/src/settings_store.rs |  3 +
5 files changed, 72 insertions(+), 20 deletions(-)

Detailed changes

assets/settings/default.json 🔗

@@ -138,7 +138,13 @@
     // Default width when the assistant is docked to the left or right.
     "default_width": 640,
     // Default height when the assistant is docked to the bottom.
-    "default_height": 320
+    "default_height": 320,
+    // The default OpenAI model to use when starting new conversations. This
+    // setting can take two values:
+    //
+    // 1. "gpt-3.5-turbo-0613""
+    // 2. "gpt-4-0613""
+    "default_open_ai_model": "gpt-4-0613"
   },
   // Whether the screen sharing icon is shown in the os status bar.
   "show_call_status_icon": true,

crates/ai/src/ai.rs 🔗

@@ -3,6 +3,7 @@ mod assistant_settings;
 
 use anyhow::Result;
 pub use assistant::AssistantPanel;
+use assistant_settings::OpenAIModel;
 use chrono::{DateTime, Local};
 use collections::HashMap;
 use fs::Fs;
@@ -60,7 +61,7 @@ struct SavedConversation {
     messages: Vec<SavedMessage>,
     message_metadata: HashMap<MessageId, MessageMetadata>,
     summary: String,
-    model: String,
+    model: OpenAIModel,
 }
 
 impl SavedConversation {

crates/ai/src/assistant.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{
-    assistant_settings::{AssistantDockPosition, AssistantSettings},
+    assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
     MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
     RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
 };
@@ -833,7 +833,7 @@ struct Conversation {
     pending_summary: Task<Option<()>>,
     completion_count: usize,
     pending_completions: Vec<PendingCompletion>,
-    model: String,
+    model: OpenAIModel,
     token_count: Option<usize>,
     max_token_count: usize,
     pending_token_count: Task<Option<()>>,
@@ -853,7 +853,6 @@ impl Conversation {
         language_registry: Arc<LanguageRegistry>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
-        let model = "gpt-3.5-turbo-0613";
         let markdown = language_registry.language_for_name("Markdown");
         let buffer = cx.add_model(|cx| {
             let mut buffer = Buffer::new(0, "", cx);
@@ -872,6 +871,9 @@ impl Conversation {
             buffer
         });
 
+        let settings = settings::get::<AssistantSettings>(cx);
+        let model = settings.default_open_ai_model.clone();
+
         let mut this = Self {
             message_anchors: Default::default(),
             messages_metadata: Default::default(),
@@ -881,9 +883,9 @@ impl Conversation {
             completion_count: Default::default(),
             pending_completions: Default::default(),
             token_count: None,
-            max_token_count: tiktoken_rs::model::get_context_size(model),
+            max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
             pending_token_count: Task::ready(None),
-            model: model.into(),
+            model: model.clone(),
             _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
             pending_save: Task::ready(Ok(())),
             path: None,
@@ -977,7 +979,7 @@ impl Conversation {
             completion_count: Default::default(),
             pending_completions: Default::default(),
             token_count: None,
-            max_token_count: tiktoken_rs::model::get_context_size(&model),
+            max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
             pending_token_count: Task::ready(None),
             model,
             _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
@@ -1031,13 +1033,16 @@ impl Conversation {
                 cx.background().timer(Duration::from_millis(200)).await;
                 let token_count = cx
                     .background()
-                    .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
+                    .spawn(async move {
+                        tiktoken_rs::num_tokens_from_messages(&model.full_name(), &messages)
+                    })
                     .await?;
 
                 this.upgrade(&cx)
                     .ok_or_else(|| anyhow!("conversation was dropped"))?
                     .update(&mut cx, |this, cx| {
-                        this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
+                        this.max_token_count =
+                            tiktoken_rs::model::get_context_size(&this.model.full_name());
                         this.token_count = Some(token_count);
                         cx.notify()
                     });
@@ -1051,7 +1056,7 @@ impl Conversation {
         Some(self.max_token_count as isize - self.token_count? as isize)
     }
 
-    fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
+    fn set_model(&mut self, model: OpenAIModel, cx: &mut ModelContext<Self>) {
         self.model = model;
         self.count_remaining_tokens(cx);
         cx.notify();
@@ -1093,7 +1098,7 @@ impl Conversation {
                 }
             } else {
                 let request = OpenAIRequest {
-                    model: self.model.clone(),
+                    model: self.model.full_name().to_string(),
                     messages: self
                         .messages(cx)
                         .filter(|message| matches!(message.status, MessageStatus::Done))
@@ -1419,7 +1424,7 @@ impl Conversation {
                                 .into(),
                     }));
                 let request = OpenAIRequest {
-                    model: self.model.clone(),
+                    model: self.model.full_name().to_string(),
                     messages: messages.collect(),
                     stream: true,
                 };
@@ -2023,11 +2028,8 @@ impl ConversationEditor {
 
     fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
         self.conversation.update(cx, |conversation, cx| {
-            let new_model = match conversation.model.as_str() {
-                "gpt-4-0613" => "gpt-3.5-turbo-0613",
-                _ => "gpt-4-0613",
-            };
-            conversation.set_model(new_model.into(), cx);
+            let new_model = conversation.model.cycle();
+            conversation.set_model(new_model, cx);
         });
     }
 
@@ -2049,7 +2051,8 @@ impl ConversationEditor {
 
         MouseEventHandler::new::<Model, _>(0, cx, |state, cx| {
             let style = style.model.style_for(state);
-            Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
+            let model_display_name = self.conversation.read(cx).model.short_name();
+            Label::new(model_display_name, style.text.clone())
                 .contained()
                 .with_style(style.container)
         })
@@ -2238,6 +2241,8 @@ mod tests {
 
     #[gpui::test]
     fn test_inserting_and_removing_messages(cx: &mut AppContext) {
+        cx.set_global(SettingsStore::test(cx));
+        init(cx);
         let registry = Arc::new(LanguageRegistry::test());
         let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
         let buffer = conversation.read(cx).buffer.clone();
@@ -2364,6 +2369,8 @@ mod tests {
 
     #[gpui::test]
     fn test_message_splitting(cx: &mut AppContext) {
+        cx.set_global(SettingsStore::test(cx));
+        init(cx);
         let registry = Arc::new(LanguageRegistry::test());
         let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
         let buffer = conversation.read(cx).buffer.clone();
@@ -2458,6 +2465,8 @@ mod tests {
 
     #[gpui::test]
     fn test_messages_for_offsets(cx: &mut AppContext) {
+        cx.set_global(SettingsStore::test(cx));
+        init(cx);
         let registry = Arc::new(LanguageRegistry::test());
         let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
         let buffer = conversation.read(cx).buffer.clone();
@@ -2538,6 +2547,8 @@ mod tests {
 
     #[gpui::test]
     fn test_serialization(cx: &mut AppContext) {
+        cx.set_global(SettingsStore::test(cx));
+        init(cx);
         let registry = Arc::new(LanguageRegistry::test());
         let conversation =
             cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));

crates/ai/src/assistant_settings.rs 🔗

@@ -3,6 +3,37 @@ use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::Setting;
 
+#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
+pub enum OpenAIModel {
+    #[serde(rename = "gpt-3.5-turbo-0613")]
+    ThreePointFiveTurbo,
+    #[serde(rename = "gpt-4-0613")]
+    Four,
+}
+
+impl OpenAIModel {
+    pub fn full_name(&self) -> &'static str {
+        match self {
+            OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
+            OpenAIModel::Four => "gpt-4-0613",
+        }
+    }
+
+    pub fn short_name(&self) -> &'static str {
+        match self {
+            OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo",
+            OpenAIModel::Four => "gpt-4",
+        }
+    }
+
+    pub fn cycle(&self) -> Self {
+        match self {
+            OpenAIModel::ThreePointFiveTurbo => OpenAIModel::Four,
+            OpenAIModel::Four => OpenAIModel::ThreePointFiveTurbo,
+        }
+    }
+}
+
 #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
 #[serde(rename_all = "snake_case")]
 pub enum AssistantDockPosition {
@@ -17,6 +48,7 @@ pub struct AssistantSettings {
     pub dock: AssistantDockPosition,
     pub default_width: f32,
     pub default_height: f32,
+    pub default_open_ai_model: OpenAIModel,
 }
 
 #[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
@@ -25,6 +57,7 @@ pub struct AssistantSettingsContent {
     pub dock: Option<AssistantDockPosition>,
     pub default_width: Option<f32>,
     pub default_height: Option<f32>,
+    pub default_open_ai_model: Option<OpenAIModel>,
 }
 
 impl Setting for AssistantSettings {

crates/settings/src/settings_store.rs 🔗

@@ -1,4 +1,4 @@
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, Context, Result};
 use collections::{btree_map, hash_map, BTreeMap, HashMap};
 use gpui::AppContext;
 use lazy_static::lazy_static;
@@ -162,6 +162,7 @@ impl SettingsStore {
 
             if let Some(setting) = setting_value
                 .load_setting(&default_settings, &user_values_stack, cx)
+                .context("A default setting must be added to the `default.json` file")
                 .log_err()
             {
                 setting_value.set_global_value(setting);