1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{ToolId, ToolWorkingSet};
5use collections::HashMap;
6use context_server::manager::ContextServerManager;
7use context_server::{ContextServerFactoryRegistry, ContextServerTool};
8use gpui::{prelude::*, AppContext, Model, ModelContext, Task};
9use project::Project;
10use util::ResultExt as _;
11
12pub struct ThreadStore {
13 #[allow(unused)]
14 project: Model<Project>,
15 tools: Arc<ToolWorkingSet>,
16 context_server_manager: Model<ContextServerManager>,
17 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
18}
19
20impl ThreadStore {
21 pub fn new(
22 project: Model<Project>,
23 tools: Arc<ToolWorkingSet>,
24 cx: &mut AppContext,
25 ) -> Task<Result<Model<Self>>> {
26 cx.spawn(|mut cx| async move {
27 let this = cx.new_model(|cx: &mut ModelContext<Self>| {
28 let context_server_factory_registry =
29 ContextServerFactoryRegistry::default_global(cx);
30 let context_server_manager = cx.new_model(|cx| {
31 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
32 });
33
34 let this = Self {
35 project,
36 tools,
37 context_server_manager,
38 context_server_tool_ids: HashMap::default(),
39 };
40 this.register_context_server_handlers(cx);
41
42 this
43 })?;
44
45 Ok(this)
46 })
47 }
48
49 fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
50 cx.subscribe(
51 &self.context_server_manager.clone(),
52 Self::handle_context_server_event,
53 )
54 .detach();
55 }
56
57 fn handle_context_server_event(
58 &mut self,
59 context_server_manager: Model<ContextServerManager>,
60 event: &context_server::manager::Event,
61 cx: &mut ModelContext<Self>,
62 ) {
63 let tool_working_set = self.tools.clone();
64 match event {
65 context_server::manager::Event::ServerStarted { server_id } => {
66 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
67 let context_server_manager = context_server_manager.clone();
68 cx.spawn({
69 let server = server.clone();
70 let server_id = server_id.clone();
71 |this, mut cx| async move {
72 let Some(protocol) = server.client() else {
73 return;
74 };
75
76 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
77 if let Some(tools) = protocol.list_tools().await.log_err() {
78 let tool_ids = tools
79 .tools
80 .into_iter()
81 .map(|tool| {
82 log::info!(
83 "registering context server tool: {:?}",
84 tool.name
85 );
86 tool_working_set.insert(Arc::new(
87 ContextServerTool::new(
88 context_server_manager.clone(),
89 server.id(),
90 tool,
91 ),
92 ))
93 })
94 .collect::<Vec<_>>();
95
96 this.update(&mut cx, |this, _cx| {
97 this.context_server_tool_ids.insert(server_id, tool_ids);
98 })
99 .log_err();
100 }
101 }
102 }
103 })
104 .detach();
105 }
106 }
107 context_server::manager::Event::ServerStopped { server_id } => {
108 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
109 tool_working_set.remove(&tool_ids);
110 }
111 }
112 }
113 }
114}