Cargo.lock 🔗
@@ -2815,6 +2815,7 @@ name = "context_servers"
version = "0.1.0"
dependencies = [
"anyhow",
+ "async-trait",
"collections",
"command_palette_hooks",
"futures 0.3.30",
Marshall Bowers created
This PR puts context servers behind the `ContextServer` trait to allow
us to provide context servers from an extension.
Release Notes:
- N/A
Cargo.lock | 1
crates/assistant/src/context_store.rs | 6
crates/assistant/src/slash_command/context_server_command.rs | 10
crates/assistant/src/tools/context_server_tool.rs | 8
crates/context_servers/Cargo.toml | 1
crates/context_servers/src/manager.rs | 133 +++--
6 files changed, 100 insertions(+), 59 deletions(-)
@@ -2815,6 +2815,7 @@ name = "context_servers"
version = "0.1.0"
dependencies = [
"anyhow",
+ "async-trait",
"collections",
"command_palette_hooks",
"futures 0.3.30",
@@ -819,7 +819,7 @@ impl ContextStore {
|context_server_manager, cx| {
for server in context_server_manager.servers() {
context_server_manager
- .restart_server(&server.id, cx)
+ .restart_server(&server.id(), cx)
.detach_and_log_err(cx);
}
},
@@ -850,7 +850,7 @@ impl ContextStore {
let server = server.clone();
let server_id = server_id.clone();
|this, mut cx| async move {
- let Some(protocol) = server.client.read().clone() else {
+ let Some(protocol) = server.client() else {
return;
};
@@ -889,7 +889,7 @@ impl ContextStore {
tool_working_set.insert(
Arc::new(tools::context_server_tool::ContextServerTool::new(
context_server_manager.clone(),
- server.id.clone(),
+ server.id(),
tool,
)),
)
@@ -20,18 +20,18 @@ use crate::slash_command::create_label_for_command;
pub struct ContextServerSlashCommand {
server_manager: Model<ContextServerManager>,
- server_id: String,
+ server_id: Arc<str>,
prompt: Prompt,
}
impl ContextServerSlashCommand {
pub fn new(
server_manager: Model<ContextServerManager>,
- server: &Arc<ContextServer>,
+ server: &Arc<dyn ContextServer>,
prompt: Prompt,
) -> Self {
Self {
- server_id: server.id.clone(),
+ server_id: server.id(),
prompt,
server_manager,
}
@@ -89,7 +89,7 @@ impl SlashCommand for ContextServerSlashCommand {
if let Some(server) = self.server_manager.read(cx).get_server(&server_id) {
cx.foreground_executor().spawn(async move {
- let Some(protocol) = server.client.read().clone() else {
+ let Some(protocol) = server.client() else {
return Err(anyhow!("Context server not initialized"));
};
@@ -143,7 +143,7 @@ impl SlashCommand for ContextServerSlashCommand {
let manager = self.server_manager.read(cx);
if let Some(server) = manager.get_server(&server_id) {
cx.foreground_executor().spawn(async move {
- let Some(protocol) = server.client.read().clone() else {
+ let Some(protocol) = server.client() else {
return Err(anyhow!("Context server not initialized"));
};
let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
@@ -1,3 +1,5 @@
+use std::sync::Arc;
+
use anyhow::{anyhow, bail};
use assistant_tool::Tool;
use context_servers::manager::ContextServerManager;
@@ -6,14 +8,14 @@ use gpui::{Model, Task};
pub struct ContextServerTool {
server_manager: Model<ContextServerManager>,
- server_id: String,
+ server_id: Arc<str>,
tool: types::Tool,
}
impl ContextServerTool {
pub fn new(
server_manager: Model<ContextServerManager>,
- server_id: impl Into<String>,
+ server_id: impl Into<Arc<str>>,
tool: types::Tool,
) -> Self {
Self {
@@ -55,7 +57,7 @@ impl Tool for ContextServerTool {
cx.foreground_executor().spawn({
let tool_name = self.tool.name.clone();
async move {
- let Some(protocol) = server.client.read().clone() else {
+ let Some(protocol) = server.client() else {
bail!("Context server not initialized");
};
@@ -13,6 +13,7 @@ path = "src/context_servers.rs"
[dependencies]
anyhow.workspace = true
+async-trait.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
futures.workspace = true
@@ -15,9 +15,13 @@
//! and react to changes in settings.
use std::path::Path;
+use std::pin::Pin;
use std::sync::Arc;
+use anyhow::Result;
+use async_trait::async_trait;
use collections::{HashMap, HashSet};
+use futures::{Future, FutureExt};
use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
use log;
use parking_lot::RwLock;
@@ -56,51 +60,84 @@ impl Settings for ContextServerSettings {
}
}
-pub struct ContextServer {
- pub id: String,
- pub config: ServerConfig,
+#[async_trait(?Send)]
+pub trait ContextServer: Send + Sync + 'static {
+ fn id(&self) -> Arc<str>;
+ fn config(&self) -> Arc<ServerConfig>;
+ fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>>;
+ fn start<'a>(
+ self: Arc<Self>,
+ cx: &'a AsyncAppContext,
+ ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>>;
+ fn stop(&self) -> Result<()>;
+}
+
+pub struct NativeContextServer {
+ pub id: Arc<str>,
+ pub config: Arc<ServerConfig>,
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
}
-impl ContextServer {
- fn new(config: ServerConfig) -> Self {
+impl NativeContextServer {
+ fn new(config: Arc<ServerConfig>) -> Self {
Self {
- id: config.id.clone(),
+ id: config.id.clone().into(),
config,
client: RwLock::new(None),
}
}
+}
- async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> {
- log::info!("starting context server {}", self.config.id,);
- let client = Client::new(
- client::ContextServerId(self.config.id.clone()),
- client::ModelContextServerBinary {
- executable: Path::new(&self.config.executable).to_path_buf(),
- args: self.config.args.clone(),
- env: self.config.env.clone(),
- },
- cx.clone(),
- )?;
-
- let protocol = crate::protocol::ModelContextProtocol::new(client);
- let client_info = types::Implementation {
- name: "Zed".to_string(),
- version: env!("CARGO_PKG_VERSION").to_string(),
- };
- let initialized_protocol = protocol.initialize(client_info).await?;
+#[async_trait(?Send)]
+impl ContextServer for NativeContextServer {
+ fn id(&self) -> Arc<str> {
+ self.id.clone()
+ }
- log::debug!(
- "context server {} initialized: {:?}",
- self.config.id,
- initialized_protocol.initialize,
- );
+ fn config(&self) -> Arc<ServerConfig> {
+ self.config.clone()
+ }
- *self.client.write() = Some(Arc::new(initialized_protocol));
- Ok(())
+ fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
+ self.client.read().clone()
+ }
+
+ fn start<'a>(
+ self: Arc<Self>,
+ cx: &'a AsyncAppContext,
+ ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
+ async move {
+ log::info!("starting context server {}", self.config.id,);
+ let client = Client::new(
+ client::ContextServerId(self.config.id.clone()),
+ client::ModelContextServerBinary {
+ executable: Path::new(&self.config.executable).to_path_buf(),
+ args: self.config.args.clone(),
+ env: self.config.env.clone(),
+ },
+ cx.clone(),
+ )?;
+
+ let protocol = crate::protocol::ModelContextProtocol::new(client);
+ let client_info = types::Implementation {
+ name: "Zed".to_string(),
+ version: env!("CARGO_PKG_VERSION").to_string(),
+ };
+ let initialized_protocol = protocol.initialize(client_info).await?;
+
+ log::debug!(
+ "context server {} initialized: {:?}",
+ self.config.id,
+ initialized_protocol.initialize,
+ );
+
+ *self.client.write() = Some(Arc::new(initialized_protocol));
+ Ok(())
+ }
+ .boxed_local()
}
- async fn stop(&self) -> anyhow::Result<()> {
+ fn stop(&self) -> Result<()> {
let mut client = self.client.write();
if let Some(protocol) = client.take() {
drop(protocol);
@@ -114,7 +151,7 @@ impl ContextServer {
/// must go through the `GlobalContextServerManager` which holds
/// a model to the ContextServerManager.
pub struct ContextServerManager {
- servers: HashMap<String, Arc<ContextServer>>,
+ servers: HashMap<String, Arc<dyn ContextServer>>,
pending_servers: HashSet<String>,
}
@@ -141,7 +178,7 @@ impl ContextServerManager {
pub fn add_server(
&mut self,
- config: ServerConfig,
+ config: Arc<ServerConfig>,
cx: &ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
let server_id = config.id.clone();
@@ -153,8 +190,8 @@ impl ContextServerManager {
let task = {
let server_id = server_id.clone();
cx.spawn(|this, mut cx| async move {
- let server = Arc::new(ContextServer::new(config));
- server.start(&cx).await?;
+ let server = Arc::new(NativeContextServer::new(config));
+ server.clone().start(&cx).await?;
this.update(&mut cx, |this, cx| {
this.servers.insert(server_id.clone(), server);
this.pending_servers.remove(&server_id);
@@ -170,7 +207,7 @@ impl ContextServerManager {
task
}
- pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
+ pub fn get_server(&self, id: &str) -> Option<Arc<dyn ContextServer>> {
self.servers.get(id).cloned()
}
@@ -178,7 +215,7 @@ impl ContextServerManager {
let id = id.to_string();
cx.spawn(|this, mut cx| async move {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
- server.stop().await?;
+ server.stop()?;
}
this.update(&mut cx, |this, cx| {
this.pending_servers.remove(&id);
@@ -192,16 +229,16 @@ impl ContextServerManager {
pub fn restart_server(
&mut self,
- id: &str,
+ id: &Arc<str>,
cx: &mut ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
let id = id.to_string();
cx.spawn(|this, mut cx| async move {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
- server.stop().await?;
- let config = server.config.clone();
- let new_server = Arc::new(ContextServer::new(config));
- new_server.start(&cx).await?;
+ server.stop()?;
+ let config = server.config();
+ let new_server = Arc::new(NativeContextServer::new(config));
+ new_server.clone().start(&cx).await?;
this.update(&mut cx, |this, cx| {
this.servers.insert(id.clone(), new_server);
cx.emit(Event::ServerStopped {
@@ -216,7 +253,7 @@ impl ContextServerManager {
})
}
- pub fn servers(&self) -> Vec<Arc<ContextServer>> {
+ pub fn servers(&self) -> Vec<Arc<dyn ContextServer>> {
self.servers.values().cloned().collect()
}
@@ -224,7 +261,7 @@ impl ContextServerManager {
let current_servers = self
.servers()
.into_iter()
- .map(|server| (server.id.clone(), server.config.clone()))
+ .map(|server| (server.id(), server.config()))
.collect::<HashMap<_, _>>();
let new_servers = settings
@@ -235,19 +272,19 @@ impl ContextServerManager {
let servers_to_add = new_servers
.values()
- .filter(|config| !current_servers.contains_key(&config.id))
+ .filter(|config| !current_servers.contains_key(config.id.as_str()))
.cloned()
.collect::<Vec<_>>();
let servers_to_remove = current_servers
.keys()
- .filter(|id| !new_servers.contains_key(*id))
+ .filter(|id| !new_servers.contains_key(id.as_ref()))
.cloned()
.collect::<Vec<_>>();
log::trace!("servers_to_add={:?}", servers_to_add);
for config in servers_to_add {
- self.add_server(config, cx).detach_and_log_err(cx);
+ self.add_server(Arc::new(config), cx).detach_and_log_err(cx);
}
for id in servers_to_remove {