context_server_registry.rs

  1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
  2use agent_client_protocol::ToolKind;
  3use anyhow::{Result, anyhow, bail};
  4use collections::{BTreeMap, HashMap};
  5use context_server::ContextServerId;
  6use gpui::{App, Context, Entity, SharedString, Task};
  7use project::context_server_store::{ContextServerStatus, ContextServerStore};
  8use std::sync::Arc;
  9use util::ResultExt;
 10
 11pub struct ContextServerRegistry {
 12    server_store: Entity<ContextServerStore>,
 13    registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
 14    _subscription: gpui::Subscription,
 15}
 16
 17struct RegisteredContextServer {
 18    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 19    load_tools: Task<Result<()>>,
 20}
 21
 22impl ContextServerRegistry {
 23    pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
 24        let mut this = Self {
 25            server_store: server_store.clone(),
 26            registered_servers: HashMap::default(),
 27            _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
 28        };
 29        for server in server_store.read(cx).running_servers() {
 30            this.reload_tools_for_server(server.id(), cx);
 31        }
 32        this
 33    }
 34
 35    pub fn servers(
 36        &self,
 37    ) -> impl Iterator<
 38        Item = (
 39            &ContextServerId,
 40            &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 41        ),
 42    > {
 43        self.registered_servers
 44            .iter()
 45            .map(|(id, server)| (id, &server.tools))
 46    }
 47
 48    fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
 49        let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
 50            return;
 51        };
 52        let Some(client) = server.client() else {
 53            return;
 54        };
 55        if !client.capable(context_server::protocol::ServerCapability::Tools) {
 56            return;
 57        }
 58
 59        let registered_server =
 60            self.registered_servers
 61                .entry(server_id.clone())
 62                .or_insert(RegisteredContextServer {
 63                    tools: BTreeMap::default(),
 64                    load_tools: Task::ready(Ok(())),
 65                });
 66        registered_server.load_tools = cx.spawn(async move |this, cx| {
 67            let response = client
 68                .request::<context_server::types::requests::ListTools>(())
 69                .await;
 70
 71            this.update(cx, |this, cx| {
 72                let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
 73                    return;
 74                };
 75
 76                registered_server.tools.clear();
 77                if let Some(response) = response.log_err() {
 78                    for tool in response.tools {
 79                        let tool = Arc::new(ContextServerTool::new(
 80                            this.server_store.clone(),
 81                            server.id(),
 82                            tool,
 83                        ));
 84                        registered_server.tools.insert(tool.name(), tool);
 85                    }
 86                    cx.notify();
 87                }
 88            })
 89        });
 90    }
 91
 92    fn handle_context_server_store_event(
 93        &mut self,
 94        _: Entity<ContextServerStore>,
 95        event: &project::context_server_store::Event,
 96        cx: &mut Context<Self>,
 97    ) {
 98        match event {
 99            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
100                match status {
101                    ContextServerStatus::Starting => {}
102                    ContextServerStatus::Running => {
103                        self.reload_tools_for_server(server_id.clone(), cx);
104                    }
105                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
106                        self.registered_servers.remove(server_id);
107                        cx.notify();
108                    }
109                }
110            }
111        }
112    }
113}
114
115struct ContextServerTool {
116    store: Entity<ContextServerStore>,
117    server_id: ContextServerId,
118    tool: context_server::types::Tool,
119}
120
121impl ContextServerTool {
122    fn new(
123        store: Entity<ContextServerStore>,
124        server_id: ContextServerId,
125        tool: context_server::types::Tool,
126    ) -> Self {
127        Self {
128            store,
129            server_id,
130            tool,
131        }
132    }
133}
134
135impl AnyAgentTool for ContextServerTool {
136    fn name(&self) -> SharedString {
137        self.tool.name.clone().into()
138    }
139
140    fn description(&self) -> SharedString {
141        self.tool.description.clone().unwrap_or_default().into()
142    }
143
144    fn kind(&self) -> ToolKind {
145        ToolKind::Other
146    }
147
148    fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
149        format!("Run MCP tool `{}`", self.tool.name).into()
150    }
151
152    fn input_schema(
153        &self,
154        format: language_model::LanguageModelToolSchemaFormat,
155    ) -> Result<serde_json::Value> {
156        let mut schema = self.tool.input_schema.clone();
157        assistant_tool::adapt_schema_to_format(&mut schema, format)?;
158        Ok(match schema {
159            serde_json::Value::Null => {
160                serde_json::json!({ "type": "object", "properties": [] })
161            }
162            serde_json::Value::Object(map) if map.is_empty() => {
163                serde_json::json!({ "type": "object", "properties": [] })
164            }
165            _ => schema,
166        })
167    }
168
169    fn run(
170        self: Arc<Self>,
171        input: serde_json::Value,
172        event_stream: ToolCallEventStream,
173        cx: &mut App,
174    ) -> Task<Result<AgentToolOutput>> {
175        let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
176            return Task::ready(Err(anyhow!("Context server not found")));
177        };
178        let tool_name = self.tool.name.clone();
179        let authorize = event_stream.authorize(self.initial_title(input.clone(), cx), cx);
180
181        cx.spawn(async move |_cx| {
182            authorize.await?;
183
184            let Some(protocol) = server.client() else {
185                bail!("Context server not initialized");
186            };
187
188            let arguments = if let serde_json::Value::Object(map) = input {
189                Some(map.into_iter().collect())
190            } else {
191                None
192            };
193
194            log::trace!(
195                "Running tool: {} with arguments: {:?}",
196                tool_name,
197                arguments
198            );
199            let response = protocol
200                .request::<context_server::types::requests::CallTool>(
201                    context_server::types::CallToolParams {
202                        name: tool_name,
203                        arguments,
204                        meta: None,
205                    },
206                )
207                .await?;
208
209            let mut result = String::new();
210            for content in response.content {
211                match content {
212                    context_server::types::ToolResponseContent::Text { text } => {
213                        result.push_str(&text);
214                    }
215                    context_server::types::ToolResponseContent::Image { .. } => {
216                        log::warn!("Ignoring image content from tool response");
217                    }
218                    context_server::types::ToolResponseContent::Audio { .. } => {
219                        log::warn!("Ignoring audio content from tool response");
220                    }
221                    context_server::types::ToolResponseContent::Resource { .. } => {
222                        log::warn!("Ignoring resource content from tool response");
223                    }
224                }
225            }
226            Ok(AgentToolOutput {
227                raw_output: result.clone().into(),
228                llm_output: result.into(),
229            })
230        })
231    }
232
233    fn replay(
234        &self,
235        _input: serde_json::Value,
236        _output: serde_json::Value,
237        _event_stream: ToolCallEventStream,
238        _cx: &mut App,
239    ) -> Result<()> {
240        Ok(())
241    }
242}