Detailed changes
@@ -199,6 +199,10 @@ impl MessageEditor {
)
});
+ let profile_selector = cx.new(|cx| {
+ ProfileSelector::new(thread.clone(), thread_store, editor.focus_handle(cx), cx)
+ });
+
Self {
editor: editor.clone(),
project: thread.read(cx).project().clone(),
@@ -215,8 +219,7 @@ impl MessageEditor {
model_selector,
edits_expanded: false,
editor_is_expanded: false,
- profile_selector: cx
- .new(|cx| ProfileSelector::new(fs, thread_store, editor.focus_handle(cx), cx)),
+ profile_selector,
last_estimated_token_count: None,
update_token_count_task: None,
_subscriptions: subscriptions,
@@ -1,24 +1,21 @@
-use std::sync::Arc;
-
use assistant_settings::{
AgentProfile, AgentProfileId, AssistantDockPosition, AssistantSettings, GroupedAgentProfiles,
builtin_profiles,
};
-use fs::Fs;
use gpui::{Action, Entity, FocusHandle, Subscription, WeakEntity, prelude::*};
use language_model::LanguageModelRegistry;
-use settings::{Settings as _, SettingsStore, update_settings_file};
+use settings::{Settings as _, SettingsStore};
use ui::{
ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, Tooltip,
prelude::*,
};
use util::ResultExt as _;
-use crate::{ManageProfiles, ThreadStore, ToggleProfileSelector};
+use crate::{ManageProfiles, Thread, ThreadStore, ToggleProfileSelector};
pub struct ProfileSelector {
profiles: GroupedAgentProfiles,
- fs: Arc<dyn Fs>,
+ thread: Entity<Thread>,
thread_store: WeakEntity<ThreadStore>,
menu_handle: PopoverMenuHandle<ContextMenu>,
focus_handle: FocusHandle,
@@ -27,7 +24,7 @@ pub struct ProfileSelector {
impl ProfileSelector {
pub fn new(
- fs: Arc<dyn Fs>,
+ thread: Entity<Thread>,
thread_store: WeakEntity<ThreadStore>,
focus_handle: FocusHandle,
cx: &mut Context<Self>,
@@ -38,7 +35,7 @@ impl ProfileSelector {
Self {
profiles: GroupedAgentProfiles::from_settings(AssistantSettings::get_global(cx)),
- fs,
+ thread,
thread_store,
menu_handle: PopoverMenuHandle::default(),
focus_handle,
@@ -113,15 +110,15 @@ impl ProfileSelector {
};
entry.handler({
- let fs = self.fs.clone();
let thread_store = self.thread_store.clone();
let profile_id = profile_id.clone();
+ let profile = profile.clone();
+
+ let thread = self.thread.clone();
+
move |_window, cx| {
- update_settings_file::<AssistantSettings>(fs.clone(), cx, {
- let profile_id = profile_id.clone();
- move |settings, _cx| {
- settings.set_profile(profile_id.clone());
- }
+ thread.update(cx, |thread, cx| {
+ thread.set_configured_profile(Some(profile.clone()), cx);
});
thread_store
@@ -137,17 +134,28 @@ impl ProfileSelector {
impl Render for ProfileSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let settings = AssistantSettings::get_global(cx);
- let profile_id = &settings.default_profile;
- let profile = settings.profiles.get(profile_id);
+ let profile = self
+ .thread
+ .read_with(cx, |thread, _cx| thread.configured_profile())
+ .or_else(|| {
+ let profile_id = &settings.default_profile;
+ let profile = settings.profiles.get(profile_id);
+ profile.cloned()
+ });
let selected_profile = profile
.map(|profile| profile.name.clone())
.unwrap_or_else(|| "Unknown".into());
- let model_registry = LanguageModelRegistry::read_global(cx);
- let supports_tools = model_registry
- .default_model()
- .map_or(false, |default| default.model.supports_tools());
+ let configured_model = self
+ .thread
+ .read_with(cx, |thread, _cx| thread.configured_model())
+ .or_else(|| {
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ model_registry.default_model()
+ });
+ let supports_tools =
+ configured_model.map_or(false, |default| default.model.supports_tools());
if supports_tools {
let this = cx.entity().clone();
@@ -5,7 +5,7 @@ use std::sync::Arc;
use std::time::Instant;
use anyhow::{Result, anyhow};
-use assistant_settings::{AssistantSettings, CompletionMode};
+use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
@@ -359,6 +359,7 @@ pub struct Thread {
>,
remaining_turns: u32,
configured_model: Option<ConfiguredModel>,
+ configured_profile: Option<AgentProfile>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -379,6 +380,9 @@ impl Thread {
) -> Self {
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
+ let assistant_settings = AssistantSettings::get_global(cx);
+ let profile_id = &assistant_settings.default_profile;
+ let configured_profile = assistant_settings.profiles.get(profile_id).cloned();
Self {
id: ThreadId::new(),
@@ -421,6 +425,7 @@ impl Thread {
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
+ configured_profile,
}
}
@@ -468,6 +473,13 @@ impl Thread {
.completion_mode
.unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
+ let configured_profile = serialized.profile.and_then(|profile| {
+ AssistantSettings::get_global(cx)
+ .profiles
+ .get(&profile)
+ .cloned()
+ });
+
Self {
id,
updated_at: serialized.updated_at,
@@ -541,6 +553,7 @@ impl Thread {
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
+ configured_profile,
}
}
@@ -596,6 +609,19 @@ impl Thread {
cx.notify();
}
+ pub fn configured_profile(&self) -> Option<AgentProfile> {
+ self.configured_profile.clone()
+ }
+
+ pub fn set_configured_profile(
+ &mut self,
+ profile: Option<AgentProfile>,
+ cx: &mut Context<Self>,
+ ) {
+ self.configured_profile = profile;
+ cx.notify();
+ }
+
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
pub fn summary_or_default(&self) -> SharedString {
@@ -1100,6 +1126,10 @@ impl Thread {
provider: model.provider.id().0.to_string(),
model: model.model.id().0.to_string(),
}),
+ profile: this
+ .configured_profile
+ .as_ref()
+ .map(|profile| AgentProfileId(profile.name.clone().into())),
completion_mode: Some(this.completion_mode),
})
})
@@ -657,6 +657,8 @@ pub struct SerializedThread {
pub model: Option<SerializedLanguageModel>,
#[serde(default)]
pub completion_mode: Option<CompletionMode>,
+ #[serde(default)]
+ pub profile: Option<AgentProfileId>,
}
#[derive(Serialize, Deserialize, Debug)]
@@ -802,6 +804,7 @@ impl LegacySerializedThread {
exceeded_window_error: None,
model: None,
completion_mode: None,
+ profile: None,
}
}
}