@@ -1,5 +1,5 @@
use crate::{
- assistant_settings::{AssistantDockPosition, AssistantSettings},
+ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
};
@@ -833,7 +833,7 @@ struct Conversation {
pending_summary: Task<Option<()>>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
- model: String,
+ model: OpenAIModel,
token_count: Option<usize>,
max_token_count: usize,
pending_token_count: Task<Option<()>>,
@@ -853,7 +853,6 @@ impl Conversation {
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
) -> Self {
- let model = "gpt-3.5-turbo-0613";
let markdown = language_registry.language_for_name("Markdown");
let buffer = cx.add_model(|cx| {
let mut buffer = Buffer::new(0, "", cx);
@@ -872,6 +871,9 @@ impl Conversation {
buffer
});
+ let settings = settings::get::<AssistantSettings>(cx);
+ let model = settings.default_open_ai_model.clone();
+
let mut this = Self {
message_anchors: Default::default(),
messages_metadata: Default::default(),
@@ -881,9 +883,9 @@ impl Conversation {
completion_count: Default::default(),
pending_completions: Default::default(),
token_count: None,
- max_token_count: tiktoken_rs::model::get_context_size(model),
+ max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
pending_token_count: Task::ready(None),
- model: model.into(),
+ model: model.clone(),
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: None,
@@ -977,7 +979,7 @@ impl Conversation {
completion_count: Default::default(),
pending_completions: Default::default(),
token_count: None,
- max_token_count: tiktoken_rs::model::get_context_size(&model),
+ max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
pending_token_count: Task::ready(None),
model,
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
@@ -1031,13 +1033,16 @@ impl Conversation {
cx.background().timer(Duration::from_millis(200)).await;
let token_count = cx
.background()
- .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
+ .spawn(async move {
+ tiktoken_rs::num_tokens_from_messages(&model.full_name(), &messages)
+ })
.await?;
this.upgrade(&cx)
.ok_or_else(|| anyhow!("conversation was dropped"))?
.update(&mut cx, |this, cx| {
- this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
+ this.max_token_count =
+ tiktoken_rs::model::get_context_size(&this.model.full_name());
this.token_count = Some(token_count);
cx.notify()
});
@@ -1051,7 +1056,7 @@ impl Conversation {
Some(self.max_token_count as isize - self.token_count? as isize)
}
- fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
+ fn set_model(&mut self, model: OpenAIModel, cx: &mut ModelContext<Self>) {
self.model = model;
self.count_remaining_tokens(cx);
cx.notify();
@@ -1093,7 +1098,7 @@ impl Conversation {
}
} else {
let request = OpenAIRequest {
- model: self.model.clone(),
+ model: self.model.full_name().to_string(),
messages: self
.messages(cx)
.filter(|message| matches!(message.status, MessageStatus::Done))
@@ -1419,7 +1424,7 @@ impl Conversation {
.into(),
}));
let request = OpenAIRequest {
- model: self.model.clone(),
+ model: self.model.full_name().to_string(),
messages: messages.collect(),
stream: true,
};
@@ -2023,11 +2028,8 @@ impl ConversationEditor {
fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
self.conversation.update(cx, |conversation, cx| {
- let new_model = match conversation.model.as_str() {
- "gpt-4-0613" => "gpt-3.5-turbo-0613",
- _ => "gpt-4-0613",
- };
- conversation.set_model(new_model.into(), cx);
+ let new_model = conversation.model.cycle();
+ conversation.set_model(new_model, cx);
});
}
@@ -2049,7 +2051,8 @@ impl ConversationEditor {
MouseEventHandler::new::<Model, _>(0, cx, |state, cx| {
let style = style.model.style_for(state);
- Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
+ let model_display_name = self.conversation.read(cx).model.short_name();
+ Label::new(model_display_name, style.text.clone())
.contained()
.with_style(style.container)
})
@@ -2238,6 +2241,8 @@ mod tests {
#[gpui::test]
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
+ cx.set_global(SettingsStore::test(cx));
+ init(cx);
let registry = Arc::new(LanguageRegistry::test());
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
let buffer = conversation.read(cx).buffer.clone();
@@ -2364,6 +2369,8 @@ mod tests {
#[gpui::test]
fn test_message_splitting(cx: &mut AppContext) {
+ cx.set_global(SettingsStore::test(cx));
+ init(cx);
let registry = Arc::new(LanguageRegistry::test());
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
let buffer = conversation.read(cx).buffer.clone();
@@ -2458,6 +2465,8 @@ mod tests {
#[gpui::test]
fn test_messages_for_offsets(cx: &mut AppContext) {
+ cx.set_global(SettingsStore::test(cx));
+ init(cx);
let registry = Arc::new(LanguageRegistry::test());
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
let buffer = conversation.read(cx).buffer.clone();
@@ -2538,6 +2547,8 @@ mod tests {
#[gpui::test]
fn test_serialization(cx: &mut AppContext) {
+ cx.set_global(SettingsStore::test(cx));
+ init(cx);
let registry = Arc::new(LanguageRegistry::test());
let conversation =
cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
@@ -3,6 +3,37 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Setting;
+#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
+pub enum OpenAIModel {
+ #[serde(rename = "gpt-3.5-turbo-0613")]
+ ThreePointFiveTurbo,
+ #[serde(rename = "gpt-4-0613")]
+ Four,
+}
+
+impl OpenAIModel {
+ pub fn full_name(&self) -> &'static str {
+ match self {
+ OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
+ OpenAIModel::Four => "gpt-4-0613",
+ }
+ }
+
+ pub fn short_name(&self) -> &'static str {
+ match self {
+ OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo",
+ OpenAIModel::Four => "gpt-4",
+ }
+ }
+
+ pub fn cycle(&self) -> Self {
+ match self {
+ OpenAIModel::ThreePointFiveTurbo => OpenAIModel::Four,
+ OpenAIModel::Four => OpenAIModel::ThreePointFiveTurbo,
+ }
+ }
+}
+
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum AssistantDockPosition {
@@ -17,6 +48,7 @@ pub struct AssistantSettings {
pub dock: AssistantDockPosition,
pub default_width: f32,
pub default_height: f32,
+ pub default_open_ai_model: OpenAIModel,
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
@@ -25,6 +57,7 @@ pub struct AssistantSettingsContent {
pub dock: Option<AssistantDockPosition>,
pub default_width: Option<f32>,
pub default_height: Option<f32>,
+ pub default_open_ai_model: Option<OpenAIModel>,
}
impl Setting for AssistantSettings {