From a836f9c23d6566ad21e9aacf6e92612419196a3a Mon Sep 17 00:00:00 2001 From: "Joseph T. Lyons" Date: Tue, 22 Aug 2023 02:55:27 -0400 Subject: [PATCH] Add a `default_open_ai_model` setting for the assistant (#2876) [This PR has been sitting around for a bit](https://github.com/zed-industries/zed/pull/2845). I received a bit of mixed opinions from the team on how this setting should work, if it should use the full model names or some simpler form of it, etc. I went ahead and made the decision to do the following: - Use the full model names in settings - ex: `gpt-4-0613` - Default to `gpt-4-0613` when no setting is present - Save the full model names in the conversation history files (this is how it was prior) - ex: `gpt-4-0613` - Display the shortened model names in the assistant - ex: `gpt-4` - Not worry about adding an option to add custom models (can add in a follow-up PR) - Not query what models are available to the user via their api key (can add in a follow-up PR) Release Notes: - Added a `default_open_ai_model` setting for the assistant (defaults to `gpt-4-0613`). --------- Co-authored-by: Mikayla --- assets/settings/default.json | 8 ++++- crates/ai/src/ai.rs | 3 +- crates/ai/src/assistant.rs | 45 +++++++++++++++++---------- crates/ai/src/assistant_settings.rs | 33 ++++++++++++++++++++ crates/settings/src/settings_store.rs | 3 +- 5 files changed, 72 insertions(+), 20 deletions(-) diff --git a/assets/settings/default.json b/assets/settings/default.json index 08faedbed6bd1e872f26ef6a0112f6afee0a5fd2..24412b883bf0be12cb2639dd54dec7f70adf6882 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -138,7 +138,13 @@ // Default width when the assistant is docked to the left or right. "default_width": 640, // Default height when the assistant is docked to the bottom. - "default_height": 320 + "default_height": 320, + // The default OpenAI model to use when starting new conversations. This + // setting can take two values: + // + // 1. "gpt-3.5-turbo-0613"" + // 2. "gpt-4-0613"" + "default_open_ai_model": "gpt-4-0613" }, // Whether the screen sharing icon is shown in the os status bar. "show_call_status_icon": true, diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 7cc5f08f7c1f30102308188865299ff8ee1af833..d2be651bd564878bb6ff081f8071675a61bcdbe2 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -3,6 +3,7 @@ mod assistant_settings; use anyhow::Result; pub use assistant::AssistantPanel; +use assistant_settings::OpenAIModel; use chrono::{DateTime, Local}; use collections::HashMap; use fs::Fs; @@ -60,7 +61,7 @@ struct SavedConversation { messages: Vec, message_metadata: HashMap, summary: String, - model: String, + model: OpenAIModel, } impl SavedConversation { diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index e5026182ed2f0b9cc7516df9e31f73a2c53d8c5d..81299bbdc26f5001f407901893b7c8d3e0f1b166 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -1,5 +1,5 @@ use crate::{ - assistant_settings::{AssistantDockPosition, AssistantSettings}, + assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; @@ -833,7 +833,7 @@ struct Conversation { pending_summary: Task>, completion_count: usize, pending_completions: Vec, - model: String, + model: OpenAIModel, token_count: Option, max_token_count: usize, pending_token_count: Task>, @@ -853,7 +853,6 @@ impl Conversation { language_registry: Arc, cx: &mut ModelContext, ) -> Self { - let model = "gpt-3.5-turbo-0613"; let markdown = language_registry.language_for_name("Markdown"); let buffer = cx.add_model(|cx| { let mut buffer = Buffer::new(0, "", cx); @@ -872,6 +871,9 @@ impl Conversation { buffer }); + let settings = settings::get::(cx); + let model = settings.default_open_ai_model.clone(); + let mut this = Self { message_anchors: Default::default(), messages_metadata: Default::default(), @@ -881,9 +883,9 @@ impl Conversation { completion_count: Default::default(), pending_completions: Default::default(), token_count: None, - max_token_count: tiktoken_rs::model::get_context_size(model), + max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()), pending_token_count: Task::ready(None), - model: model.into(), + model: model.clone(), _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: None, @@ -977,7 +979,7 @@ impl Conversation { completion_count: Default::default(), pending_completions: Default::default(), token_count: None, - max_token_count: tiktoken_rs::model::get_context_size(&model), + max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()), pending_token_count: Task::ready(None), model, _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], @@ -1031,13 +1033,16 @@ impl Conversation { cx.background().timer(Duration::from_millis(200)).await; let token_count = cx .background() - .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) }) + .spawn(async move { + tiktoken_rs::num_tokens_from_messages(&model.full_name(), &messages) + }) .await?; this.upgrade(&cx) .ok_or_else(|| anyhow!("conversation was dropped"))? .update(&mut cx, |this, cx| { - this.max_token_count = tiktoken_rs::model::get_context_size(&this.model); + this.max_token_count = + tiktoken_rs::model::get_context_size(&this.model.full_name()); this.token_count = Some(token_count); cx.notify() }); @@ -1051,7 +1056,7 @@ impl Conversation { Some(self.max_token_count as isize - self.token_count? as isize) } - fn set_model(&mut self, model: String, cx: &mut ModelContext) { + fn set_model(&mut self, model: OpenAIModel, cx: &mut ModelContext) { self.model = model; self.count_remaining_tokens(cx); cx.notify(); @@ -1093,7 +1098,7 @@ impl Conversation { } } else { let request = OpenAIRequest { - model: self.model.clone(), + model: self.model.full_name().to_string(), messages: self .messages(cx) .filter(|message| matches!(message.status, MessageStatus::Done)) @@ -1419,7 +1424,7 @@ impl Conversation { .into(), })); let request = OpenAIRequest { - model: self.model.clone(), + model: self.model.full_name().to_string(), messages: messages.collect(), stream: true, }; @@ -2023,11 +2028,8 @@ impl ConversationEditor { fn cycle_model(&mut self, cx: &mut ViewContext) { self.conversation.update(cx, |conversation, cx| { - let new_model = match conversation.model.as_str() { - "gpt-4-0613" => "gpt-3.5-turbo-0613", - _ => "gpt-4-0613", - }; - conversation.set_model(new_model.into(), cx); + let new_model = conversation.model.cycle(); + conversation.set_model(new_model, cx); }); } @@ -2049,7 +2051,8 @@ impl ConversationEditor { MouseEventHandler::new::(0, cx, |state, cx| { let style = style.model.style_for(state); - Label::new(self.conversation.read(cx).model.clone(), style.text.clone()) + let model_display_name = self.conversation.read(cx).model.short_name(); + Label::new(model_display_name, style.text.clone()) .contained() .with_style(style.container) }) @@ -2238,6 +2241,8 @@ mod tests { #[gpui::test] fn test_inserting_and_removing_messages(cx: &mut AppContext) { + cx.set_global(SettingsStore::test(cx)); + init(cx); let registry = Arc::new(LanguageRegistry::test()); let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); let buffer = conversation.read(cx).buffer.clone(); @@ -2364,6 +2369,8 @@ mod tests { #[gpui::test] fn test_message_splitting(cx: &mut AppContext) { + cx.set_global(SettingsStore::test(cx)); + init(cx); let registry = Arc::new(LanguageRegistry::test()); let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); let buffer = conversation.read(cx).buffer.clone(); @@ -2458,6 +2465,8 @@ mod tests { #[gpui::test] fn test_messages_for_offsets(cx: &mut AppContext) { + cx.set_global(SettingsStore::test(cx)); + init(cx); let registry = Arc::new(LanguageRegistry::test()); let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); let buffer = conversation.read(cx).buffer.clone(); @@ -2538,6 +2547,8 @@ mod tests { #[gpui::test] fn test_serialization(cx: &mut AppContext) { + cx.set_global(SettingsStore::test(cx)); + init(cx); let registry = Arc::new(LanguageRegistry::test()); let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx)); diff --git a/crates/ai/src/assistant_settings.rs b/crates/ai/src/assistant_settings.rs index 04ba8fb946eb7f450fe95bc7565bb304f8b4a1d7..05d8d9ffebe485204108969bc4ec308eedbd2d1d 100644 --- a/crates/ai/src/assistant_settings.rs +++ b/crates/ai/src/assistant_settings.rs @@ -3,6 +3,37 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Setting; +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] +pub enum OpenAIModel { + #[serde(rename = "gpt-3.5-turbo-0613")] + ThreePointFiveTurbo, + #[serde(rename = "gpt-4-0613")] + Four, +} + +impl OpenAIModel { + pub fn full_name(&self) -> &'static str { + match self { + OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613", + OpenAIModel::Four => "gpt-4-0613", + } + } + + pub fn short_name(&self) -> &'static str { + match self { + OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo", + OpenAIModel::Four => "gpt-4", + } + } + + pub fn cycle(&self) -> Self { + match self { + OpenAIModel::ThreePointFiveTurbo => OpenAIModel::Four, + OpenAIModel::Four => OpenAIModel::ThreePointFiveTurbo, + } + } +} + #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] pub enum AssistantDockPosition { @@ -17,6 +48,7 @@ pub struct AssistantSettings { pub dock: AssistantDockPosition, pub default_width: f32, pub default_height: f32, + pub default_open_ai_model: OpenAIModel, } #[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] @@ -25,6 +57,7 @@ pub struct AssistantSettingsContent { pub dock: Option, pub default_width: Option, pub default_height: Option, + pub default_open_ai_model: Option, } impl Setting for AssistantSettings { diff --git a/crates/settings/src/settings_store.rs b/crates/settings/src/settings_store.rs index 1188018cd892143788c0f6629aa72425eee2dc85..da84074d2a9cd80eec0c84a8f4daaf4c681d5c8f 100644 --- a/crates/settings/src/settings_store.rs +++ b/crates/settings/src/settings_store.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use collections::{btree_map, hash_map, BTreeMap, HashMap}; use gpui::AppContext; use lazy_static::lazy_static; @@ -162,6 +162,7 @@ impl SettingsStore { if let Some(setting) = setting_value .load_setting(&default_settings, &user_values_stack, cx) + .context("A default setting must be added to the `default.json` file") .log_err() { setting_value.set_global_value(setting);