diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index ccfa42a63d2e0f4cce4f419639b12151010a939d..8a0f17faa64a242114bdd3c2f0713ed116331230 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -199,6 +199,10 @@ impl MessageEditor { ) }); + let profile_selector = cx.new(|cx| { + ProfileSelector::new(thread.clone(), thread_store, editor.focus_handle(cx), cx) + }); + Self { editor: editor.clone(), project: thread.read(cx).project().clone(), @@ -215,8 +219,7 @@ impl MessageEditor { model_selector, edits_expanded: false, editor_is_expanded: false, - profile_selector: cx - .new(|cx| ProfileSelector::new(fs, thread_store, editor.focus_handle(cx), cx)), + profile_selector, last_estimated_token_count: None, update_token_count_task: None, _subscriptions: subscriptions, diff --git a/crates/agent/src/profile_selector.rs b/crates/agent/src/profile_selector.rs index 51eb20934488a3463ce3e113c747757bc4f7dfa6..8879c9b3c7ae5a1fab5df4a8430b1235222de6f8 100644 --- a/crates/agent/src/profile_selector.rs +++ b/crates/agent/src/profile_selector.rs @@ -1,24 +1,21 @@ -use std::sync::Arc; - use assistant_settings::{ AgentProfile, AgentProfileId, AssistantDockPosition, AssistantSettings, GroupedAgentProfiles, builtin_profiles, }; -use fs::Fs; use gpui::{Action, Entity, FocusHandle, Subscription, WeakEntity, prelude::*}; use language_model::LanguageModelRegistry; -use settings::{Settings as _, SettingsStore, update_settings_file}; +use settings::{Settings as _, SettingsStore}; use ui::{ ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*, }; use util::ResultExt as _; -use crate::{ManageProfiles, ThreadStore, ToggleProfileSelector}; +use crate::{ManageProfiles, Thread, ThreadStore, ToggleProfileSelector}; pub struct ProfileSelector { profiles: GroupedAgentProfiles, - fs: Arc, + thread: Entity, thread_store: WeakEntity, menu_handle: PopoverMenuHandle, focus_handle: FocusHandle, @@ -27,7 +24,7 @@ pub struct ProfileSelector { impl ProfileSelector { pub fn new( - fs: Arc, + thread: Entity, thread_store: WeakEntity, focus_handle: FocusHandle, cx: &mut Context, @@ -38,7 +35,7 @@ impl ProfileSelector { Self { profiles: GroupedAgentProfiles::from_settings(AssistantSettings::get_global(cx)), - fs, + thread, thread_store, menu_handle: PopoverMenuHandle::default(), focus_handle, @@ -113,15 +110,15 @@ impl ProfileSelector { }; entry.handler({ - let fs = self.fs.clone(); let thread_store = self.thread_store.clone(); let profile_id = profile_id.clone(); + let profile = profile.clone(); + + let thread = self.thread.clone(); + move |_window, cx| { - update_settings_file::(fs.clone(), cx, { - let profile_id = profile_id.clone(); - move |settings, _cx| { - settings.set_profile(profile_id.clone()); - } + thread.update(cx, |thread, cx| { + thread.set_configured_profile(Some(profile.clone()), cx); }); thread_store @@ -137,17 +134,28 @@ impl ProfileSelector { impl Render for ProfileSelector { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { let settings = AssistantSettings::get_global(cx); - let profile_id = &settings.default_profile; - let profile = settings.profiles.get(profile_id); + let profile = self + .thread + .read_with(cx, |thread, _cx| thread.configured_profile()) + .or_else(|| { + let profile_id = &settings.default_profile; + let profile = settings.profiles.get(profile_id); + profile.cloned() + }); let selected_profile = profile .map(|profile| profile.name.clone()) .unwrap_or_else(|| "Unknown".into()); - let model_registry = LanguageModelRegistry::read_global(cx); - let supports_tools = model_registry - .default_model() - .map_or(false, |default| default.model.supports_tools()); + let configured_model = self + .thread + .read_with(cx, |thread, _cx| thread.configured_model()) + .or_else(|| { + let model_registry = LanguageModelRegistry::read_global(cx); + model_registry.default_model() + }); + let supports_tools = + configured_model.map_or(false, |default| default.model.supports_tools()); if supports_tools { let this = cx.entity().clone(); diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index f82a7a339a06f4804b8349ce3895d739b018a541..8c4d4a06da40def09847c980063049083302085d 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use std::time::Instant; use anyhow::{Result, anyhow}; -use assistant_settings::{AssistantSettings, CompletionMode}; +use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::HashMap; @@ -359,6 +359,7 @@ pub struct Thread { >, remaining_turns: u32, configured_model: Option, + configured_profile: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -379,6 +380,9 @@ impl Thread { ) -> Self { let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel(); let configured_model = LanguageModelRegistry::read_global(cx).default_model(); + let assistant_settings = AssistantSettings::get_global(cx); + let profile_id = &assistant_settings.default_profile; + let configured_profile = assistant_settings.profiles.get(profile_id).cloned(); Self { id: ThreadId::new(), @@ -421,6 +425,7 @@ impl Thread { request_callback: None, remaining_turns: u32::MAX, configured_model, + configured_profile, } } @@ -468,6 +473,13 @@ impl Thread { .completion_mode .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode); + let configured_profile = serialized.profile.and_then(|profile| { + AssistantSettings::get_global(cx) + .profiles + .get(&profile) + .cloned() + }); + Self { id, updated_at: serialized.updated_at, @@ -541,6 +553,7 @@ impl Thread { request_callback: None, remaining_turns: u32::MAX, configured_model, + configured_profile, } } @@ -596,6 +609,19 @@ impl Thread { cx.notify(); } + pub fn configured_profile(&self) -> Option { + self.configured_profile.clone() + } + + pub fn set_configured_profile( + &mut self, + profile: Option, + cx: &mut Context, + ) { + self.configured_profile = profile; + cx.notify(); + } + pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread"); pub fn summary_or_default(&self) -> SharedString { @@ -1100,6 +1126,10 @@ impl Thread { provider: model.provider.id().0.to_string(), model: model.model.id().0.to_string(), }), + profile: this + .configured_profile + .as_ref() + .map(|profile| AgentProfileId(profile.name.clone().into())), completion_mode: Some(this.completion_mode), }) }) diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 99ecd3d4420fdbc374ed7875dad43880dd967136..62482a822f964a0ed9fdc6f4fa8647a4a5ffd119 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -657,6 +657,8 @@ pub struct SerializedThread { pub model: Option, #[serde(default)] pub completion_mode: Option, + #[serde(default)] + pub profile: Option, } #[derive(Serialize, Deserialize, Debug)] @@ -802,6 +804,7 @@ impl LegacySerializedThread { exceeded_window_error: None, model: None, completion_mode: None, + profile: None, } } }