thread_store.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{ToolId, ToolWorkingSet};
  5use collections::HashMap;
  6use context_server::manager::ContextServerManager;
  7use context_server::{ContextServerFactoryRegistry, ContextServerTool};
  8use gpui::{prelude::*, AppContext, Model, ModelContext, Task};
  9use project::Project;
 10use util::ResultExt as _;
 11
 12pub struct ThreadStore {
 13    #[allow(unused)]
 14    project: Model<Project>,
 15    tools: Arc<ToolWorkingSet>,
 16    context_server_manager: Model<ContextServerManager>,
 17    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 18}
 19
 20impl ThreadStore {
 21    pub fn new(
 22        project: Model<Project>,
 23        tools: Arc<ToolWorkingSet>,
 24        cx: &mut AppContext,
 25    ) -> Task<Result<Model<Self>>> {
 26        cx.spawn(|mut cx| async move {
 27            let this = cx.new_model(|cx: &mut ModelContext<Self>| {
 28                let context_server_factory_registry =
 29                    ContextServerFactoryRegistry::default_global(cx);
 30                let context_server_manager = cx.new_model(|cx| {
 31                    ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 32                });
 33
 34                let this = Self {
 35                    project,
 36                    tools,
 37                    context_server_manager,
 38                    context_server_tool_ids: HashMap::default(),
 39                };
 40                this.register_context_server_handlers(cx);
 41
 42                this
 43            })?;
 44
 45            Ok(this)
 46        })
 47    }
 48
 49    fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
 50        cx.subscribe(
 51            &self.context_server_manager.clone(),
 52            Self::handle_context_server_event,
 53        )
 54        .detach();
 55    }
 56
 57    fn handle_context_server_event(
 58        &mut self,
 59        context_server_manager: Model<ContextServerManager>,
 60        event: &context_server::manager::Event,
 61        cx: &mut ModelContext<Self>,
 62    ) {
 63        let tool_working_set = self.tools.clone();
 64        match event {
 65            context_server::manager::Event::ServerStarted { server_id } => {
 66                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
 67                    let context_server_manager = context_server_manager.clone();
 68                    cx.spawn({
 69                        let server = server.clone();
 70                        let server_id = server_id.clone();
 71                        |this, mut cx| async move {
 72                            let Some(protocol) = server.client() else {
 73                                return;
 74                            };
 75
 76                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
 77                                if let Some(tools) = protocol.list_tools().await.log_err() {
 78                                    let tool_ids = tools
 79                                        .tools
 80                                        .into_iter()
 81                                        .map(|tool| {
 82                                            log::info!(
 83                                                "registering context server tool: {:?}",
 84                                                tool.name
 85                                            );
 86                                            tool_working_set.insert(Arc::new(
 87                                                ContextServerTool::new(
 88                                                    context_server_manager.clone(),
 89                                                    server.id(),
 90                                                    tool,
 91                                                ),
 92                                            ))
 93                                        })
 94                                        .collect::<Vec<_>>();
 95
 96                                    this.update(&mut cx, |this, _cx| {
 97                                        this.context_server_tool_ids.insert(server_id, tool_ids);
 98                                    })
 99                                    .log_err();
100                                }
101                            }
102                        }
103                    })
104                    .detach();
105                }
106            }
107            context_server::manager::Event::ServerStopped { server_id } => {
108                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
109                    tool_working_set.remove(&tool_ids);
110                }
111            }
112        }
113    }
114}