Detailed changes
@@ -43,6 +43,7 @@ pub struct UserMessage {
pub content: ContentBlock,
pub chunks: Vec<acp::ContentBlock>,
pub checkpoint: Option<Checkpoint>,
+ pub indented: bool,
}
#[derive(Debug)]
@@ -73,6 +74,7 @@ impl UserMessage {
#[derive(Debug, PartialEq)]
pub struct AssistantMessage {
pub chunks: Vec<AssistantMessageChunk>,
+ pub indented: bool,
}
impl AssistantMessage {
@@ -123,6 +125,14 @@ pub enum AgentThreadEntry {
}
impl AgentThreadEntry {
+ pub fn is_indented(&self) -> bool {
+ match self {
+ Self::UserMessage(message) => message.indented,
+ Self::AssistantMessage(message) => message.indented,
+ Self::ToolCall(_) => false,
+ }
+ }
+
pub fn to_markdown(&self, cx: &App) -> String {
match self {
Self::UserMessage(message) => message.to_markdown(cx),
@@ -1184,6 +1194,16 @@ impl AcpThread {
message_id: Option<UserMessageId>,
chunk: acp::ContentBlock,
cx: &mut Context<Self>,
+ ) {
+ self.push_user_content_block_with_indent(message_id, chunk, false, cx)
+ }
+
+ pub fn push_user_content_block_with_indent(
+ &mut self,
+ message_id: Option<UserMessageId>,
+ chunk: acp::ContentBlock,
+ indented: bool,
+ cx: &mut Context<Self>,
) {
let language_registry = self.project.read(cx).languages().clone();
let path_style = self.project.read(cx).path_style(cx);
@@ -1194,8 +1214,10 @@ impl AcpThread {
id,
content,
chunks,
+ indented: existing_indented,
..
}) = last_entry
+ && *existing_indented == indented
{
*id = message_id.or(id.take());
content.append(chunk.clone(), &language_registry, path_style, cx);
@@ -1210,6 +1232,7 @@ impl AcpThread {
content,
chunks: vec![chunk],
checkpoint: None,
+ indented,
}),
cx,
);
@@ -1221,12 +1244,26 @@ impl AcpThread {
chunk: acp::ContentBlock,
is_thought: bool,
cx: &mut Context<Self>,
+ ) {
+ self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
+ }
+
+ pub fn push_assistant_content_block_with_indent(
+ &mut self,
+ chunk: acp::ContentBlock,
+ is_thought: bool,
+ indented: bool,
+ cx: &mut Context<Self>,
) {
let language_registry = self.project.read(cx).languages().clone();
let path_style = self.project.read(cx).path_style(cx);
let entries_len = self.entries.len();
if let Some(last_entry) = self.entries.last_mut()
- && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
+ && let AgentThreadEntry::AssistantMessage(AssistantMessage {
+ chunks,
+ indented: existing_indented,
+ }) = last_entry
+ && *existing_indented == indented
{
let idx = entries_len - 1;
cx.emit(AcpThreadEvent::EntryUpdated(idx));
@@ -1255,6 +1292,7 @@ impl AcpThread {
self.push_entry(
AgentThreadEntry::AssistantMessage(AssistantMessage {
chunks: vec![chunk],
+ indented,
}),
cx,
);
@@ -1704,6 +1742,7 @@ impl AcpThread {
content: block,
chunks: message,
checkpoint: None,
+ indented: false,
}),
cx,
);
@@ -5,12 +5,12 @@ mod legacy_thread;
mod native_agent_server;
pub mod outline;
mod templates;
-mod thread;
-mod tools;
-
#[cfg(test)]
mod tests;
+mod thread;
+mod tools;
+use context_server::ContextServerId;
pub use db::*;
pub use history_store::*;
pub use native_agent_server::NativeAgentServer;
@@ -18,11 +18,11 @@ pub use templates::*;
pub use thread::*;
pub use tools::*;
-use acp_thread::{AcpThread, AgentModelSelector};
+use acp_thread::{AcpThread, AgentModelSelector, UserMessageId};
use agent_client_protocol as acp;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
-use collections::{HashSet, IndexMap};
+use collections::{HashMap, HashSet, IndexMap};
use fs::Fs;
use futures::channel::{mpsc, oneshot};
use futures::future::Shared;
@@ -39,7 +39,6 @@ use prompt_store::{
use serde::{Deserialize, Serialize};
use settings::{LanguageModelSelection, update_settings_file};
use std::any::Any;
-use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::sync::Arc;
@@ -252,12 +251,24 @@ impl NativeAgent {
.await;
cx.new(|cx| {
+ let context_server_store = project.read(cx).context_server_store();
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+
let mut subscriptions = vec![
cx.subscribe(&project, Self::handle_project_event),
cx.subscribe(
&LanguageModelRegistry::global(cx),
Self::handle_models_updated_event,
),
+ cx.subscribe(
+ &context_server_store,
+ Self::handle_context_server_store_updated,
+ ),
+ cx.subscribe(
+ &context_server_registry,
+ Self::handle_context_server_registry_event,
+ ),
];
if let Some(prompt_store) = prompt_store.as_ref() {
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
@@ -266,16 +277,14 @@ impl NativeAgent {
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
watch::channel(());
Self {
- sessions: HashMap::new(),
+ sessions: HashMap::default(),
history,
project_context: cx.new(|_| project_context),
project_context_needs_refresh: project_context_needs_refresh_tx,
_maintain_project_context: cx.spawn(async move |this, cx| {
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
}),
- context_server_registry: cx.new(|cx| {
- ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
- }),
+ context_server_registry,
templates,
models: LanguageModels::new(cx),
project,
@@ -344,6 +353,9 @@ impl NativeAgent {
pending_save: Task::ready(()),
},
);
+
+ self.update_available_commands(cx);
+
acp_thread
}
@@ -608,6 +620,99 @@ impl NativeAgent {
}
}
+ fn handle_context_server_store_updated(
+ &mut self,
+ _store: Entity<project::context_server_store::ContextServerStore>,
+ _event: &project::context_server_store::Event,
+ cx: &mut Context<Self>,
+ ) {
+ self.update_available_commands(cx);
+ }
+
+ fn handle_context_server_registry_event(
+ &mut self,
+ _registry: Entity<ContextServerRegistry>,
+ event: &ContextServerRegistryEvent,
+ cx: &mut Context<Self>,
+ ) {
+ match event {
+ ContextServerRegistryEvent::ToolsChanged => {}
+ ContextServerRegistryEvent::PromptsChanged => {
+ self.update_available_commands(cx);
+ }
+ }
+ }
+
+ fn update_available_commands(&self, cx: &mut Context<Self>) {
+ let available_commands = self.build_available_commands(cx);
+ for session in self.sessions.values() {
+ if let Some(acp_thread) = session.acp_thread.upgrade() {
+ acp_thread.update(cx, |thread, cx| {
+ thread
+ .handle_session_update(
+ acp::SessionUpdate::AvailableCommandsUpdate(
+ acp::AvailableCommandsUpdate::new(available_commands.clone()),
+ ),
+ cx,
+ )
+ .log_err();
+ });
+ }
+ }
+ }
+
+ fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
+ let registry = self.context_server_registry.read(cx);
+
+ let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
+ for context_server_prompt in registry.prompts() {
+ *prompt_name_counts
+ .entry(context_server_prompt.prompt.name.as_str())
+ .or_insert(0) += 1;
+ }
+
+ registry
+ .prompts()
+ .flat_map(|context_server_prompt| {
+ let prompt = &context_server_prompt.prompt;
+
+ let should_prefix = prompt_name_counts
+ .get(prompt.name.as_str())
+ .copied()
+ .unwrap_or(0)
+ > 1;
+
+ let name = if should_prefix {
+ format!("{}.{}", context_server_prompt.server_id, prompt.name)
+ } else {
+ prompt.name.clone()
+ };
+
+ let mut command = acp::AvailableCommand::new(
+ name,
+ prompt.description.clone().unwrap_or_default(),
+ );
+
+ match prompt.arguments.as_deref() {
+ Some([arg]) => {
+ let hint = format!("<{}>", arg.name);
+
+ command = command.input(acp::AvailableCommandInput::Unstructured(
+ acp::UnstructuredCommandInput::new(hint),
+ ));
+ }
+ Some([]) | None => {}
+ Some(_) => {
+ // skip >1 argument commands since we don't support them yet
+ return None;
+ }
+ }
+
+ Some(command)
+ })
+ .collect()
+ }
+
pub fn load_thread(
&mut self,
id: acp::SessionId,
@@ -706,6 +811,102 @@ impl NativeAgent {
history.update(cx, |history, cx| history.reload(cx)).ok();
});
}
+
+ fn send_mcp_prompt(
+ &self,
+ message_id: UserMessageId,
+ session_id: agent_client_protocol::SessionId,
+ prompt_name: String,
+ server_id: ContextServerId,
+ arguments: HashMap<String, String>,
+ original_content: Vec<acp::ContentBlock>,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<acp::PromptResponse>> {
+ let server_store = self.context_server_registry.read(cx).server_store().clone();
+ let path_style = self.project.read(cx).path_style(cx);
+
+ cx.spawn(async move |this, cx| {
+ let prompt =
+ crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
+
+ let (acp_thread, thread) = this.update(cx, |this, _cx| {
+ let session = this
+ .sessions
+ .get(&session_id)
+ .context("Failed to get session")?;
+ anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
+ })??;
+
+ let mut last_is_user = true;
+
+ thread.update(cx, |thread, cx| {
+ thread.push_acp_user_block(
+ message_id,
+ original_content.into_iter().skip(1),
+ path_style,
+ cx,
+ );
+ })?;
+
+ for message in prompt.messages {
+ let context_server::types::PromptMessage { role, content } = message;
+ let block = mcp_message_content_to_acp_content_block(content);
+
+ match role {
+ context_server::types::Role::User => {
+ let id = acp_thread::UserMessageId::new();
+
+ acp_thread.update(cx, |acp_thread, cx| {
+ acp_thread.push_user_content_block_with_indent(
+ Some(id.clone()),
+ block.clone(),
+ true,
+ cx,
+ );
+ anyhow::Ok(())
+ })??;
+
+ thread.update(cx, |thread, cx| {
+ thread.push_acp_user_block(id, [block], path_style, cx);
+ anyhow::Ok(())
+ })??;
+ }
+ context_server::types::Role::Assistant => {
+ acp_thread.update(cx, |acp_thread, cx| {
+ acp_thread.push_assistant_content_block_with_indent(
+ block.clone(),
+ false,
+ true,
+ cx,
+ );
+ anyhow::Ok(())
+ })??;
+
+ thread.update(cx, |thread, cx| {
+ thread.push_acp_agent_block(block, cx);
+ anyhow::Ok(())
+ })??;
+ }
+ }
+
+ last_is_user = role == context_server::types::Role::User;
+ }
+
+ let response_stream = thread.update(cx, |thread, cx| {
+ if last_is_user {
+ thread.send_existing(cx)
+ } else {
+ // Resume if MCP prompt did not end with a user message
+ thread.resume(cx)
+ }
+ })??;
+
+ cx.update(|cx| {
+ NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
+ })?
+ .await
+ })
+ }
}
/// Wrapper struct that implements the AgentConnection trait
@@ -840,6 +1041,39 @@ impl NativeAgentConnection {
}
}
+struct Command<'a> {
+ prompt_name: &'a str,
+ arg_value: &'a str,
+ explicit_server_id: Option<&'a str>,
+}
+
+impl<'a> Command<'a> {
+ fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
+ let acp::ContentBlock::Text(text_content) = prompt.first()? else {
+ return None;
+ };
+ let text = text_content.text.trim();
+ let command = text.strip_prefix('/')?;
+ let (command, arg_value) = command
+ .split_once(char::is_whitespace)
+ .unwrap_or((command, ""));
+
+ if let Some((server_id, prompt_name)) = command.split_once('.') {
+ Some(Self {
+ prompt_name,
+ arg_value,
+ explicit_server_id: Some(server_id),
+ })
+ } else {
+ Some(Self {
+ prompt_name: command,
+ arg_value,
+ explicit_server_id: None,
+ })
+ }
+ }
+}
+
struct NativeAgentModelSelector {
session_id: acp::SessionId,
connection: NativeAgentConnection,
@@ -1005,6 +1239,47 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
let session_id = params.session_id.clone();
log::info!("Received prompt request for session: {}", session_id);
log::debug!("Prompt blocks count: {}", params.prompt.len());
+
+ if let Some(parsed_command) = Command::parse(¶ms.prompt) {
+ let registry = self.0.read(cx).context_server_registry.read(cx);
+
+ let explicit_server_id = parsed_command
+ .explicit_server_id
+ .map(|server_id| ContextServerId(server_id.into()));
+
+ if let Some(prompt) =
+ registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
+ {
+ let arguments = if !parsed_command.arg_value.is_empty()
+ && let Some(arg_name) = prompt
+ .prompt
+ .arguments
+ .as_ref()
+ .and_then(|args| args.first())
+ .map(|arg| arg.name.clone())
+ {
+ HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
+ } else {
+ Default::default()
+ };
+
+ let prompt_name = prompt.prompt.name.clone();
+ let server_id = prompt.server_id.clone();
+
+ return self.0.update(cx, |agent, cx| {
+ agent.send_mcp_prompt(
+ id,
+ session_id.clone(),
+ prompt_name,
+ server_id,
+ arguments,
+ params.prompt,
+ cx,
+ )
+ });
+ };
+ };
+
let path_style = self.0.read(cx).project.read(cx).path_style(cx);
self.run_turn(session_id, cx, move |thread, cx| {
@@ -1601,3 +1876,35 @@ mod internal_tests {
});
}
}
+
+fn mcp_message_content_to_acp_content_block(
+ content: context_server::types::MessageContent,
+) -> acp::ContentBlock {
+ match content {
+ context_server::types::MessageContent::Text {
+ text,
+ annotations: _,
+ } => text.into(),
+ context_server::types::MessageContent::Image {
+ data,
+ mime_type,
+ annotations: _,
+ } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
+ context_server::types::MessageContent::Audio {
+ data,
+ mime_type,
+ annotations: _,
+ } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
+ context_server::types::MessageContent::Resource {
+ resource,
+ annotations: _,
+ } => {
+ let mut link =
+ acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
+ if let Some(mime_type) = resource.mime_type {
+ link = link.mime_type(mime_type);
+ }
+ acp::ContentBlock::ResourceLink(link)
+ }
+ }
+}
@@ -108,7 +108,13 @@ impl Message {
pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
match self {
- Message::User(message) => vec![message.to_request()],
+ Message::User(message) => {
+ if message.content.is_empty() {
+ vec![]
+ } else {
+ vec![message.to_request()]
+ }
+ }
Message::Agent(message) => message.to_request(),
Message::Resume => vec![LanguageModelRequestMessage {
role: Role::User,
@@ -1141,20 +1147,64 @@ impl Thread {
where
T: Into<UserMessageContent>,
{
+ let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
+ log::debug!("Thread::send content: {:?}", content);
+
+ self.messages
+ .push(Message::User(UserMessage { id, content }));
+ cx.notify();
+
+ self.send_existing(cx)
+ }
+
+ pub fn send_existing(
+ &mut self,
+ cx: &mut Context<Self>,
+ ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
let model = self.model().context("No language model configured")?;
log::info!("Thread::send called with model: {}", model.name().0);
self.advance_prompt_id();
- let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
- log::debug!("Thread::send content: {:?}", content);
+ log::debug!("Total messages in thread: {}", self.messages.len());
+ self.run_turn(cx)
+ }
+ pub fn push_acp_user_block(
+ &mut self,
+ id: UserMessageId,
+ blocks: impl IntoIterator<Item = acp::ContentBlock>,
+ path_style: PathStyle,
+ cx: &mut Context<Self>,
+ ) {
+ let content = blocks
+ .into_iter()
+ .map(|block| UserMessageContent::from_content_block(block, path_style))
+ .collect::<Vec<_>>();
self.messages
.push(Message::User(UserMessage { id, content }));
cx.notify();
+ }
- log::debug!("Total messages in thread: {}", self.messages.len());
- self.run_turn(cx)
+ pub fn push_acp_agent_block(&mut self, block: acp::ContentBlock, cx: &mut Context<Self>) {
+ let text = match block {
+ acp::ContentBlock::Text(text_content) => text_content.text,
+ acp::ContentBlock::Image(_) => "[image]".to_string(),
+ acp::ContentBlock::Audio(_) => "[audio]".to_string(),
+ acp::ContentBlock::ResourceLink(resource_link) => resource_link.uri,
+ acp::ContentBlock::Resource(resource) => match resource.resource {
+ acp::EmbeddedResourceResource::TextResourceContents(resource) => resource.uri,
+ acp::EmbeddedResourceResource::BlobResourceContents(resource) => resource.uri,
+ _ => "[resource]".to_string(),
+ },
+ _ => "[unknown]".to_string(),
+ };
+
+ self.messages.push(Message::Agent(AgentMessage {
+ content: vec![AgentMessageContent::Text(text)],
+ ..Default::default()
+ }));
+ cx.notify();
}
#[cfg(feature = "eval")]
@@ -3,11 +3,23 @@ use agent_client_protocol::ToolKind;
use anyhow::{Result, anyhow, bail};
use collections::{BTreeMap, HashMap};
use context_server::ContextServerId;
-use gpui::{App, Context, Entity, SharedString, Task};
+use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
use project::context_server_store::{ContextServerStatus, ContextServerStore};
use std::sync::Arc;
use util::ResultExt;
+pub struct ContextServerPrompt {
+ pub server_id: ContextServerId,
+ pub prompt: context_server::types::Prompt,
+}
+
+pub enum ContextServerRegistryEvent {
+ ToolsChanged,
+ PromptsChanged,
+}
+
+impl EventEmitter<ContextServerRegistryEvent> for ContextServerRegistry {}
+
pub struct ContextServerRegistry {
server_store: Entity<ContextServerStore>,
registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
@@ -16,7 +28,20 @@ pub struct ContextServerRegistry {
struct RegisteredContextServer {
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
+ prompts: BTreeMap<SharedString, ContextServerPrompt>,
load_tools: Task<Result<()>>,
+ load_prompts: Task<Result<()>>,
+}
+
+impl RegisteredContextServer {
+ fn new() -> Self {
+ Self {
+ tools: BTreeMap::default(),
+ prompts: BTreeMap::default(),
+ load_tools: Task::ready(Ok(())),
+ load_prompts: Task::ready(Ok(())),
+ }
+ }
}
impl ContextServerRegistry {
@@ -28,6 +53,7 @@ impl ContextServerRegistry {
};
for server in server_store.read(cx).running_servers() {
this.reload_tools_for_server(server.id(), cx);
+ this.reload_prompts_for_server(server.id(), cx);
}
this
}
@@ -56,6 +82,41 @@ impl ContextServerRegistry {
.map(|(id, server)| (id, &server.tools))
}
+ pub fn prompts(&self) -> impl Iterator<Item = &ContextServerPrompt> {
+ self.registered_servers
+ .values()
+ .flat_map(|server| server.prompts.values())
+ }
+
+ pub fn find_prompt(
+ &self,
+ server_id: Option<&ContextServerId>,
+ name: &str,
+ ) -> Option<&ContextServerPrompt> {
+ if let Some(server_id) = server_id {
+ self.registered_servers
+ .get(server_id)
+ .and_then(|server| server.prompts.get(name))
+ } else {
+ self.registered_servers
+ .values()
+ .find_map(|server| server.prompts.get(name))
+ }
+ }
+
+ pub fn server_store(&self) -> &Entity<ContextServerStore> {
+ &self.server_store
+ }
+
+ fn get_or_register_server(
+ &mut self,
+ server_id: &ContextServerId,
+ ) -> &mut RegisteredContextServer {
+ self.registered_servers
+ .entry(server_id.clone())
+ .or_insert_with(RegisteredContextServer::new)
+ }
+
fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
return;
@@ -67,13 +128,7 @@ impl ContextServerRegistry {
return;
}
- let registered_server =
- self.registered_servers
- .entry(server_id.clone())
- .or_insert(RegisteredContextServer {
- tools: BTreeMap::default(),
- load_tools: Task::ready(Ok(())),
- });
+ let registered_server = self.get_or_register_server(&server_id);
registered_server.load_tools = cx.spawn(async move |this, cx| {
let response = client
.request::<context_server::types::requests::ListTools>(())
@@ -94,6 +149,49 @@ impl ContextServerRegistry {
));
registered_server.tools.insert(tool.name(), tool);
}
+ cx.emit(ContextServerRegistryEvent::ToolsChanged);
+ cx.notify();
+ }
+ })
+ });
+ }
+
+ fn reload_prompts_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
+ let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
+ return;
+ };
+ let Some(client) = server.client() else {
+ return;
+ };
+ if !client.capable(context_server::protocol::ServerCapability::Prompts) {
+ return;
+ }
+
+ let registered_server = self.get_or_register_server(&server_id);
+
+ registered_server.load_prompts = cx.spawn(async move |this, cx| {
+ let response = client
+ .request::<context_server::types::requests::PromptsList>(())
+ .await;
+
+ this.update(cx, |this, cx| {
+ let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
+ return;
+ };
+
+ registered_server.prompts.clear();
+ if let Some(response) = response.log_err() {
+ for prompt in response.prompts {
+ let name: SharedString = prompt.name.clone().into();
+ registered_server.prompts.insert(
+ name,
+ ContextServerPrompt {
+ server_id: server_id.clone(),
+ prompt,
+ },
+ );
+ }
+ cx.emit(ContextServerRegistryEvent::PromptsChanged);
cx.notify();
}
})
@@ -112,9 +210,17 @@ impl ContextServerRegistry {
ContextServerStatus::Starting => {}
ContextServerStatus::Running => {
self.reload_tools_for_server(server_id.clone(), cx);
+ self.reload_prompts_for_server(server_id.clone(), cx);
}
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
- self.registered_servers.remove(server_id);
+ if let Some(registered_server) = self.registered_servers.remove(server_id) {
+ if !registered_server.tools.is_empty() {
+ cx.emit(ContextServerRegistryEvent::ToolsChanged);
+ }
+ if !registered_server.prompts.is_empty() {
+ cx.emit(ContextServerRegistryEvent::PromptsChanged);
+ }
+ }
cx.notify();
}
}
@@ -251,3 +357,39 @@ impl AnyAgentTool for ContextServerTool {
Ok(())
}
}
+
+pub fn get_prompt(
+ server_store: &Entity<ContextServerStore>,
+ server_id: &ContextServerId,
+ prompt_name: &str,
+ arguments: HashMap<String, String>,
+ cx: &mut AsyncApp,
+) -> Task<Result<context_server::types::PromptsGetResponse>> {
+ let server = match cx.update(|cx| server_store.read(cx).get_running_server(server_id)) {
+ Ok(server) => server,
+ Err(error) => return Task::ready(Err(error)),
+ };
+ let Some(server) = server else {
+ return Task::ready(Err(anyhow::anyhow!("Context server not found")));
+ };
+
+ let Some(protocol) = server.client() else {
+ return Task::ready(Err(anyhow::anyhow!("Context server not initialized")));
+ };
+
+ let prompt_name = prompt_name.to_string();
+
+ cx.background_spawn(async move {
+ let response = protocol
+ .request::<context_server::types::requests::PromptsGet>(
+ context_server::types::PromptsGetParams {
+ name: prompt_name,
+ arguments: (!arguments.is_empty()).then(|| arguments),
+ meta: None,
+ },
+ )
+ .await?;
+
+ Ok(response)
+ })
+}
@@ -1315,7 +1315,7 @@ impl AcpThreadView {
})?;
anyhow::Ok(())
})
- .detach();
+ .detach_and_log_err(cx);
}
fn open_edited_buffer(
@@ -1940,6 +1940,16 @@ impl AcpThreadView {
window: &mut Window,
cx: &Context<Self>,
) -> AnyElement {
+ let is_indented = entry.is_indented();
+ let is_first_indented = is_indented
+ && self.thread().is_some_and(|thread| {
+ thread
+ .read(cx)
+ .entries()
+ .get(entry_ix.saturating_sub(1))
+ .is_none_or(|entry| !entry.is_indented())
+ });
+
let primary = match &entry {
AgentThreadEntry::UserMessage(message) => {
let Some(editor) = self
@@ -1972,7 +1982,9 @@ impl AcpThreadView {
v_flex()
.id(("user_message", entry_ix))
.map(|this| {
- if entry_ix == 0 && !has_checkpoint_button && rules_item.is_none() {
+ if is_first_indented {
+ this.pt_0p5()
+ } else if entry_ix == 0 && !has_checkpoint_button && rules_item.is_none() {
this.pt(rems_from_px(18.))
} else if rules_item.is_some() {
this.pt_3()
@@ -2018,6 +2030,9 @@ impl AcpThreadView {
.shadow_md()
.bg(cx.theme().colors().editor_background)
.border_1()
+ .when(is_indented, |this| {
+ this.py_2().px_2().shadow_sm()
+ })
.when(editing && !editor_focus, |this| this.border_dashed())
.border_color(cx.theme().colors().border)
.map(|this|{
@@ -2112,7 +2127,10 @@ impl AcpThreadView {
)
.into_any()
}
- AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) => {
+ AgentThreadEntry::AssistantMessage(AssistantMessage {
+ chunks,
+ indented: _,
+ }) => {
let is_last = entry_ix + 1 == total_entries;
let style = default_markdown_style(false, false, window, cx);
@@ -2146,6 +2164,7 @@ impl AcpThreadView {
v_flex()
.px_5()
.py_1p5()
+ .when(is_first_indented, |this| this.pt_0p5())
.when(is_last, |this| this.pb_4())
.w_full()
.text_ui(cx)
@@ -2155,19 +2174,48 @@ impl AcpThreadView {
AgentThreadEntry::ToolCall(tool_call) => {
let has_terminals = tool_call.terminals().next().is_some();
- div().w_full().map(|this| {
- if has_terminals {
- this.children(tool_call.terminals().map(|terminal| {
- self.render_terminal_tool_call(
- entry_ix, terminal, tool_call, window, cx,
- )
- }))
- } else {
- this.child(self.render_tool_call(entry_ix, tool_call, window, cx))
- }
- })
+ div()
+ .w_full()
+ .map(|this| {
+ if has_terminals {
+ this.children(tool_call.terminals().map(|terminal| {
+ self.render_terminal_tool_call(
+ entry_ix, terminal, tool_call, window, cx,
+ )
+ }))
+ } else {
+ this.child(self.render_tool_call(entry_ix, tool_call, window, cx))
+ }
+ })
+ .into_any()
}
- .into_any(),
+ };
+
+ let primary = if is_indented {
+ let line_top = if is_first_indented {
+ rems_from_px(-12.0)
+ } else {
+ rems_from_px(0.0)
+ };
+
+ div()
+ .relative()
+ .w_full()
+ .pl(rems_from_px(20.0))
+ .bg(cx.theme().colors().panel_background.opacity(0.2))
+ .child(
+ div()
+ .absolute()
+ .left(rems_from_px(18.0))
+ .top(line_top)
+ .bottom_0()
+ .w_px()
+ .bg(cx.theme().colors().border.opacity(0.6)),
+ )
+ .child(primary)
+ .into_any_element()
+ } else {
+ primary
};
let needs_confirmation = if let AgentThreadEntry::ToolCall(tool_call) = entry {
@@ -330,7 +330,7 @@ pub struct PromptMessage {
pub content: MessageContent,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,