context_server_registry.rs

  1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
  2use agent_client_protocol::ToolKind;
  3use anyhow::{Result, anyhow, bail};
  4use collections::{BTreeMap, HashMap};
  5use context_server::ContextServerId;
  6use gpui::{App, Context, Entity, SharedString, Task};
  7use project::context_server_store::{ContextServerStatus, ContextServerStore};
  8use std::sync::Arc;
  9use util::ResultExt;
 10
 11pub struct ContextServerRegistry {
 12    server_store: Entity<ContextServerStore>,
 13    registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
 14    _subscription: gpui::Subscription,
 15}
 16
 17struct RegisteredContextServer {
 18    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 19    load_tools: Task<Result<()>>,
 20}
 21
 22impl ContextServerRegistry {
 23    pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
 24        let mut this = Self {
 25            server_store: server_store.clone(),
 26            registered_servers: HashMap::default(),
 27            _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
 28        };
 29        for server in server_store.read(cx).running_servers() {
 30            this.reload_tools_for_server(server.id(), cx);
 31        }
 32        this
 33    }
 34
 35    pub fn tools_for_server(
 36        &self,
 37        server_id: &ContextServerId,
 38    ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
 39        self.registered_servers
 40            .get(server_id)
 41            .map(|server| server.tools.values())
 42            .into_iter()
 43            .flatten()
 44    }
 45
 46    pub fn servers(
 47        &self,
 48    ) -> impl Iterator<
 49        Item = (
 50            &ContextServerId,
 51            &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 52        ),
 53    > {
 54        self.registered_servers
 55            .iter()
 56            .map(|(id, server)| (id, &server.tools))
 57    }
 58
 59    fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
 60        let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
 61            return;
 62        };
 63        let Some(client) = server.client() else {
 64            return;
 65        };
 66        if !client.capable(context_server::protocol::ServerCapability::Tools) {
 67            return;
 68        }
 69
 70        let registered_server =
 71            self.registered_servers
 72                .entry(server_id.clone())
 73                .or_insert(RegisteredContextServer {
 74                    tools: BTreeMap::default(),
 75                    load_tools: Task::ready(Ok(())),
 76                });
 77        registered_server.load_tools = cx.spawn(async move |this, cx| {
 78            let response = client
 79                .request::<context_server::types::requests::ListTools>(())
 80                .await;
 81
 82            this.update(cx, |this, cx| {
 83                let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
 84                    return;
 85                };
 86
 87                registered_server.tools.clear();
 88                if let Some(response) = response.log_err() {
 89                    for tool in response.tools {
 90                        let tool = Arc::new(ContextServerTool::new(
 91                            this.server_store.clone(),
 92                            server.id(),
 93                            tool,
 94                        ));
 95                        registered_server.tools.insert(tool.name(), tool);
 96                    }
 97                    cx.notify();
 98                }
 99            })
100        });
101    }
102
103    fn handle_context_server_store_event(
104        &mut self,
105        _: Entity<ContextServerStore>,
106        event: &project::context_server_store::Event,
107        cx: &mut Context<Self>,
108    ) {
109        match event {
110            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
111                match status {
112                    ContextServerStatus::Starting => {}
113                    ContextServerStatus::Running => {
114                        self.reload_tools_for_server(server_id.clone(), cx);
115                    }
116                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
117                        self.registered_servers.remove(server_id);
118                        cx.notify();
119                    }
120                }
121            }
122        }
123    }
124}
125
126struct ContextServerTool {
127    store: Entity<ContextServerStore>,
128    server_id: ContextServerId,
129    tool: context_server::types::Tool,
130}
131
132impl ContextServerTool {
133    fn new(
134        store: Entity<ContextServerStore>,
135        server_id: ContextServerId,
136        tool: context_server::types::Tool,
137    ) -> Self {
138        Self {
139            store,
140            server_id,
141            tool,
142        }
143    }
144}
145
146impl AnyAgentTool for ContextServerTool {
147    fn name(&self) -> SharedString {
148        self.tool.name.clone().into()
149    }
150
151    fn description(&self) -> SharedString {
152        self.tool.description.clone().unwrap_or_default().into()
153    }
154
155    fn kind(&self) -> ToolKind {
156        ToolKind::Other
157    }
158
159    fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
160        format!("Run MCP tool `{}`", self.tool.name).into()
161    }
162
163    fn input_schema(
164        &self,
165        format: language_model::LanguageModelToolSchemaFormat,
166    ) -> Result<serde_json::Value> {
167        let mut schema = self.tool.input_schema.clone();
168        crate::tool_schema::adapt_schema_to_format(&mut schema, format)?;
169        Ok(match schema {
170            serde_json::Value::Null => {
171                serde_json::json!({ "type": "object", "properties": [] })
172            }
173            serde_json::Value::Object(map) if map.is_empty() => {
174                serde_json::json!({ "type": "object", "properties": [] })
175            }
176            _ => schema,
177        })
178    }
179
180    fn run(
181        self: Arc<Self>,
182        input: serde_json::Value,
183        event_stream: ToolCallEventStream,
184        cx: &mut App,
185    ) -> Task<Result<AgentToolOutput>> {
186        let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
187            return Task::ready(Err(anyhow!("Context server not found")));
188        };
189        let tool_name = self.tool.name.clone();
190        let authorize = event_stream.authorize(self.initial_title(input.clone(), cx), cx);
191
192        cx.spawn(async move |_cx| {
193            authorize.await?;
194
195            let Some(protocol) = server.client() else {
196                bail!("Context server not initialized");
197            };
198
199            let arguments = if let serde_json::Value::Object(map) = input {
200                Some(map.into_iter().collect())
201            } else {
202                None
203            };
204
205            log::trace!(
206                "Running tool: {} with arguments: {:?}",
207                tool_name,
208                arguments
209            );
210            let response = protocol
211                .request::<context_server::types::requests::CallTool>(
212                    context_server::types::CallToolParams {
213                        name: tool_name,
214                        arguments,
215                        meta: None,
216                    },
217                )
218                .await?;
219
220            let mut result = String::new();
221            for content in response.content {
222                match content {
223                    context_server::types::ToolResponseContent::Text { text } => {
224                        result.push_str(&text);
225                    }
226                    context_server::types::ToolResponseContent::Image { .. } => {
227                        log::warn!("Ignoring image content from tool response");
228                    }
229                    context_server::types::ToolResponseContent::Audio { .. } => {
230                        log::warn!("Ignoring audio content from tool response");
231                    }
232                    context_server::types::ToolResponseContent::Resource { .. } => {
233                        log::warn!("Ignoring resource content from tool response");
234                    }
235                }
236            }
237            Ok(AgentToolOutput {
238                raw_output: result.clone().into(),
239                llm_output: result.into(),
240            })
241        })
242    }
243
244    fn replay(
245        &self,
246        _input: serde_json::Value,
247        _output: serde_json::Value,
248        _event_stream: ToolCallEventStream,
249        _cx: &mut App,
250    ) -> Result<()> {
251        Ok(())
252    }
253}