From 4e6e424fd78f2c024eff88efaaca4878379a5472 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Mon, 22 Sep 2025 17:07:40 +0200 Subject: [PATCH] acp: Support model selection for ACP agents (#38652) It requires the agent to implement the (still unstable) model selection API. Will allow us to test it out before stabilizing. Release Notes: - N/A --- Cargo.lock | 9 +- Cargo.toml | 2 +- crates/acp_thread/src/connection.rs | 51 +++--- crates/agent2/src/agent.rs | 119 ++++++++------ crates/agent2/src/tests/mod.rs | 29 ++-- crates/agent_servers/src/acp.rs | 103 +++++++++++- crates/agent_ui/src/acp/model_selector.rs | 154 +++++++++++------- .../src/acp/model_selector_popover.rs | 4 +- crates/agent_ui/src/acp/thread_view.rs | 32 ++-- crates/picker/Cargo.toml | 1 + crates/picker/src/picker.rs | 67 +++++++- crates/ui/src/components/context_menu.rs | 6 +- tooling/workspace-hack/Cargo.toml | 4 +- 13 files changed, 391 insertions(+), 190 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dbe2467499ad1c5d6f67c4de82546e2b560451bb..e51968b0262a91d3a1ed78a10656e75b9d0d4523 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -195,9 +195,9 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.4.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2526e80463b9742afed4829aedd6ae5632d6db778c6cc1fecb80c960c3521b" +checksum = "00e33b9f4bd34d342b6f80b7156d3a37a04aeec16313f264001e52d6a9118600" dependencies = [ "anyhow", "async-broadcast", @@ -4932,7 +4932,7 @@ dependencies = [ "libc", "option-ext", "redox_users 0.5.0", - "windows-sys 0.60.2", + "windows-sys 0.61.0", ] [[package]] @@ -12677,6 +12677,7 @@ dependencies = [ "schemars 1.0.1", "serde", "serde_json", + "theme", "ui", "workspace", "workspace-hack", @@ -20853,7 +20854,7 @@ dependencies = [ "windows-sys 0.48.0", "windows-sys 0.52.0", "windows-sys 0.59.0", - "windows-sys 0.60.2", + "windows-sys 0.61.0", "winnow", "zeroize", "zvariant", diff --git a/Cargo.toml b/Cargo.toml index d4812908ac8292caf8371ce1d6dd9c9ee4042ca0..fd552c6e9d117bd03b251f231dee8294b02ba928 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -439,7 +439,7 @@ zlog_settings = { path = "crates/zlog_settings" } # External crates # -agent-client-protocol = { version = "0.4.0", features = ["unstable"] } +agent-client-protocol = { version = "0.4.2", features = ["unstable"] } aho-corasick = "1.1" alacritty_terminal = "0.25.1-rc1" any_vec = "0.14" diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 10c9dd22b6ec476f17fabeae7f6bd4f1a9672db7..fe66f954370f8118d054ee56f1e9f68f2de7e6f4 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -68,7 +68,7 @@ pub trait AgentConnection { /// /// If the agent does not support model selection, returns [None]. /// This allows sharing the selector in UI components. - fn model_selector(&self) -> Option> { + fn model_selector(&self, _session_id: &acp::SessionId) -> Option> { None } @@ -177,61 +177,48 @@ pub trait AgentModelSelector: 'static { /// If the session doesn't exist or the model is invalid, it returns an error. /// /// # Parameters - /// - `session_id`: The ID of the session (thread) to apply the model to. /// - `model`: The model to select (should be one from [list_models]). /// - `cx`: The GPUI app context. /// /// # Returns /// A task resolving to `Ok(())` on success or an error. - fn select_model( - &self, - session_id: acp::SessionId, - model_id: AgentModelId, - cx: &mut App, - ) -> Task>; + fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task>; /// Retrieves the currently selected model for a specific session (thread). /// /// # Parameters - /// - `session_id`: The ID of the session (thread) to query. /// - `cx`: The GPUI app context. /// /// # Returns /// A task resolving to the selected model (always set) or an error (e.g., session not found). - fn selected_model( - &self, - session_id: &acp::SessionId, - cx: &mut App, - ) -> Task>; + fn selected_model(&self, cx: &mut App) -> Task>; /// Whenever the model list is updated the receiver will be notified. - fn watch(&self, cx: &mut App) -> watch::Receiver<()>; -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct AgentModelId(pub SharedString); - -impl std::ops::Deref for AgentModelId { - type Target = SharedString; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl fmt::Display for AgentModelId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) + /// Optional for agents that don't update their model list. + fn watch(&self, _cx: &mut App) -> Option> { + None } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct AgentModelInfo { - pub id: AgentModelId, + pub id: acp::ModelId, pub name: SharedString, + pub description: Option, pub icon: Option, } +impl From for AgentModelInfo { + fn from(info: acp::ModelInfo) -> Self { + Self { + id: info.model_id, + name: info.name.into(), + description: info.description.map(|desc| desc.into()), + icon: None, + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct AgentModelGroupName(pub SharedString); diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 86fb50242c64917248df5c620782af066e639b54..36ab1be9ef79221b530258c4fdd55be2ac1e8b29 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -56,7 +56,7 @@ struct Session { pub struct LanguageModels { /// Access language model by ID - models: HashMap>, + models: HashMap>, /// Cached list for returning language model information model_list: acp_thread::AgentModelList, refresh_models_rx: watch::Receiver<()>, @@ -132,10 +132,7 @@ impl LanguageModels { self.refresh_models_rx.clone() } - pub fn model_from_id( - &self, - model_id: &acp_thread::AgentModelId, - ) -> Option> { + pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option> { self.models.get(model_id).cloned() } @@ -146,12 +143,13 @@ impl LanguageModels { acp_thread::AgentModelInfo { id: Self::model_id(model), name: model.name().0, + description: None, icon: Some(provider.icon()), } } - fn model_id(model: &Arc) -> acp_thread::AgentModelId { - acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into()) + fn model_id(model: &Arc) -> acp::ModelId { + acp::ModelId(format!("{}/{}", model.provider_id().0, model.id().0).into()) } fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> { @@ -836,10 +834,15 @@ impl NativeAgentConnection { } } -impl AgentModelSelector for NativeAgentConnection { +struct NativeAgentModelSelector { + session_id: acp::SessionId, + connection: NativeAgentConnection, +} + +impl acp_thread::AgentModelSelector for NativeAgentModelSelector { fn list_models(&self, cx: &mut App) -> Task> { log::debug!("NativeAgentConnection::list_models called"); - let list = self.0.read(cx).models.model_list.clone(); + let list = self.connection.0.read(cx).models.model_list.clone(); Task::ready(if list.is_empty() { Err(anyhow::anyhow!("No models available")) } else { @@ -847,24 +850,24 @@ impl AgentModelSelector for NativeAgentConnection { }) } - fn select_model( - &self, - session_id: acp::SessionId, - model_id: acp_thread::AgentModelId, - cx: &mut App, - ) -> Task> { - log::debug!("Setting model for session {}: {}", session_id, model_id); + fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task> { + log::debug!( + "Setting model for session {}: {}", + self.session_id, + model_id + ); let Some(thread) = self + .connection .0 .read(cx) .sessions - .get(&session_id) + .get(&self.session_id) .map(|session| session.thread.clone()) else { return Task::ready(Err(anyhow!("Session not found"))); }; - let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else { + let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else { return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); }; @@ -872,33 +875,32 @@ impl AgentModelSelector for NativeAgentConnection { thread.set_model(model.clone(), cx); }); - update_settings_file(self.0.read(cx).fs.clone(), cx, move |settings, _cx| { - let provider = model.provider_id().0.to_string(); - let model = model.id().0.to_string(); - settings - .agent - .get_or_insert_default() - .set_model(LanguageModelSelection { - provider: provider.into(), - model, - }); - }); + update_settings_file( + self.connection.0.read(cx).fs.clone(), + cx, + move |settings, _cx| { + let provider = model.provider_id().0.to_string(); + let model = model.id().0.to_string(); + settings + .agent + .get_or_insert_default() + .set_model(LanguageModelSelection { + provider: provider.into(), + model, + }); + }, + ); Task::ready(Ok(())) } - fn selected_model( - &self, - session_id: &acp::SessionId, - cx: &mut App, - ) -> Task> { - let session_id = session_id.clone(); - + fn selected_model(&self, cx: &mut App) -> Task> { let Some(thread) = self + .connection .0 .read(cx) .sessions - .get(&session_id) + .get(&self.session_id) .map(|session| session.thread.clone()) else { return Task::ready(Err(anyhow!("Session not found"))); @@ -915,8 +917,8 @@ impl AgentModelSelector for NativeAgentConnection { ))) } - fn watch(&self, cx: &mut App) -> watch::Receiver<()> { - self.0.read(cx).models.watch() + fn watch(&self, cx: &mut App) -> Option> { + Some(self.connection.0.read(cx).models.watch()) } } @@ -972,8 +974,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection { Task::ready(Ok(())) } - fn model_selector(&self) -> Option> { - Some(Rc::new(self.clone()) as Rc) + fn model_selector(&self, session_id: &acp::SessionId) -> Option> { + Some(Rc::new(NativeAgentModelSelector { + session_id: session_id.clone(), + connection: self.clone(), + }) as Rc) } fn prompt( @@ -1196,9 +1201,7 @@ mod tests { use crate::HistoryEntryId; use super::*; - use acp_thread::{ - AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri, - }; + use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri}; use fs::FakeFs; use gpui::TestAppContext; use indoc::indoc; @@ -1292,7 +1295,25 @@ mod tests { .unwrap(), ); - let models = cx.update(|cx| connection.list_models(cx)).await.unwrap(); + // Create a thread/session + let acp_thread = cx + .update(|cx| { + Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx) + }) + .await + .unwrap(); + + let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone()); + + let models = cx + .update(|cx| { + connection + .model_selector(&session_id) + .unwrap() + .list_models(cx) + }) + .await + .unwrap(); let acp_thread::AgentModelList::Grouped(models) = models else { panic!("Unexpected model group"); @@ -1302,8 +1323,9 @@ mod tests { IndexMap::from_iter([( AgentModelGroupName("Fake".into()), vec![AgentModelInfo { - id: AgentModelId("fake/fake".into()), + id: acp::ModelId("fake/fake".into()), name: "Fake".into(), + description: None, icon: Some(ui::IconName::ZedAssistant), }] )]) @@ -1360,8 +1382,9 @@ mod tests { let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone()); // Select a model - let model_id = AgentModelId("fake/fake".into()); - cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx)) + let selector = connection.model_selector(&session_id).unwrap(); + let model_id = acp::ModelId("fake/fake".into()); + cx.update(|cx| selector.select_model(model_id.clone(), cx)) .await .unwrap(); diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index c0f693afe6dc0decdce4447471191bd78cf345f1..2e63aa5856501f880fec94f7659b13be321b03b3 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1850,8 +1850,18 @@ async fn test_agent_connection(cx: &mut TestAppContext) { .unwrap(); let connection = NativeAgentConnection(agent.clone()); + // Create a thread using new_thread + let connection_rc = Rc::new(connection.clone()); + let acp_thread = cx + .update(|cx| connection_rc.new_thread(project, cwd, cx)) + .await + .expect("new_thread should succeed"); + + // Get the session_id from the AcpThread + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + // Test model_selector returns Some - let selector_opt = connection.model_selector(); + let selector_opt = connection.model_selector(&session_id); assert!( selector_opt.is_some(), "agent2 should always support ModelSelector" @@ -1868,23 +1878,16 @@ async fn test_agent_connection(cx: &mut TestAppContext) { }; assert!(!listed_models.is_empty(), "should have at least one model"); assert_eq!( - listed_models[&AgentModelGroupName("Fake".into())][0].id.0, + listed_models[&AgentModelGroupName("Fake".into())][0] + .id + .0 + .as_ref(), "fake/fake" ); - // Create a thread using new_thread - let connection_rc = Rc::new(connection.clone()); - let acp_thread = cx - .update(|cx| connection_rc.new_thread(project, cwd, cx)) - .await - .expect("new_thread should succeed"); - - // Get the session_id from the AcpThread - let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); - // Test selected_model returns the default let model = cx - .update(|cx| selector.selected_model(&session_id, cx)) + .update(|cx| selector.selected_model(cx)) .await .expect("selected_model should succeed"); let model = cx diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index b8c75a01a2e2965c255e32bd3c0746b26d78ecab..b14c0467c58d3f41e32e602996560e2cc672d76a 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -44,6 +44,7 @@ pub struct AcpConnection { pub struct AcpSession { thread: WeakEntity, suppress_abort_err: bool, + models: Option>>, session_modes: Option>>, } @@ -264,6 +265,7 @@ impl AgentConnection for AcpConnection { })?; let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes))); + let models = response.models.map(|models| Rc::new(RefCell::new(models))); if let Some(default_mode) = default_mode { if let Some(modes) = modes.as_ref() { @@ -326,10 +328,12 @@ impl AgentConnection for AcpConnection { ) })?; + let session = AcpSession { thread: thread.downgrade(), suppress_abort_err: false, - session_modes: modes + session_modes: modes, + models, }; sessions.borrow_mut().insert(session_id, session); @@ -450,6 +454,27 @@ impl AgentConnection for AcpConnection { } } + fn model_selector( + &self, + session_id: &acp::SessionId, + ) -> Option> { + let sessions = self.sessions.clone(); + let sessions_ref = sessions.borrow(); + let Some(session) = sessions_ref.get(session_id) else { + return None; + }; + + if let Some(models) = session.models.as_ref() { + Some(Rc::new(AcpModelSelector::new( + session_id.clone(), + self.connection.clone(), + models.clone(), + )) as _) + } else { + None + } + } + fn into_any(self: Rc) -> Rc { self } @@ -500,6 +525,82 @@ impl acp_thread::AgentSessionModes for AcpSessionModes { } } +struct AcpModelSelector { + session_id: acp::SessionId, + connection: Rc, + state: Rc>, +} + +impl AcpModelSelector { + fn new( + session_id: acp::SessionId, + connection: Rc, + state: Rc>, + ) -> Self { + Self { + session_id, + connection, + state, + } + } +} + +impl acp_thread::AgentModelSelector for AcpModelSelector { + fn list_models(&self, _cx: &mut App) -> Task> { + Task::ready(Ok(acp_thread::AgentModelList::Flat( + self.state + .borrow() + .available_models + .clone() + .into_iter() + .map(acp_thread::AgentModelInfo::from) + .collect(), + ))) + } + + fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task> { + let connection = self.connection.clone(); + let session_id = self.session_id.clone(); + let old_model_id; + { + let mut state = self.state.borrow_mut(); + old_model_id = state.current_model_id.clone(); + state.current_model_id = model_id.clone(); + }; + let state = self.state.clone(); + cx.foreground_executor().spawn(async move { + let result = connection + .set_session_model(acp::SetSessionModelRequest { + session_id, + model_id, + meta: None, + }) + .await; + + if result.is_err() { + state.borrow_mut().current_model_id = old_model_id; + } + + result?; + + Ok(()) + }) + } + + fn selected_model(&self, _cx: &mut App) -> Task> { + let state = self.state.borrow(); + Task::ready( + state + .available_models + .iter() + .find(|m| m.model_id == state.current_model_id) + .cloned() + .map(acp_thread::AgentModelInfo::from) + .ok_or_else(|| anyhow::anyhow!("Model not found")), + ) + } +} + struct ClientDelegate { sessions: Rc>>, cx: AsyncApp, diff --git a/crates/agent_ui/src/acp/model_selector.rs b/crates/agent_ui/src/acp/model_selector.rs index 95c0478aa3cf6b1ca78cf391a5bd734820c41454..381bdb01edec49e222c9bd9b3a97ce9ba21a9789 100644 --- a/crates/agent_ui/src/acp/model_selector.rs +++ b/crates/agent_ui/src/acp/model_selector.rs @@ -1,7 +1,6 @@ use std::{cmp::Reverse, rc::Rc, sync::Arc}; use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector}; -use agent_client_protocol as acp; use anyhow::Result; use collections::IndexMap; use futures::FutureExt; @@ -10,20 +9,19 @@ use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, W use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; use ui::{ - AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window, - prelude::*, rems, + AnyElement, App, Context, DocumentationAside, DocumentationEdge, DocumentationSide, + IntoElement, ListItem, ListItemSpacing, SharedString, Window, prelude::*, rems, }; use util::ResultExt; pub type AcpModelSelector = Picker; pub fn acp_model_selector( - session_id: acp::SessionId, selector: Rc, window: &mut Window, cx: &mut Context, ) -> AcpModelSelector { - let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx); + let delegate = AcpModelPickerDelegate::new(selector, window, cx); Picker::list(delegate, window, cx) .show_scrollbar(true) .width(rems(20.)) @@ -36,61 +34,63 @@ enum AcpModelPickerEntry { } pub struct AcpModelPickerDelegate { - session_id: acp::SessionId, selector: Rc, filtered_entries: Vec, models: Option, selected_index: usize, + selected_description: Option<(usize, SharedString)>, selected_model: Option, _refresh_models_task: Task<()>, } impl AcpModelPickerDelegate { fn new( - session_id: acp::SessionId, selector: Rc, window: &mut Window, cx: &mut Context, ) -> Self { - let mut rx = selector.watch(cx); - let refresh_models_task = cx.spawn_in(window, { - let session_id = session_id.clone(); - async move |this, cx| { - async fn refresh( - this: &WeakEntity>, - session_id: &acp::SessionId, - cx: &mut AsyncWindowContext, - ) -> Result<()> { - let (models_task, selected_model_task) = this.update(cx, |this, cx| { - ( - this.delegate.selector.list_models(cx), - this.delegate.selector.selected_model(session_id, cx), - ) - })?; - - let (models, selected_model) = futures::join!(models_task, selected_model_task); + let rx = selector.watch(cx); + let refresh_models_task = { + cx.spawn_in(window, { + async move |this, cx| { + async fn refresh( + this: &WeakEntity>, + cx: &mut AsyncWindowContext, + ) -> Result<()> { + let (models_task, selected_model_task) = this.update(cx, |this, cx| { + ( + this.delegate.selector.list_models(cx), + this.delegate.selector.selected_model(cx), + ) + })?; - this.update_in(cx, |this, window, cx| { - this.delegate.models = models.ok(); - this.delegate.selected_model = selected_model.ok(); - this.refresh(window, cx) - }) - } + let (models, selected_model) = + futures::join!(models_task, selected_model_task); - refresh(&this, &session_id, cx).await.log_err(); - while let Ok(()) = rx.recv().await { - refresh(&this, &session_id, cx).await.log_err(); + this.update_in(cx, |this, window, cx| { + this.delegate.models = models.ok(); + this.delegate.selected_model = selected_model.ok(); + this.refresh(window, cx) + }) + } + + refresh(&this, cx).await.log_err(); + if let Some(mut rx) = rx { + while let Ok(()) = rx.recv().await { + refresh(&this, cx).await.log_err(); + } + } } - } - }); + }) + }; Self { - session_id, selector, filtered_entries: Vec::new(), models: None, selected_model: None, selected_index: 0, + selected_description: None, _refresh_models_task: refresh_models_task, } } @@ -182,7 +182,7 @@ impl PickerDelegate for AcpModelPickerDelegate { self.filtered_entries.get(self.selected_index) { self.selector - .select_model(self.session_id.clone(), model_info.id.clone(), cx) + .select_model(model_info.id.clone(), cx) .detach_and_log_err(cx); self.selected_model = Some(model_info.clone()); let current_index = self.selected_index; @@ -233,31 +233,46 @@ impl PickerDelegate for AcpModelPickerDelegate { }; Some( - ListItem::new(ix) - .inset(true) - .spacing(ListItemSpacing::Sparse) - .toggle_state(selected) - .start_slot::(model_info.icon.map(|icon| { - Icon::new(icon) - .color(model_icon_color) - .size(IconSize::Small) - })) + div() + .id(("model-picker-menu-child", ix)) + .when_some(model_info.description.clone(), |this, description| { + this + .on_hover(cx.listener(move |menu, hovered, _, cx| { + if *hovered { + menu.delegate.selected_description = Some((ix, description.clone())); + } else if matches!(menu.delegate.selected_description, Some((id, _)) if id == ix) { + menu.delegate.selected_description = None; + } + cx.notify(); + })) + }) .child( - h_flex() - .w_full() - .pl_0p5() - .gap_1p5() - .w(px(240.)) - .child(Label::new(model_info.name.clone()).truncate()), + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot::(model_info.icon.map(|icon| { + Icon::new(icon) + .color(model_icon_color) + .size(IconSize::Small) + })) + .child( + h_flex() + .w_full() + .pl_0p5() + .gap_1p5() + .w(px(240.)) + .child(Label::new(model_info.name.clone()).truncate()), + ) + .end_slot(div().pr_3().when(is_selected, |this| { + this.child( + Icon::new(IconName::Check) + .color(Color::Accent) + .size(IconSize::Small), + ) + })), ) - .end_slot(div().pr_3().when(is_selected, |this| { - this.child( - Icon::new(IconName::Check) - .color(Color::Accent) - .size(IconSize::Small), - ) - })) - .into_any_element(), + .into_any_element() ) } } @@ -292,6 +307,21 @@ impl PickerDelegate for AcpModelPickerDelegate { .into_any(), ) } + + fn documentation_aside( + &self, + _window: &mut Window, + _cx: &mut Context>, + ) -> Option { + self.selected_description.as_ref().map(|(_, description)| { + let description = description.clone(); + DocumentationAside::new( + DocumentationSide::Left, + DocumentationEdge::Bottom, + Rc::new(move |_| Label::new(description.clone()).into_any_element()), + ) + }) + } } fn info_list_to_picker_entries( @@ -371,6 +401,7 @@ async fn fuzzy_search( #[cfg(test)] mod tests { + use agent_client_protocol as acp; use gpui::TestAppContext; use super::*; @@ -383,8 +414,9 @@ mod tests { models .into_iter() .map(|model| acp_thread::AgentModelInfo { - id: acp_thread::AgentModelId(model.to_string().into()), + id: acp::ModelId(model.to_string().into()), name: model.to_string().into(), + description: None, icon: None, }) .collect::>(), diff --git a/crates/agent_ui/src/acp/model_selector_popover.rs b/crates/agent_ui/src/acp/model_selector_popover.rs index fa771c695ecf8175859d145b8d08d2cf3447a77a..55f530c81b1cead74fd4ec4f6cc29ececcf2bf7e 100644 --- a/crates/agent_ui/src/acp/model_selector_popover.rs +++ b/crates/agent_ui/src/acp/model_selector_popover.rs @@ -1,7 +1,6 @@ use std::rc::Rc; use acp_thread::AgentModelSelector; -use agent_client_protocol as acp; use gpui::{Entity, FocusHandle}; use picker::popover_menu::PickerPopoverMenu; use ui::{ @@ -20,7 +19,6 @@ pub struct AcpModelSelectorPopover { impl AcpModelSelectorPopover { pub(crate) fn new( - session_id: acp::SessionId, selector: Rc, menu_handle: PopoverMenuHandle, focus_handle: FocusHandle, @@ -28,7 +26,7 @@ impl AcpModelSelectorPopover { cx: &mut Context, ) -> Self { Self { - selector: cx.new(move |cx| acp_model_selector(session_id, selector, window, cx)), + selector: cx.new(move |cx| acp_model_selector(selector, window, cx)), menu_handle, focus_handle, } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 8658e2c285997c18ece2b9783c25fbcaa614dc83..391486a68eca87e238f9efb88288bc970e3eb412 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -577,23 +577,21 @@ impl AcpThreadView { AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); - this.model_selector = - thread - .read(cx) - .connection() - .model_selector() - .map(|selector| { - cx.new(|cx| { - AcpModelSelectorPopover::new( - thread.read(cx).session_id().clone(), - selector, - PopoverMenuHandle::default(), - this.focus_handle(cx), - window, - cx, - ) - }) - }); + this.model_selector = thread + .read(cx) + .connection() + .model_selector(thread.read(cx).session_id()) + .map(|selector| { + cx.new(|cx| { + AcpModelSelectorPopover::new( + selector, + PopoverMenuHandle::default(), + this.focus_handle(cx), + window, + cx, + ) + }) + }); let mode_selector = thread .read(cx) diff --git a/crates/picker/Cargo.toml b/crates/picker/Cargo.toml index d785cb5b3a96502165b10e2bf0def0d8bf66cd67..23c867b6f30aa64d5916e8939d836dda27ebf6c9 100644 --- a/crates/picker/Cargo.toml +++ b/crates/picker/Cargo.toml @@ -22,6 +22,7 @@ gpui.workspace = true menu.workspace = true schemars.workspace = true serde.workspace = true +theme.workspace = true ui.workspace = true workspace.workspace = true workspace-hack.workspace = true diff --git a/crates/picker/src/picker.rs b/crates/picker/src/picker.rs index 8816fb5424ff25788cec9cb602d2960ab753c135..247fcbdd875ffc2e52d90d9b1309f874c508e588 100644 --- a/crates/picker/src/picker.rs +++ b/crates/picker/src/picker.rs @@ -18,11 +18,12 @@ use head::Head; use schemars::JsonSchema; use serde::Deserialize; use std::{ops::Range, sync::Arc, time::Duration}; +use theme::ThemeSettings; use ui::{ - Color, Divider, Label, ListItem, ListItemSpacing, ScrollAxes, Scrollbars, WithScrollbar, - prelude::*, v_flex, + Color, Divider, DocumentationAside, DocumentationEdge, DocumentationSide, Label, ListItem, + ListItemSpacing, ScrollAxes, Scrollbars, WithScrollbar, prelude::*, utils::WithRemSize, v_flex, }; -use workspace::ModalView; +use workspace::{ModalView, item::Settings}; enum ElementContainer { List(ListState), @@ -222,6 +223,14 @@ pub trait PickerDelegate: Sized + 'static { ) -> Option { None } + + fn documentation_aside( + &self, + _window: &mut Window, + _cx: &mut Context>, + ) -> Option { + None + } } impl Focusable for Picker { @@ -781,8 +790,15 @@ impl ModalView for Picker {} impl Render for Picker { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + let ui_font_size = ThemeSettings::get_global(cx).ui_font_size(cx); + let window_size = window.viewport_size(); + let rem_size = window.rem_size(); + let is_wide_window = window_size.width / rem_size > rems_from_px(800.).0; + + let aside = self.delegate.documentation_aside(window, cx); + let editor_position = self.delegate.editor_position(); - v_flex() + let menu = v_flex() .key_context("Picker") .size_full() .when_some(self.width, |el, width| el.w(width)) @@ -865,6 +881,47 @@ impl Render for Picker { } } Head::Empty(empty_head) => Some(div().child(empty_head.clone())), - }) + }); + + let Some(aside) = aside else { + return menu; + }; + + let render_aside = |aside: DocumentationAside, cx: &mut Context| { + WithRemSize::new(ui_font_size) + .occlude() + .elevation_2(cx) + .w_full() + .p_2() + .overflow_hidden() + .when(is_wide_window, |this| this.max_w_96()) + .when(!is_wide_window, |this| this.max_w_48()) + .child((aside.render)(cx)) + }; + + if is_wide_window { + div().relative().child(menu).child( + h_flex() + .absolute() + .when(aside.side == DocumentationSide::Left, |this| { + this.right_full().mr_1() + }) + .when(aside.side == DocumentationSide::Right, |this| { + this.left_full().ml_1() + }) + .when(aside.edge == DocumentationEdge::Top, |this| this.top_0()) + .when(aside.edge == DocumentationEdge::Bottom, |this| { + this.bottom_0() + }) + .child(render_aside(aside, cx)), + ) + } else { + v_flex() + .w_full() + .gap_1() + .justify_end() + .child(render_aside(aside, cx)) + .child(menu) + } } } diff --git a/crates/ui/src/components/context_menu.rs b/crates/ui/src/components/context_menu.rs index e57f02be915fdecec7a5af4894c6f4fdd72f48bc..7b61789b3c87d54ff231e1d635266d6502fb944f 100644 --- a/crates/ui/src/components/context_menu.rs +++ b/crates/ui/src/components/context_menu.rs @@ -180,9 +180,9 @@ pub enum DocumentationEdge { #[derive(Clone)] pub struct DocumentationAside { - side: DocumentationSide, - edge: DocumentationEdge, - render: Rc AnyElement>, + pub side: DocumentationSide, + pub edge: DocumentationEdge, + pub render: Rc AnyElement>, } impl DocumentationAside { diff --git a/tooling/workspace-hack/Cargo.toml b/tooling/workspace-hack/Cargo.toml index 68fd84b32b64e15b0ea63ef851ec5aac457179c2..342d675bf38c3f9233d3dee4f8eefd77bfbc7836 100644 --- a/tooling/workspace-hack/Cargo.toml +++ b/tooling/workspace-hack/Cargo.toml @@ -600,10 +600,10 @@ tower = { version = "0.5", default-features = false, features = ["timeout", "uti winapi = { version = "0.3", default-features = false, features = ["cfg", "commapi", "consoleapi", "evntrace", "fileapi", "handleapi", "impl-debug", "impl-default", "in6addr", "inaddr", "ioapiset", "knownfolders", "minwinbase", "minwindef", "namedpipeapi", "ntsecapi", "objbase", "processenv", "processthreadsapi", "shlobj", "std", "synchapi", "sysinfoapi", "timezoneapi", "winbase", "windef", "winerror", "winioctl", "winnt", "winreg", "winsock2", "winuser"] } windows-core = { version = "0.61" } windows-numerics = { version = "0.2" } -windows-sys-4db8c43aad08e7ae = { package = "windows-sys", version = "0.60", features = ["Win32_Globalization", "Win32_System_Com", "Win32_UI_Shell"] } windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security_Authorization", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Memory", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } +windows-sys-d4189bed749088b6 = { package = "windows-sys", version = "0.61", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_IO", "Win32_System_LibraryLoader", "Win32_System_Threading", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } [target.x86_64-pc-windows-msvc.build-dependencies] codespan-reporting = { version = "0.12" } @@ -627,10 +627,10 @@ tower = { version = "0.5", default-features = false, features = ["timeout", "uti winapi = { version = "0.3", default-features = false, features = ["cfg", "commapi", "consoleapi", "evntrace", "fileapi", "handleapi", "impl-debug", "impl-default", "in6addr", "inaddr", "ioapiset", "knownfolders", "minwinbase", "minwindef", "namedpipeapi", "ntsecapi", "objbase", "processenv", "processthreadsapi", "shlobj", "std", "synchapi", "sysinfoapi", "timezoneapi", "winbase", "windef", "winerror", "winioctl", "winnt", "winreg", "winsock2", "winuser"] } windows-core = { version = "0.61" } windows-numerics = { version = "0.2" } -windows-sys-4db8c43aad08e7ae = { package = "windows-sys", version = "0.60", features = ["Win32_Globalization", "Win32_System_Com", "Win32_UI_Shell"] } windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security_Authorization", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Memory", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } +windows-sys-d4189bed749088b6 = { package = "windows-sys", version = "0.61", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_IO", "Win32_System_LibraryLoader", "Win32_System_Threading", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } [target.x86_64-unknown-linux-musl.dependencies] aes = { version = "0.8", default-features = false, features = ["zeroize"] }