Detailed changes
@@ -21,7 +21,7 @@ use gpui::{
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
};
use language::{Buffer, LanguageRegistry};
-use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
+use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::{Markdown, MarkdownStyle};
use project::ProjectItem as _;
use settings::{Settings as _, update_settings_file};
@@ -606,7 +606,7 @@ impl ActiveThread {
if self.thread.read(cx).all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
- if let Some(model) = model_registry.active_model() {
+ if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
self.thread.update(cx, |thread, cx| {
thread.attach_tool_results(cx);
if !canceled {
@@ -814,21 +814,17 @@ impl ActiveThread {
}
});
- let provider = LanguageModelRegistry::read_global(cx).active_provider();
- if provider
- .as_ref()
- .map_or(false, |provider| provider.must_accept_terms(cx))
- {
+ let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
+ return;
+ };
+
+ if model.provider.must_accept_terms(cx) {
cx.notify();
return;
}
- let model_registry = LanguageModelRegistry::read_global(cx);
- let Some(model) = model_registry.active_model() else {
- return;
- };
self.thread.update(cx, |thread, cx| {
- thread.send_to_model(model, RequestKind::Chat, cx)
+ thread.send_to_model(model.model, RequestKind::Chat, cx)
});
cx.notify();
}
@@ -202,43 +202,43 @@ impl PickerDelegate for ToolPickerDelegate {
let default_profile = self.profile.clone();
let tool = tool.clone();
move |settings, _cx| match settings {
- AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(
- settings,
- )) => {
- let profiles = settings.profiles.get_or_insert_default();
- let profile =
- profiles
- .entry(profile_id)
- .or_insert_with(|| AgentProfileContent {
- name: default_profile.name.into(),
- tools: default_profile.tools,
- enable_all_context_servers: Some(
- default_profile.enable_all_context_servers,
- ),
- context_servers: default_profile
- .context_servers
- .into_iter()
- .map(|(server_id, preset)| {
- (
- server_id,
- ContextServerPresetContent {
- tools: preset.tools,
- },
- )
- })
- .collect(),
- });
+ AssistantSettingsContent::Versioned(boxed) => {
+ if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
+ let profiles = settings.profiles.get_or_insert_default();
+ let profile =
+ profiles
+ .entry(profile_id)
+ .or_insert_with(|| AgentProfileContent {
+ name: default_profile.name.into(),
+ tools: default_profile.tools,
+ enable_all_context_servers: Some(
+ default_profile.enable_all_context_servers,
+ ),
+ context_servers: default_profile
+ .context_servers
+ .into_iter()
+ .map(|(server_id, preset)| {
+ (
+ server_id,
+ ContextServerPresetContent {
+ tools: preset.tools,
+ },
+ )
+ })
+ .collect(),
+ });
- match tool.source {
- ToolSource::Native => {
- *profile.tools.entry(tool.name).or_default() = is_enabled;
- }
- ToolSource::ContextServer { id } => {
- let preset = profile
- .context_servers
- .entry(id.clone().into())
- .or_default();
- *preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
+ match tool.source {
+ ToolSource::Native => {
+ *profile.tools.entry(tool.name).or_default() = is_enabled;
+ }
+ ToolSource::ContextServer { id } => {
+ let preset = profile
+ .context_servers
+ .entry(id.clone().into())
+ .or_default();
+ *preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
+ }
}
}
}
@@ -9,10 +9,17 @@ use settings::update_settings_file;
use std::sync::Arc;
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
+#[derive(Clone, Copy)]
+pub enum ModelType {
+ Default,
+ InlineAssistant,
+}
+
pub struct AssistantModelSelector {
selector: Entity<LanguageModelSelector>,
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
focus_handle: FocusHandle,
+ model_type: ModelType,
}
impl AssistantModelSelector {
@@ -20,6 +27,7 @@ impl AssistantModelSelector {
fs: Arc<dyn Fs>,
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
focus_handle: FocusHandle,
+ model_type: ModelType,
window: &mut Window,
cx: &mut App,
) -> Self {
@@ -28,11 +36,32 @@ impl AssistantModelSelector {
let fs = fs.clone();
LanguageModelSelector::new(
move |model, cx| {
- update_settings_file::<AssistantSettings>(
- fs.clone(),
- cx,
- move |settings, _cx| settings.set_model(model.clone()),
- );
+ let provider = model.provider_id().0.to_string();
+ let model_id = model.id().0.to_string();
+
+ match model_type {
+ ModelType::Default => {
+ update_settings_file::<AssistantSettings>(
+ fs.clone(),
+ cx,
+ move |settings, _cx| {
+ settings.set_model(model.clone());
+ },
+ );
+ }
+ ModelType::InlineAssistant => {
+ update_settings_file::<AssistantSettings>(
+ fs.clone(),
+ cx,
+ move |settings, _cx| {
+ settings.set_inline_assistant_model(
+ provider.clone(),
+ model_id.clone(),
+ );
+ },
+ );
+ }
+ }
},
window,
cx,
@@ -40,6 +69,7 @@ impl AssistantModelSelector {
}),
menu_handle,
focus_handle,
+ model_type,
}
}
@@ -50,10 +80,16 @@ impl AssistantModelSelector {
impl Render for AssistantModelSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let active_model = LanguageModelRegistry::read_global(cx).active_model();
+ let model_registry = LanguageModelRegistry::read_global(cx);
+
+ let model = match self.model_type {
+ ModelType::Default => model_registry.default_model(),
+ ModelType::InlineAssistant => model_registry.inline_assistant_model(),
+ };
+
let focus_handle = self.focus_handle.clone();
- let model_name = match active_model {
- Some(model) => model.name().0,
+ let model_name = match model {
+ Some(model) => model.model.name().0,
_ => SharedString::from("No model selected"),
};
@@ -571,10 +571,8 @@ impl AssistantPanel {
match event {
AssistantConfigurationEvent::NewThread(provider) => {
if LanguageModelRegistry::read_global(cx)
- .active_provider()
- .map_or(true, |active_provider| {
- active_provider.id() != provider.id()
- })
+ .default_model()
+ .map_or(true, |model| model.provider.id() != provider.id())
{
if let Some(model) = provider.default_model(cx) {
update_settings_file::<AssistantSettings>(
@@ -922,16 +920,18 @@ impl AssistantPanel {
}
fn configuration_error(&self, cx: &App) -> Option<ConfigurationError> {
- let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
+ let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return Some(ConfigurationError::NoProvider);
};
- if !provider.is_authenticated(cx) {
+ if !model.provider.is_authenticated(cx) {
return Some(ConfigurationError::ProviderNotAuthenticated);
}
- if provider.must_accept_terms(cx) {
- return Some(ConfigurationError::ProviderPendingTermsAcceptance(provider));
+ if model.provider.must_accept_terms(cx) {
+ return Some(ConfigurationError::ProviderPendingTermsAcceptance(
+ model.provider,
+ ));
}
None
@@ -156,8 +156,9 @@ impl BufferCodegen {
}
let primary_model = LanguageModelRegistry::read_global(cx)
- .active_model()
- .context("no active model")?;
+ .default_model()
+ .context("no active model")?
+ .model;
for (model, alternative) in iter::once(primary_model)
.chain(alternative_models)
@@ -239,8 +239,8 @@ impl InlineAssistant {
let is_authenticated = || {
LanguageModelRegistry::read_global(cx)
- .active_provider()
- .map_or(false, |provider| provider.is_authenticated(cx))
+ .inline_assistant_model()
+ .map_or(false, |model| model.provider.is_authenticated(cx))
};
let thread_store = workspace
@@ -279,8 +279,8 @@ impl InlineAssistant {
cx.spawn_in(window, async move |_workspace, cx| {
let Some(task) = cx.update(|_, cx| {
LanguageModelRegistry::read_global(cx)
- .active_provider()
- .map_or(None, |provider| Some(provider.authenticate(cx)))
+ .inline_assistant_model()
+ .map_or(None, |model| Some(model.provider.authenticate(cx)))
})?
else {
let answer = cx
@@ -401,14 +401,14 @@ impl InlineAssistant {
codegen_ranges.push(anchor_range);
- if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
+ if let Some(model) = LanguageModelRegistry::read_global(cx).inline_assistant_model() {
self.telemetry.report_assistant_event(AssistantEvent {
conversation_id: None,
kind: AssistantKind::Inline,
phase: AssistantPhase::Invoked,
message_id: None,
- model: model.telemetry_id(),
- model_provider: model.provider_id().to_string(),
+ model: model.model.telemetry_id(),
+ model_provider: model.provider.id().to_string(),
response_latency: None,
error_message: None,
language_name: buffer.language().map(|language| language.name().to_proto()),
@@ -976,7 +976,7 @@ impl InlineAssistant {
let active_alternative = assist.codegen.read(cx).active_alternative().clone();
let message_id = active_alternative.read(cx).message_id.clone();
- if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
+ if let Some(model) = LanguageModelRegistry::read_global(cx).inline_assistant_model() {
let language_name = assist.editor.upgrade().and_then(|editor| {
let multibuffer = editor.read(cx).buffer().read(cx);
let snapshot = multibuffer.snapshot(cx);
@@ -996,15 +996,15 @@ impl InlineAssistant {
} else {
AssistantPhase::Accepted
},
- model: model.telemetry_id(),
- model_provider: model.provider_id().to_string(),
+ model: model.model.telemetry_id(),
+ model_provider: model.model.provider_id().to_string(),
response_latency: None,
error_message: None,
language_name: language_name.map(|name| name.to_proto()),
},
Some(self.telemetry.clone()),
cx.http_client(),
- model.api_key(cx),
+ model.model.api_key(cx),
cx.background_executor(),
);
}
@@ -1,4 +1,4 @@
-use crate::assistant_model_selector::AssistantModelSelector;
+use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::buffer_codegen::BufferCodegen;
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
@@ -582,7 +582,7 @@ impl<T: 'static> PromptEditor<T> {
let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
let model_registry = LanguageModelRegistry::read_global(cx);
- let default_model = model_registry.active_model();
+ let default_model = model_registry.default_model().map(|default| default.model);
let alternative_models = model_registry.inline_alternative_models();
let get_model_name = |index: usize| -> String {
@@ -890,6 +890,7 @@ impl PromptEditor<BufferCodegen> {
fs,
model_selector_menu_handle,
prompt_editor.focus_handle(cx),
+ ModelType::InlineAssistant,
window,
cx,
)
@@ -1042,6 +1043,7 @@ impl PromptEditor<TerminalCodegen> {
fs,
model_selector_menu_handle.clone(),
prompt_editor.focus_handle(cx),
+ ModelType::InlineAssistant,
window,
cx,
)
@@ -1,5 +1,6 @@
use std::sync::Arc;
+use crate::assistant_model_selector::ModelType;
use collections::HashSet;
use editor::actions::MoveUp;
use editor::{ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorStyle};
@@ -10,7 +11,7 @@ use gpui::{
WeakEntity, linear_color_stop, linear_gradient, point,
};
use language::Buffer;
-use language_model::LanguageModelRegistry;
+use language_model::{ConfiguredModel, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector;
use multi_buffer;
use project::Project;
@@ -139,6 +140,7 @@ impl MessageEditor {
fs.clone(),
model_selector_menu_handle,
editor.focus_handle(cx),
+ ModelType::Default,
window,
cx,
)
@@ -191,7 +193,7 @@ impl MessageEditor {
fn is_model_selected(&self, cx: &App) -> bool {
LanguageModelRegistry::read_global(cx)
- .active_model()
+ .default_model()
.is_some()
}
@@ -201,20 +203,16 @@ impl MessageEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
- let provider = LanguageModelRegistry::read_global(cx).active_provider();
- if provider
- .as_ref()
- .map_or(false, |provider| provider.must_accept_terms(cx))
- {
- cx.notify();
- return;
- }
-
let model_registry = LanguageModelRegistry::read_global(cx);
- let Some(model) = model_registry.active_model() else {
+ let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else {
return;
};
+ if provider.must_accept_terms(cx) {
+ cx.notify();
+ return;
+ }
+
let user_message = self.editor.update(cx, |editor, cx| {
let text = editor.text(cx);
editor.clear(window, cx);
@@ -130,8 +130,8 @@ impl Render for ProfileSelector {
let model_registry = LanguageModelRegistry::read_global(cx);
let supports_tools = model_registry
- .active_model()
- .map_or(false, |model| model.supports_tools());
+ .default_model()
+ .map_or(false, |default| default.model.supports_tools());
let icon = match profile_id.as_str() {
"write" => IconName::Pencil,
@@ -2,7 +2,9 @@ use crate::inline_prompt_editor::CodegenStatus;
use client::telemetry::Telemetry;
use futures::{SinkExt, StreamExt, channel::mpsc};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Task};
-use language_model::{LanguageModelRegistry, LanguageModelRequest, report_assistant_event};
+use language_model::{
+ ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, report_assistant_event,
+};
use std::{sync::Arc, time::Instant};
use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
use terminal::Terminal;
@@ -31,7 +33,9 @@ impl TerminalCodegen {
}
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
- let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+ let Some(ConfiguredModel { model, .. }) =
+ LanguageModelRegistry::read_global(cx).inline_assistant_model()
+ else {
return;
};
@@ -13,8 +13,8 @@ use fs::Fs;
use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
use language::Buffer;
use language_model::{
- LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
- report_assistant_event,
+ ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
+ Role, report_assistant_event,
};
use prompt_store::PromptBuilder;
use std::sync::Arc;
@@ -286,7 +286,9 @@ impl TerminalInlineAssistant {
})
.log_err();
- if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
+ if let Some(ConfiguredModel { model, .. }) =
+ LanguageModelRegistry::read_global(cx).inline_assistant_model()
+ {
let codegen = assist.codegen.read(cx);
let executor = cx.background_executor().clone();
report_assistant_event(
@@ -14,10 +14,10 @@ use futures::{FutureExt, StreamExt as _};
use git::repository::DiffType;
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{
- LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
- LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
- LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
- Role, StopReason, TokenUsage,
+ ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
+ LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
+ LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
+ PaymentRequiredError, Role, StopReason, TokenUsage,
};
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
use project::{Project, Worktree};
@@ -1250,14 +1250,11 @@ impl Thread {
}
pub fn summarize(&mut self, cx: &mut Context<Self>) {
- let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
- return;
- };
- let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+ let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
return;
};
- if !provider.is_authenticated(cx) {
+ if !model.provider.is_authenticated(cx) {
return;
}
@@ -1276,7 +1273,7 @@ impl Thread {
self.pending_summary = cx.spawn(async move |this, cx| {
async move {
- let stream = model.stream_completion_text(request, &cx);
+ let stream = model.model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
let mut new_summary = String::new();
@@ -1320,8 +1317,8 @@ impl Thread {
_ => {}
}
- let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
- let model = LanguageModelRegistry::read_global(cx).active_model()?;
+ let ConfiguredModel { model, provider } =
+ LanguageModelRegistry::read_global(cx).thread_summary_model()?;
if !provider.is_authenticated(cx) {
return None;
@@ -1782,11 +1779,11 @@ impl Thread {
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
let model_registry = LanguageModelRegistry::read_global(cx);
- let Some(model) = model_registry.active_model() else {
+ let Some(model) = model_registry.default_model() else {
return TotalTokenUsage::default();
};
- let max = model.max_token_count();
+ let max = model.model.max_token_count();
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
@@ -161,12 +161,38 @@ fn init_language_model_settings(cx: &mut App) {
fn update_active_language_model_from_settings(cx: &mut App) {
let settings = AssistantSettings::get_global(cx);
+ // Default model - used as fallback
let active_model_provider_name =
LanguageModelProviderId::from(settings.default_model.provider.clone());
let active_model_id = LanguageModelId::from(settings.default_model.model.clone());
- let editor_provider_name =
- LanguageModelProviderId::from(settings.editor_model.provider.clone());
- let editor_model_id = LanguageModelId::from(settings.editor_model.model.clone());
+
+ // Inline assistant model
+ let inline_assistant_model = settings
+ .inline_assistant_model
+ .as_ref()
+ .unwrap_or(&settings.default_model);
+ let inline_assistant_provider_name =
+ LanguageModelProviderId::from(inline_assistant_model.provider.clone());
+ let inline_assistant_model_id = LanguageModelId::from(inline_assistant_model.model.clone());
+
+ // Commit message model
+ let commit_message_model = settings
+ .commit_message_model
+ .as_ref()
+ .unwrap_or(&settings.default_model);
+ let commit_message_provider_name =
+ LanguageModelProviderId::from(commit_message_model.provider.clone());
+ let commit_message_model_id = LanguageModelId::from(commit_message_model.model.clone());
+
+ // Thread summary model
+ let thread_summary_model = settings
+ .thread_summary_model
+ .as_ref()
+ .unwrap_or(&settings.default_model);
+ let thread_summary_provider_name =
+ LanguageModelProviderId::from(thread_summary_model.provider.clone());
+ let thread_summary_model_id = LanguageModelId::from(thread_summary_model.model.clone());
+
let inline_alternatives = settings
.inline_alternatives
.iter()
@@ -177,9 +203,29 @@ fn update_active_language_model_from_settings(cx: &mut App) {
)
})
.collect::<Vec<_>>();
+
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
- registry.select_active_model(&active_model_provider_name, &active_model_id, cx);
- registry.select_editor_model(&editor_provider_name, &editor_model_id, cx);
+ // Set the default model
+ registry.select_default_model(&active_model_provider_name, &active_model_id, cx);
+
+ // Set the specific models
+ registry.select_inline_assistant_model(
+ &inline_assistant_provider_name,
+ &inline_assistant_model_id,
+ cx,
+ );
+ registry.select_commit_message_model(
+ &commit_message_provider_name,
+ &commit_message_model_id,
+ cx,
+ );
+ registry.select_thread_summary_model(
+ &thread_summary_provider_name,
+ &thread_summary_model_id,
+ cx,
+ );
+
+ // Set the alternatives
registry.select_inline_alternative_models(inline_alternatives, cx);
});
}
@@ -22,7 +22,8 @@ use gpui::{
};
use language::LanguageRegistry;
use language_model::{
- AuthenticateError, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
+ AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
+ ZED_CLOUD_PROVIDER_ID,
};
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
@@ -298,8 +299,10 @@ impl AssistantPanel {
&LanguageModelRegistry::global(cx),
window,
|this, _, event: &language_model::Event, window, cx| match event {
- language_model::Event::ActiveModelChanged
- | language_model::Event::EditorModelChanged => {
+ language_model::Event::DefaultModelChanged
+ | language_model::Event::InlineAssistantModelChanged
+ | language_model::Event::CommitMessageModelChanged
+ | language_model::Event::ThreadSummaryModelChanged => {
this.completion_provider_changed(window, cx);
}
language_model::Event::ProviderStateChanged => {
@@ -468,12 +471,12 @@ impl AssistantPanel {
}
fn update_zed_ai_notice_visibility(&mut self, client_status: Status, cx: &mut Context<Self>) {
- let active_provider = LanguageModelRegistry::read_global(cx).active_provider();
+ let model = LanguageModelRegistry::read_global(cx).default_model();
// If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is
// the provider, we want to show a nudge to sign in.
let show_zed_ai_notice = client_status.is_signed_out()
- && active_provider.map_or(true, |provider| provider.id().0 == ZED_CLOUD_PROVIDER_ID);
+ && model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID);
self.show_zed_ai_notice = show_zed_ai_notice;
cx.notify();
@@ -541,8 +544,8 @@ impl AssistantPanel {
}
let Some(new_provider_id) = LanguageModelRegistry::read_global(cx)
- .active_provider()
- .map(|p| p.id())
+ .default_model()
+ .map(|default| default.provider.id())
else {
return;
};
@@ -568,7 +571,9 @@ impl AssistantPanel {
return;
}
- let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
+ let Some(ConfiguredModel { provider, .. }) =
+ LanguageModelRegistry::read_global(cx).default_model()
+ else {
return;
};
@@ -976,8 +981,8 @@ impl AssistantPanel {
|this, _, event: &ConfigurationViewEvent, window, cx| match event {
ConfigurationViewEvent::NewProviderContextEditor(provider) => {
if LanguageModelRegistry::read_global(cx)
- .active_provider()
- .map_or(true, |p| p.id() != provider.id())
+ .default_model()
+ .map_or(true, |default| default.provider.id() != provider.id())
{
if let Some(model) = provider.default_model(cx) {
update_settings_file::<AssistantSettings>(
@@ -1155,8 +1160,8 @@ impl AssistantPanel {
fn is_authenticated(&mut self, cx: &mut Context<Self>) -> bool {
LanguageModelRegistry::read_global(cx)
- .active_provider()
- .map_or(false, |provider| provider.is_authenticated(cx))
+ .default_model()
+ .map_or(false, |default| default.provider.is_authenticated(cx))
}
fn authenticate(
@@ -1164,8 +1169,8 @@ impl AssistantPanel {
cx: &mut Context<Self>,
) -> Option<Task<Result<(), AuthenticateError>>> {
LanguageModelRegistry::read_global(cx)
- .active_provider()
- .map_or(None, |provider| Some(provider.authenticate(cx)))
+ .default_model()
+ .map_or(None, |default| Some(default.provider.authenticate(cx)))
}
fn restart_context_servers(
@@ -34,8 +34,8 @@ use gpui::{
};
use language::{Buffer, IndentKind, Point, Selection, TransactionId, line_diff};
use language_model::{
- LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelTextStream, Role, report_assistant_event,
+ ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
+ LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use multi_buffer::MultiBufferRow;
@@ -312,7 +312,9 @@ impl InlineAssistant {
start..end,
));
- if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
+ if let Some(ConfiguredModel { model, .. }) =
+ LanguageModelRegistry::read_global(cx).default_model()
+ {
self.telemetry.report_assistant_event(AssistantEvent {
conversation_id: None,
kind: AssistantKind::Inline,
@@ -877,7 +879,9 @@ impl InlineAssistant {
let active_alternative = assist.codegen.read(cx).active_alternative().clone();
let message_id = active_alternative.read(cx).message_id.clone();
- if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
+ if let Some(ConfiguredModel { model, .. }) =
+ LanguageModelRegistry::read_global(cx).default_model()
+ {
let language_name = assist.editor.upgrade().and_then(|editor| {
let multibuffer = editor.read(cx).buffer().read(cx);
let multibuffer_snapshot = multibuffer.snapshot(cx);
@@ -1629,8 +1633,8 @@ impl Render for PromptEditor {
format!(
"Using {}",
LanguageModelRegistry::read_global(cx)
- .active_model()
- .map(|model| model.name().0)
+ .default_model()
+ .map(|default| default.model.name().0)
.unwrap_or_else(|| "No model selected".into()),
),
None,
@@ -2077,7 +2081,7 @@ impl PromptEditor {
let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
let model_registry = LanguageModelRegistry::read_global(cx);
- let default_model = model_registry.active_model();
+ let default_model = model_registry.default_model().map(|default| default.model);
let alternative_models = model_registry.inline_alternative_models();
let get_model_name = |index: usize| -> String {
@@ -2183,7 +2187,9 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
- let model = LanguageModelRegistry::read_global(cx).active_model()?;
+ let model = LanguageModelRegistry::read_global(cx)
+ .default_model()?
+ .model;
let token_counts = self.token_counts?;
let max_token_count = model.max_token_count();
@@ -2638,8 +2644,9 @@ impl Codegen {
}
let primary_model = LanguageModelRegistry::read_global(cx)
- .active_model()
- .context("no active model")?;
+ .default_model()
+ .context("no active model")?
+ .model;
for (model, alternative) in iter::once(primary_model)
.chain(alternative_models)
@@ -2863,7 +2870,9 @@ impl CodegenAlternative {
assistant_panel_context: Option<LanguageModelRequest>,
cx: &App,
) -> BoxFuture<'static, Result<TokenCounts>> {
- if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
+ if let Some(ConfiguredModel { model, .. }) =
+ LanguageModelRegistry::read_global(cx).inline_assistant_model()
+ {
let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
match request {
Ok(request) => {
@@ -16,8 +16,8 @@ use gpui::{
};
use language::Buffer;
use language_model::{
- LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
- report_assistant_event,
+ ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
+ Role, report_assistant_event,
};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use prompt_store::PromptBuilder;
@@ -318,7 +318,9 @@ impl TerminalInlineAssistant {
})
.log_err();
- if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
+ if let Some(ConfiguredModel { model, .. }) =
+ LanguageModelRegistry::read_global(cx).inline_assistant_model()
+ {
let codegen = assist.codegen.read(cx);
let executor = cx.background_executor().clone();
report_assistant_event(
@@ -652,8 +654,8 @@ impl Render for PromptEditor {
format!(
"Using {}",
LanguageModelRegistry::read_global(cx)
- .active_model()
- .map(|model| model.name().0)
+ .inline_assistant_model()
+ .map(|inline_assistant| inline_assistant.model.name().0)
.unwrap_or_else(|| "No model selected".into()),
),
None,
@@ -822,7 +824,9 @@ impl PromptEditor {
fn count_tokens(&mut self, cx: &mut Context<Self>) {
let assist_id = self.id;
- let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+ let Some(ConfiguredModel { model, .. }) =
+ LanguageModelRegistry::read_global(cx).inline_assistant_model()
+ else {
return;
};
self.pending_token_count = cx.spawn(async move |this, cx| {
@@ -980,7 +984,9 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
- let model = LanguageModelRegistry::read_global(cx).active_model()?;
+ let model = LanguageModelRegistry::read_global(cx)
+ .inline_assistant_model()?
+ .model;
let token_count = self.token_count?;
let max_token_count = model.max_token_count();
@@ -1131,7 +1137,9 @@ impl Codegen {
}
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
- let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+ let Some(ConfiguredModel { model, .. }) =
+ LanguageModelRegistry::read_global(cx).inline_assistant_model()
+ else {
return;
};
@@ -1272,7 +1272,7 @@ impl AssistantContext {
// Assume it will be a Chat request, even though that takes fewer tokens (and risks going over the limit),
// because otherwise you see in the UI that your empty message has a bunch of tokens already used.
let request = self.to_completion_request(RequestType::Chat, cx);
- let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+ let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return;
};
let debounce = self.token_count.is_some();
@@ -1284,10 +1284,12 @@ impl AssistantContext {
.await;
}
- let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
+ let token_count = cx
+ .update(|cx| model.model.count_tokens(request, cx))?
+ .await?;
this.update(cx, |this, cx| {
this.token_count = Some(token_count);
- this.start_cache_warming(&model, cx);
+ this.start_cache_warming(&model.model, cx);
cx.notify()
})
}
@@ -2304,14 +2306,16 @@ impl AssistantContext {
cx: &mut Context<Self>,
) -> Option<MessageAnchor> {
let model_registry = LanguageModelRegistry::read_global(cx);
- let provider = model_registry.active_provider()?;
- let model = model_registry.active_model()?;
+ let model = model_registry.default_model()?;
let last_message_id = self.get_last_valid_message_id(cx)?;
- if !provider.is_authenticated(cx) {
+ if !model.provider.is_authenticated(cx) {
log::info!("completion provider has no credentials");
return None;
}
+
+ let model = model.model;
+
// Compute which messages to cache, including the last one.
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
@@ -2940,15 +2944,12 @@ impl AssistantContext {
}
pub fn summarize(&mut self, replace_old: bool, cx: &mut Context<Self>) {
- let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
- return;
- };
- let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+ let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return;
};
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
- if !provider.is_authenticated(cx) {
+ if !model.provider.is_authenticated(cx) {
return;
}
@@ -2964,7 +2965,7 @@ impl AssistantContext {
self.pending_summary = cx.spawn(async move |this, cx| {
async move {
- let stream = model.stream_completion_text(request, &cx);
+ let stream = model.model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
let mut replaced = !replace_old;
@@ -384,7 +384,9 @@ impl ContextEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
- let provider = LanguageModelRegistry::read_global(cx).active_provider();
+ let provider = LanguageModelRegistry::read_global(cx)
+ .default_model()
+ .map(|default| default.provider);
if provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx))
@@ -2395,13 +2397,13 @@ impl ContextEditor {
None => (ButtonStyle::Filled, None),
};
- let provider = LanguageModelRegistry::read_global(cx).active_provider();
+ let model = LanguageModelRegistry::read_global(cx).default_model();
let has_configuration_error = configuration_error(cx).is_some();
let needs_to_accept_terms = self.show_accept_terms
- && provider
+ && model
.as_ref()
- .map_or(false, |provider| provider.must_accept_terms(cx));
+ .map_or(false, |model| model.provider.must_accept_terms(cx));
let disabled = has_configuration_error || needs_to_accept_terms;
ButtonLike::new("send_button")
@@ -2454,7 +2456,9 @@ impl ContextEditor {
None => (ButtonStyle::Filled, None),
};
- let provider = LanguageModelRegistry::read_global(cx).active_provider();
+ let provider = LanguageModelRegistry::read_global(cx)
+ .default_model()
+ .map(|default| default.provider);
let has_configuration_error = configuration_error(cx).is_some();
let needs_to_accept_terms = self.show_accept_terms
@@ -2500,7 +2504,9 @@ impl ContextEditor {
}
fn render_language_model_selector(&self, cx: &mut Context<Self>) -> impl IntoElement {
- let active_model = LanguageModelRegistry::read_global(cx).active_model();
+ let active_model = LanguageModelRegistry::read_global(cx)
+ .default_model()
+ .map(|default| default.model);
let focus_handle = self.editor().focus_handle(cx).clone();
let model_name = match active_model {
Some(model) => model.name().0,
@@ -3020,7 +3026,9 @@ impl EventEmitter<SearchEvent> for ContextEditor {}
impl Render for ContextEditor {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let provider = LanguageModelRegistry::read_global(cx).active_provider();
+ let provider = LanguageModelRegistry::read_global(cx)
+ .default_model()
+ .map(|default| default.provider);
let accept_terms = if self.show_accept_terms {
provider.as_ref().and_then(|provider| {
provider.render_accept_terms(LanguageModelProviderTosView::PromptEditorPopup, cx)
@@ -3616,7 +3624,9 @@ enum TokenState {
fn token_state(context: &Entity<AssistantContext>, cx: &App) -> Option<TokenState> {
const WARNING_TOKEN_THRESHOLD: f32 = 0.8;
- let model = LanguageModelRegistry::read_global(cx).active_model()?;
+ let model = LanguageModelRegistry::read_global(cx)
+ .default_model()?
+ .model;
let token_count = context.read(cx).token_count()?;
let max_token_count = model.max_token_count();
@@ -3669,16 +3679,16 @@ pub enum ConfigurationError {
}
fn configuration_error(cx: &App) -> Option<ConfigurationError> {
- let provider = LanguageModelRegistry::read_global(cx).active_provider();
- let is_authenticated = provider
+ let model = LanguageModelRegistry::read_global(cx).default_model();
+ let is_authenticated = model
.as_ref()
- .map_or(false, |provider| provider.is_authenticated(cx));
+ .map_or(false, |model| model.provider.is_authenticated(cx));
- if provider.is_some() && is_authenticated {
+ if model.is_some() && is_authenticated {
return None;
}
- if provider.is_none() {
+ if model.is_none() {
return Some(ConfigurationError::NoProvider);
}
@@ -156,10 +156,10 @@ impl HeadlessAssistant {
}
if thread.read(cx).all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
- if let Some(model) = model_registry.active_model() {
+ if let Some(model) = model_registry.default_model() {
thread.update(cx, |thread, cx| {
thread.attach_tool_results(cx);
- thread.send_to_model(model, RequestKind::Chat, cx);
+ thread.send_to_model(model.model, RequestKind::Chat, cx);
});
} else {
println!(
@@ -37,9 +37,6 @@ struct Args {
/// Name of the model (default: "claude-3-7-sonnet-latest")
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model_name: String,
- /// Name of the editor model (default: value of `--model_name`).
- #[arg(long)]
- editor_model_name: Option<String>,
/// Name of the judge model (default: value of `--model_name`).
#[arg(long)]
judge_model_name: Option<String>,
@@ -79,11 +76,6 @@ fn main() {
let app_state = headless_assistant::init(cx);
let model = find_model(&args.model_name, cx).unwrap();
- let editor_model = if let Some(model_name) = &args.editor_model_name {
- find_model(model_name, cx).unwrap()
- } else {
- model.clone()
- };
let judge_model = if let Some(model_name) = &args.judge_model_name {
find_model(model_name, cx).unwrap()
} else {
@@ -91,12 +83,10 @@ fn main() {
};
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
- registry.set_active_model(Some(model.clone()), cx);
- registry.set_editor_model(Some(editor_model.clone()), cx);
+ registry.set_default_model(Some(model.clone()), cx);
});
let model_provider_id = model.provider_id();
- let editor_model_provider_id = editor_model.provider_id();
let judge_model_provider_id = judge_model.provider_id();
let framework_path_clone = framework_path.clone();
@@ -110,10 +100,6 @@ fn main() {
.unwrap()
.await
.unwrap();
- cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
- .unwrap()
- .await
- .unwrap();
cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
.unwrap()
.await
@@ -77,7 +77,9 @@ pub struct AssistantSettings {
pub default_width: Pixels,
pub default_height: Pixels,
pub default_model: LanguageModelSelection,
- pub editor_model: LanguageModelSelection,
+ pub inline_assistant_model: Option<LanguageModelSelection>,
+ pub commit_message_model: Option<LanguageModelSelection>,
+ pub thread_summary_model: Option<LanguageModelSelection>,
pub inline_alternatives: Vec<LanguageModelSelection>,
pub using_outdated_settings_version: bool,
pub enable_experimental_live_diffs: bool,
@@ -95,13 +97,25 @@ impl AssistantSettings {
cx.is_staff() || self.enable_experimental_live_diffs
}
+
+ pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
+ self.inline_assistant_model = Some(LanguageModelSelection { provider, model });
+ }
+
+ pub fn set_commit_message_model(&mut self, provider: String, model: String) {
+ self.commit_message_model = Some(LanguageModelSelection { provider, model });
+ }
+
+ pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
+ self.thread_summary_model = Some(LanguageModelSelection { provider, model });
+ }
}
/// Assistant panel settings
#[derive(Clone, Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum AssistantSettingsContent {
- Versioned(VersionedAssistantSettingsContent),
+ Versioned(Box<VersionedAssistantSettingsContent>),
Legacy(LegacyAssistantSettingsContent),
}
@@ -121,14 +135,14 @@ impl JsonSchema for AssistantSettingsContent {
impl Default for AssistantSettingsContent {
fn default() -> Self {
- Self::Versioned(VersionedAssistantSettingsContent::default())
+ Self::Versioned(Box::new(VersionedAssistantSettingsContent::default()))
}
}
impl AssistantSettingsContent {
pub fn is_version_outdated(&self) -> bool {
match self {
- AssistantSettingsContent::Versioned(settings) => match settings {
+ AssistantSettingsContent::Versioned(settings) => match **settings {
VersionedAssistantSettingsContent::V1(_) => true,
VersionedAssistantSettingsContent::V2(_) => false,
},
@@ -138,8 +152,8 @@ impl AssistantSettingsContent {
fn upgrade(&self) -> AssistantSettingsContentV2 {
match self {
- AssistantSettingsContent::Versioned(settings) => match settings {
- VersionedAssistantSettingsContent::V1(settings) => AssistantSettingsContentV2 {
+ AssistantSettingsContent::Versioned(settings) => match **settings {
+ VersionedAssistantSettingsContent::V1(ref settings) => AssistantSettingsContentV2 {
enabled: settings.enabled,
button: settings.button,
dock: settings.dock,
@@ -186,7 +200,9 @@ impl AssistantSettingsContent {
})
}
}),
- editor_model: None,
+ inline_assistant_model: None,
+ commit_message_model: None,
+ thread_summary_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
default_profile: None,
@@ -194,7 +210,7 @@ impl AssistantSettingsContent {
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
},
- VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
+ VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
},
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 {
enabled: None,
@@ -211,7 +227,9 @@ impl AssistantSettingsContent {
.id()
.to_string(),
}),
- editor_model: None,
+ inline_assistant_model: None,
+ commit_message_model: None,
+ thread_summary_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
default_profile: None,
@@ -224,11 +242,11 @@ impl AssistantSettingsContent {
pub fn set_dock(&mut self, dock: AssistantDockPosition) {
match self {
- AssistantSettingsContent::Versioned(settings) => match settings {
- VersionedAssistantSettingsContent::V1(settings) => {
+ AssistantSettingsContent::Versioned(settings) => match **settings {
+ VersionedAssistantSettingsContent::V1(ref mut settings) => {
settings.dock = Some(dock);
}
- VersionedAssistantSettingsContent::V2(settings) => {
+ VersionedAssistantSettingsContent::V2(ref mut settings) => {
settings.dock = Some(dock);
}
},
@@ -243,77 +261,79 @@ impl AssistantSettingsContent {
let provider = language_model.provider_id().0.to_string();
match self {
- AssistantSettingsContent::Versioned(settings) => match settings {
- VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() {
- "zed.dev" => {
- log::warn!("attempted to set zed.dev model on outdated settings");
- }
- "anthropic" => {
- let api_url = match &settings.provider {
- Some(AssistantProviderContentV1::Anthropic { api_url, .. }) => {
- api_url.clone()
- }
- _ => None,
- };
- settings.provider = Some(AssistantProviderContentV1::Anthropic {
- default_model: AnthropicModel::from_id(&model).ok(),
- api_url,
- });
- }
- "ollama" => {
- let api_url = match &settings.provider {
- Some(AssistantProviderContentV1::Ollama { api_url, .. }) => {
- api_url.clone()
- }
- _ => None,
- };
- settings.provider = Some(AssistantProviderContentV1::Ollama {
- default_model: Some(ollama::Model::new(&model, None, None)),
- api_url,
- });
- }
- "lmstudio" => {
- let api_url = match &settings.provider {
- Some(AssistantProviderContentV1::LmStudio { api_url, .. }) => {
- api_url.clone()
- }
- _ => None,
- };
- settings.provider = Some(AssistantProviderContentV1::LmStudio {
- default_model: Some(lmstudio::Model::new(&model, None, None)),
- api_url,
- });
- }
- "openai" => {
- let (api_url, available_models) = match &settings.provider {
- Some(AssistantProviderContentV1::OpenAi {
+ AssistantSettingsContent::Versioned(settings) => match **settings {
+ VersionedAssistantSettingsContent::V1(ref mut settings) => {
+ match provider.as_ref() {
+ "zed.dev" => {
+ log::warn!("attempted to set zed.dev model on outdated settings");
+ }
+ "anthropic" => {
+ let api_url = match &settings.provider {
+ Some(AssistantProviderContentV1::Anthropic { api_url, .. }) => {
+ api_url.clone()
+ }
+ _ => None,
+ };
+ settings.provider = Some(AssistantProviderContentV1::Anthropic {
+ default_model: AnthropicModel::from_id(&model).ok(),
+ api_url,
+ });
+ }
+ "ollama" => {
+ let api_url = match &settings.provider {
+ Some(AssistantProviderContentV1::Ollama { api_url, .. }) => {
+ api_url.clone()
+ }
+ _ => None,
+ };
+ settings.provider = Some(AssistantProviderContentV1::Ollama {
+ default_model: Some(ollama::Model::new(&model, None, None)),
+ api_url,
+ });
+ }
+ "lmstudio" => {
+ let api_url = match &settings.provider {
+ Some(AssistantProviderContentV1::LmStudio { api_url, .. }) => {
+ api_url.clone()
+ }
+ _ => None,
+ };
+ settings.provider = Some(AssistantProviderContentV1::LmStudio {
+ default_model: Some(lmstudio::Model::new(&model, None, None)),
+ api_url,
+ });
+ }
+ "openai" => {
+ let (api_url, available_models) = match &settings.provider {
+ Some(AssistantProviderContentV1::OpenAi {
+ api_url,
+ available_models,
+ ..
+ }) => (api_url.clone(), available_models.clone()),
+ _ => (None, None),
+ };
+ settings.provider = Some(AssistantProviderContentV1::OpenAi {
+ default_model: OpenAiModel::from_id(&model).ok(),
api_url,
available_models,
- ..
- }) => (api_url.clone(), available_models.clone()),
- _ => (None, None),
- };
- settings.provider = Some(AssistantProviderContentV1::OpenAi {
- default_model: OpenAiModel::from_id(&model).ok(),
- api_url,
- available_models,
- });
- }
- "deepseek" => {
- let api_url = match &settings.provider {
- Some(AssistantProviderContentV1::DeepSeek { api_url, .. }) => {
- api_url.clone()
- }
- _ => None,
- };
- settings.provider = Some(AssistantProviderContentV1::DeepSeek {
- default_model: DeepseekModel::from_id(&model).ok(),
- api_url,
- });
+ });
+ }
+ "deepseek" => {
+ let api_url = match &settings.provider {
+ Some(AssistantProviderContentV1::DeepSeek { api_url, .. }) => {
+ api_url.clone()
+ }
+ _ => None,
+ };
+ settings.provider = Some(AssistantProviderContentV1::DeepSeek {
+ default_model: DeepseekModel::from_id(&model).ok(),
+ api_url,
+ });
+ }
+ _ => {}
}
- _ => {}
- },
- VersionedAssistantSettingsContent::V2(settings) => {
+ }
+ VersionedAssistantSettingsContent::V2(ref mut settings) => {
settings.default_model = Some(LanguageModelSelection { provider, model });
}
},
@@ -325,23 +345,48 @@ impl AssistantSettingsContent {
}
}
+ pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
+ if let AssistantSettingsContent::Versioned(boxed) = self {
+ if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
+ settings.inline_assistant_model = Some(LanguageModelSelection { provider, model });
+ }
+ }
+ }
+
+ pub fn set_commit_message_model(&mut self, provider: String, model: String) {
+ if let AssistantSettingsContent::Versioned(boxed) = self {
+ if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
+ settings.commit_message_model = Some(LanguageModelSelection { provider, model });
+ }
+ }
+ }
+
+ pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
+ if let AssistantSettingsContent::Versioned(boxed) = self {
+ if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
+ settings.thread_summary_model = Some(LanguageModelSelection { provider, model });
+ }
+ }
+ }
+
pub fn set_always_allow_tool_actions(&mut self, allow: bool) {
- let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) =
- self
- else {
+ let AssistantSettingsContent::Versioned(boxed) = self else {
return;
};
- settings.always_allow_tool_actions = Some(allow);
+
+ if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
+ settings.always_allow_tool_actions = Some(allow);
+ }
}
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
- let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) =
- self
- else {
+ let AssistantSettingsContent::Versioned(boxed) = self else {
return;
};
- settings.default_profile = Some(profile_id);
+ if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
+ settings.default_profile = Some(profile_id);
+ }
}
pub fn create_profile(
@@ -349,37 +394,37 @@ impl AssistantSettingsContent {
profile_id: AgentProfileId,
profile: AgentProfile,
) -> Result<()> {
- let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) =
- self
- else {
+ let AssistantSettingsContent::Versioned(boxed) = self else {
return Ok(());
};
- let profiles = settings.profiles.get_or_insert_default();
- if profiles.contains_key(&profile_id) {
- bail!("profile with ID '{profile_id}' already exists");
- }
+ if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
+ let profiles = settings.profiles.get_or_insert_default();
+ if profiles.contains_key(&profile_id) {
+ bail!("profile with ID '{profile_id}' already exists");
+ }
- profiles.insert(
- profile_id,
- AgentProfileContent {
- name: profile.name.into(),
- tools: profile.tools,
- enable_all_context_servers: Some(profile.enable_all_context_servers),
- context_servers: profile
- .context_servers
- .into_iter()
- .map(|(server_id, preset)| {
- (
- server_id,
- ContextServerPresetContent {
- tools: preset.tools,
- },
- )
- })
- .collect(),
- },
- );
+ profiles.insert(
+ profile_id,
+ AgentProfileContent {
+ name: profile.name.into(),
+ tools: profile.tools,
+ enable_all_context_servers: Some(profile.enable_all_context_servers),
+ context_servers: profile
+ .context_servers
+ .into_iter()
+ .map(|(server_id, preset)| {
+ (
+ server_id,
+ ContextServerPresetContent {
+ tools: preset.tools,
+ },
+ )
+ })
+ .collect(),
+ },
+ );
+ }
Ok(())
}
@@ -403,7 +448,9 @@ impl Default for VersionedAssistantSettingsContent {
default_width: None,
default_height: None,
default_model: None,
- editor_model: None,
+ inline_assistant_model: None,
+ commit_message_model: None,
+ thread_summary_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
default_profile: None,
@@ -436,10 +483,14 @@ pub struct AssistantSettingsContentV2 {
///
/// Default: 320
default_height: Option<f32>,
- /// The default model to use when creating new chats.
+ /// The default model to use when creating new chats and for other features when a specific model is not specified.
default_model: Option<LanguageModelSelection>,
- /// The model to use when applying edits from the assistant.
- editor_model: Option<LanguageModelSelection>,
+ /// Model to use for the inline assistant. Defaults to default_model when not specified.
+ inline_assistant_model: Option<LanguageModelSelection>,
+ /// Model to use for generating git commit messages. Defaults to default_model when not specified.
+ commit_message_model: Option<LanguageModelSelection>,
+ /// Model to use for generating thread summaries. Defaults to default_model when not specified.
+ thread_summary_model: Option<LanguageModelSelection>,
/// Additional models with which to generate alternatives when performing inline assists.
inline_alternatives: Option<Vec<LanguageModelSelection>>,
/// Enable experimental live diffs in the assistant panel.
@@ -601,7 +652,15 @@ impl Settings for AssistantSettings {
value.default_height.map(Into::into),
);
merge(&mut settings.default_model, value.default_model);
- merge(&mut settings.editor_model, value.editor_model);
+ settings.inline_assistant_model = value
+ .inline_assistant_model
+ .or(settings.inline_assistant_model.take());
+ settings.commit_message_model = value
+ .commit_message_model
+ .or(settings.commit_message_model.take());
+ settings.thread_summary_model = value
+ .thread_summary_model
+ .or(settings.thread_summary_model.take());
merge(&mut settings.inline_alternatives, value.inline_alternatives);
merge(
&mut settings.enable_experimental_live_diffs,
@@ -692,16 +751,15 @@ mod tests {
settings::SettingsStore::global(cx).update_settings_file::<AssistantSettings>(
fs.clone(),
|settings, _| {
- *settings = AssistantSettingsContent::Versioned(
+ *settings = AssistantSettingsContent::Versioned(Box::new(
VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
default_model: Some(LanguageModelSelection {
provider: "test-provider".into(),
model: "gpt-99".into(),
}),
- editor_model: Some(LanguageModelSelection {
- provider: "test-provider".into(),
- model: "gpt-99".into(),
- }),
+ inline_assistant_model: None,
+ commit_message_model: None,
+ thread_summary_model: None,
inline_alternatives: None,
enabled: None,
button: None,
@@ -714,7 +772,7 @@ mod tests {
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
}),
- )
+ ))
},
);
});
@@ -9,7 +9,7 @@ use collections::HashSet;
use edit_action::{EditAction, EditActionParser, edit_model_prompt};
use futures::{SinkExt, StreamExt, channel::mpsc};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
-use language_model::LanguageModelToolSchemaFormat;
+use language_model::{ConfiguredModel, LanguageModelToolSchemaFormat};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
};
@@ -205,8 +205,8 @@ impl EditToolRequest {
cx: &mut App,
) -> Task<Result<String>> {
let model_registry = LanguageModelRegistry::read_global(cx);
- let Some(model) = model_registry.editor_model() else {
- return Task::ready(Err(anyhow!("No editor model configured")));
+ let Some(ConfiguredModel { model, .. }) = model_registry.default_model() else {
+ return Task::ready(Err(anyhow!("No model configured")));
};
let mut messages = messages.to_vec();
@@ -37,7 +37,8 @@ use gpui::{
use itertools::Itertools;
use language::{Buffer, File};
use language_model::{
- LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
+ ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
+ LanguageModelRequestMessage, Role,
};
use menu::{Confirm, SecondaryConfirm, SelectFirst, SelectLast, SelectNext, SelectPrevious};
use multi_buffer::ExcerptInfo;
@@ -3764,8 +3765,9 @@ fn current_language_model(cx: &Context<'_, GitPanel>) -> Option<Arc<dyn Language
assistant_settings::AssistantSettings::get_global(cx)
.enabled
.then(|| {
- let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
- let model = LanguageModelRegistry::read_global(cx).active_model()?;
+ let ConfiguredModel { provider, model } =
+ LanguageModelRegistry::read_global(cx).commit_message_model()?;
+
provider.is_authenticated(cx).then(|| model)
})
.flatten()
@@ -17,20 +17,25 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)]
pub struct LanguageModelRegistry {
- active_model: Option<ActiveModel>,
- editor_model: Option<ActiveModel>,
+ default_model: Option<ConfiguredModel>,
+ inline_assistant_model: Option<ConfiguredModel>,
+ commit_message_model: Option<ConfiguredModel>,
+ thread_summary_model: Option<ConfiguredModel>,
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
}
-pub struct ActiveModel {
- provider: Arc<dyn LanguageModelProvider>,
- model: Option<Arc<dyn LanguageModel>>,
+#[derive(Clone)]
+pub struct ConfiguredModel {
+ pub provider: Arc<dyn LanguageModelProvider>,
+ pub model: Arc<dyn LanguageModel>,
}
pub enum Event {
- ActiveModelChanged,
- EditorModelChanged,
+ DefaultModelChanged,
+ InlineAssistantModelChanged,
+ CommitMessageModelChanged,
+ ThreadSummaryModelChanged,
ProviderStateChanged,
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
@@ -54,7 +59,7 @@ impl LanguageModelRegistry {
let mut registry = Self::default();
registry.register_provider(fake_provider.clone(), cx);
let model = fake_provider.provided_models(cx)[0].clone();
- registry.set_active_model(Some(model), cx);
+ registry.set_default_model(Some(model), cx);
registry
});
cx.set_global(GlobalLanguageModelRegistry(registry));
@@ -114,7 +119,7 @@ impl LanguageModelRegistry {
self.providers.get(id).cloned()
}
- pub fn select_active_model(
+ pub fn select_default_model(
&mut self,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
@@ -126,11 +131,11 @@ impl LanguageModelRegistry {
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
- self.set_active_model(Some(model), cx);
+ self.set_default_model(Some(model), cx);
}
}
- pub fn select_editor_model(
+ pub fn select_inline_assistant_model(
&mut self,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
@@ -142,23 +147,43 @@ impl LanguageModelRegistry {
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
- self.set_editor_model(Some(model), cx);
+ self.set_inline_assistant_model(Some(model), cx);
}
}
- pub fn set_active_provider(
+ pub fn select_commit_message_model(
&mut self,
- provider: Option<Arc<dyn LanguageModelProvider>>,
+ provider: &LanguageModelProviderId,
+ model_id: &LanguageModelId,
cx: &mut Context<Self>,
) {
- self.active_model = provider.map(|provider| ActiveModel {
- provider,
- model: None,
- });
- cx.emit(Event::ActiveModelChanged);
+ let Some(provider) = self.provider(provider) else {
+ return;
+ };
+
+ let models = provider.provided_models(cx);
+ if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
+ self.set_commit_message_model(Some(model), cx);
+ }
}
- pub fn set_active_model(
+ pub fn select_thread_summary_model(
+ &mut self,
+ provider: &LanguageModelProviderId,
+ model_id: &LanguageModelId,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(provider) = self.provider(provider) else {
+ return;
+ };
+
+ let models = provider.provided_models(cx);
+ if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
+ self.set_thread_summary_model(Some(model), cx);
+ }
+ }
+
+ pub fn set_default_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
@@ -166,21 +191,18 @@ impl LanguageModelRegistry {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
- self.active_model = Some(ActiveModel {
- provider,
- model: Some(model),
- });
- cx.emit(Event::ActiveModelChanged);
+ self.default_model = Some(ConfiguredModel { provider, model });
+ cx.emit(Event::DefaultModelChanged);
} else {
log::warn!("Active model's provider not found in registry");
}
} else {
- self.active_model = None;
- cx.emit(Event::ActiveModelChanged);
+ self.default_model = None;
+ cx.emit(Event::DefaultModelChanged);
}
}
- pub fn set_editor_model(
+ pub fn set_inline_assistant_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
@@ -188,35 +210,80 @@ impl LanguageModelRegistry {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
- self.editor_model = Some(ActiveModel {
- provider,
- model: Some(model),
- });
- cx.emit(Event::EditorModelChanged);
+ self.inline_assistant_model = Some(ConfiguredModel { provider, model });
+ cx.emit(Event::InlineAssistantModelChanged);
} else {
- log::warn!("Active model's provider not found in registry");
+ log::warn!("Inline assistant model's provider not found in registry");
}
} else {
- self.editor_model = None;
- cx.emit(Event::EditorModelChanged);
+ self.inline_assistant_model = None;
+ cx.emit(Event::InlineAssistantModelChanged);
}
}
- pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
+ pub fn set_commit_message_model(
+ &mut self,
+ model: Option<Arc<dyn LanguageModel>>,
+ cx: &mut Context<Self>,
+ ) {
+ if let Some(model) = model {
+ let provider_id = model.provider_id();
+ if let Some(provider) = self.providers.get(&provider_id).cloned() {
+ self.commit_message_model = Some(ConfiguredModel { provider, model });
+ cx.emit(Event::CommitMessageModelChanged);
+ } else {
+ log::warn!("Commit message model's provider not found in registry");
+ }
+ } else {
+ self.commit_message_model = None;
+ cx.emit(Event::CommitMessageModelChanged);
+ }
+ }
+
+ pub fn set_thread_summary_model(
+ &mut self,
+ model: Option<Arc<dyn LanguageModel>>,
+ cx: &mut Context<Self>,
+ ) {
+ if let Some(model) = model {
+ let provider_id = model.provider_id();
+ if let Some(provider) = self.providers.get(&provider_id).cloned() {
+ self.thread_summary_model = Some(ConfiguredModel { provider, model });
+ cx.emit(Event::ThreadSummaryModelChanged);
+ } else {
+ log::warn!("Thread summary model's provider not found in registry");
+ }
+ } else {
+ self.thread_summary_model = None;
+ cx.emit(Event::ThreadSummaryModelChanged);
+ }
+ }
+
+ pub fn default_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
}
- Some(self.active_model.as_ref()?.provider.clone())
+ self.default_model.clone()
+ }
+
+ pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
+ self.inline_assistant_model
+ .clone()
+ .or_else(|| self.default_model())
}
- pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
- self.active_model.as_ref()?.model.clone()
+ pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
+ self.commit_message_model
+ .clone()
+ .or_else(|| self.default_model())
}
- pub fn editor_model(&self) -> Option<Arc<dyn LanguageModel>> {
- self.editor_model.as_ref()?.model.clone()
+ pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
+ self.thread_summary_model
+ .clone()
+ .or_else(|| self.default_model())
}
/// Selects and sets the inline alternatives for language models based on
@@ -168,11 +168,11 @@ impl LanguageModelSelector {
}
fn get_active_model_index(cx: &App) -> usize {
- let active_model = LanguageModelRegistry::read_global(cx).active_model();
+ let active_model = LanguageModelRegistry::read_global(cx).default_model();
Self::all_models(cx)
.iter()
.position(|model_info| {
- Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
+ Some(model_info.model.id()) == active_model.as_ref().map(|model| model.model.id())
})
.unwrap_or(0)
}
@@ -406,13 +406,10 @@ impl PickerDelegate for LanguageModelPickerDelegate {
let model_info = self.filtered_models.get(ix)?;
let provider_name: String = model_info.model.provider_name().0.clone().into();
- let active_provider_id = LanguageModelRegistry::read_global(cx)
- .active_provider()
- .map(|m| m.id());
+ let active_model = LanguageModelRegistry::read_global(cx).default_model();
- let active_model_id = LanguageModelRegistry::read_global(cx)
- .active_model()
- .map(|m| m.id());
+ let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
+ let active_model_id = active_model.map(|m| m.model.id());
let is_selected = Some(model_info.model.provider_id()) == active_provider_id
&& Some(model_info.model.id()) == active_model_id;
@@ -9,7 +9,7 @@ use gpui::{
};
use language::{Buffer, LanguageRegistry, language_settings::SoftWrap};
use language_model::{
- LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
+ ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use picker::{Picker, PickerDelegate};
use release_channel::ReleaseChannel;
@@ -777,7 +777,9 @@ impl PromptLibrary {
};
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
- let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
+ let Some(ConfiguredModel { provider, .. }) =
+ LanguageModelRegistry::read_global(cx).inline_assistant_model()
+ else {
return;
};
@@ -880,7 +882,9 @@ impl PromptLibrary {
}
fn count_tokens(&mut self, prompt_id: PromptId, window: &mut Window, cx: &mut Context<Self>) {
- let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+ let Some(ConfiguredModel { model, .. }) =
+ LanguageModelRegistry::read_global(cx).default_model()
+ else {
return;
};
if let Some(prompt) = self.prompt_editors.get_mut(&prompt_id) {
@@ -967,7 +971,9 @@ impl PromptLibrary {
let prompt_metadata = self.store.metadata(prompt_id)?;
let prompt_editor = &self.prompt_editors[&prompt_id];
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
- let model = LanguageModelRegistry::read_global(cx).active_model();
+ let model = LanguageModelRegistry::read_global(cx)
+ .default_model()
+ .map(|default| default.model);
let settings = ThemeSettings::get_global(cx);
Some(
@@ -19,7 +19,8 @@ To further customize providers, you can use `settings.json` to do that as follow
- [Configuring endpoints](#custom-endpoint)
- [Configuring timeouts](#provider-timeout)
-- [Configuring default model](#default-model)
+- [Configuring models](#default-model)
+- [Configuring feature-specific models](#feature-specific-models)
- [Configuring alternative models for inline assists](#alternative-assists)
### Zed AI {#zed-ai}
@@ -281,8 +282,24 @@ Example configuration for using X.ai Grok with Zed:
"enabled": true,
"default_model": {
"provider": "zed.dev",
+ "model": "claude-3-7-sonnet"
+ },
+ "editor_model": {
+ "provider": "openai",
+ "model": "gpt-4o"
+ },
+ "inline_assistant_model": {
+ "provider": "anthropic",
"model": "claude-3-5-sonnet"
},
+ "commit_message_model": {
+ "provider": "openai",
+ "model": "gpt-4o-mini"
+ },
+ "thread_summary_model": {
+ "provider": "google",
+ "model": "gemini-1.5-flash"
+ },
"version": "2",
"button": true,
"default_width": 480,
@@ -328,7 +345,7 @@ To do so, add the following to your Zed `settings.json`:
Where `some-provider` can be any of the following values: `anthropic`, `google`, `ollama`, `openai`.
-#### Configuring the default model {#default-model}
+#### Configuring models {#default-model}
The default model can be set via the model dropdown in the assistant panel's top-right corner. Selecting a model saves it as the default.
You can also manually edit the `default_model` object in your settings:
@@ -345,6 +362,47 @@ You can also manually edit the `default_model` object in your settings:
}
```
+#### Feature-specific models {#feature-specific-models}
+
+> Currently only available in [Preview](https://zed.dev/releases/preview).
+
+Zed allows you to configure different models for specific features.
+This provides flexibility to use more powerful models for certain tasks while using faster or more efficient models for others.
+
+If a feature-specific model is not set, it will fall back to using the default model, which is the one you set on the Agent Panel.
+
+You can configure the following feature-specific models:
+
+- Thread summary model: Used for generating thread summaries
+- Inline assistant model: Used for the inline assistant feature
+- Commit message model: Used for generating Git commit messages
+
+Example configuration:
+
+```json
+{
+ "assistant": {
+ "version": "2",
+ "default_model": {
+ "provider": "zed.dev",
+ "model": "claude-3-7-sonnet"
+ },
+ "inline_assistant_model": {
+ "provider": "anthropic",
+ "model": "claude-3-5-sonnet"
+ },
+ "commit_message_model": {
+ "provider": "openai",
+ "model": "gpt-4o-mini"
+ },
+ "thread_summary_model": {
+ "provider": "google",
+ "model": "gemini-2.0-flash"
+ }
+ }
+}
+```
+
#### Configuring alternative models for inline assists {#alternative-assists}
You can configure additional models that will be used to perform inline assists in parallel. When you do this,