1use std::sync::Arc;
2
3use anyhow::{Context as _, Result};
4use collections::HashMap;
5use command_palette_hooks::CommandPaletteFilter;
6use gpui::{App, Context, Entity, EntityId, Global, Subscription, Task, prelude::*};
7use jupyter_websocket_client::RemoteServer;
8use language::Language;
9use project::{Fs, Project, WorktreeId};
10use settings::{Settings, SettingsStore};
11
12use crate::kernels::{
13 list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
14};
15use crate::{JupyterSettings, KernelSpecification, Session};
16
17struct GlobalReplStore(Entity<ReplStore>);
18
19impl Global for GlobalReplStore {}
20
21pub struct ReplStore {
22 fs: Arc<dyn Fs>,
23 enabled: bool,
24 sessions: HashMap<EntityId, Entity<Session>>,
25 kernel_specifications: Vec<KernelSpecification>,
26 selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
27 kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
28 _subscriptions: Vec<Subscription>,
29}
30
31impl ReplStore {
32 const NAMESPACE: &'static str = "repl";
33
34 pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
35 let store = cx.new(move |cx| Self::new(fs, cx));
36
37 store
38 .update(cx, |store, cx| store.refresh_kernelspecs(cx))
39 .detach_and_log_err(cx);
40
41 cx.set_global(GlobalReplStore(store))
42 }
43
44 pub fn global(cx: &App) -> Entity<Self> {
45 cx.global::<GlobalReplStore>().0.clone()
46 }
47
48 pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
49 let subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
50 this.set_enabled(JupyterSettings::enabled(cx), cx);
51 })];
52
53 let this = Self {
54 fs,
55 enabled: JupyterSettings::enabled(cx),
56 sessions: HashMap::default(),
57 kernel_specifications: Vec::new(),
58 _subscriptions: subscriptions,
59 kernel_specifications_for_worktree: HashMap::default(),
60 selected_kernel_for_worktree: HashMap::default(),
61 };
62 this.on_enabled_changed(cx);
63 this
64 }
65
66 pub fn fs(&self) -> &Arc<dyn Fs> {
67 &self.fs
68 }
69
70 pub const fn is_enabled(&self) -> bool {
71 self.enabled
72 }
73
74 pub fn kernel_specifications_for_worktree(
75 &self,
76 worktree_id: WorktreeId,
77 ) -> impl Iterator<Item = &KernelSpecification> {
78 self.kernel_specifications_for_worktree
79 .get(&worktree_id)
80 .into_iter()
81 .flat_map(|specs| specs.iter())
82 .chain(self.kernel_specifications.iter())
83 }
84
85 pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
86 self.kernel_specifications.iter()
87 }
88
89 pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
90 self.sessions.values()
91 }
92
93 fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
94 if self.enabled == enabled {
95 return;
96 }
97
98 self.enabled = enabled;
99 self.on_enabled_changed(cx);
100 }
101
102 fn on_enabled_changed(&self, cx: &mut Context<Self>) {
103 if !self.enabled {
104 CommandPaletteFilter::update_global(cx, |filter, _cx| {
105 filter.hide_namespace(Self::NAMESPACE);
106 });
107
108 return;
109 }
110
111 CommandPaletteFilter::update_global(cx, |filter, _cx| {
112 filter.show_namespace(Self::NAMESPACE);
113 });
114
115 cx.notify();
116 }
117
118 pub fn refresh_python_kernelspecs(
119 &mut self,
120 worktree_id: WorktreeId,
121 project: &Entity<Project>,
122 cx: &mut Context<Self>,
123 ) -> Task<Result<()>> {
124 let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
125 cx.spawn(async move |this, cx| {
126 let kernel_specifications = kernel_specifications
127 .await
128 .context("getting python kernelspecs")?;
129
130 this.update(cx, |this, cx| {
131 this.kernel_specifications_for_worktree
132 .insert(worktree_id, kernel_specifications);
133 cx.notify();
134 })
135 })
136 }
137
138 fn get_remote_kernel_specifications(
139 &self,
140 cx: &mut Context<Self>,
141 ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
142 match (
143 std::env::var("JUPYTER_SERVER"),
144 std::env::var("JUPYTER_TOKEN"),
145 ) {
146 (Ok(server), Ok(token)) => {
147 let remote_server = RemoteServer {
148 base_url: server,
149 token,
150 };
151 let http_client = cx.http_client();
152 Some(cx.spawn(async move |_, _| {
153 list_remote_kernelspecs(remote_server, http_client)
154 .await
155 .map(|specs| specs.into_iter().map(KernelSpecification::Remote).collect())
156 }))
157 }
158 _ => None,
159 }
160 }
161
162 pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
163 let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
164
165 let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
166
167 let all_specs = cx.background_spawn(async move {
168 let mut all_specs = local_kernel_specifications
169 .await?
170 .into_iter()
171 .map(KernelSpecification::Jupyter)
172 .collect::<Vec<_>>();
173
174 if let Some(remote_task) = remote_kernel_specifications
175 && let Ok(remote_specs) = remote_task.await
176 {
177 all_specs.extend(remote_specs);
178 }
179
180 anyhow::Ok(all_specs)
181 });
182
183 cx.spawn(async move |this, cx| {
184 let all_specs = all_specs.await;
185
186 if let Ok(specs) = all_specs {
187 this.update(cx, |this, cx| {
188 this.kernel_specifications = specs;
189 cx.notify();
190 })
191 .ok();
192 }
193
194 anyhow::Ok(())
195 })
196 }
197
198 pub fn set_active_kernelspec(
199 &mut self,
200 worktree_id: WorktreeId,
201 kernelspec: KernelSpecification,
202 _cx: &mut Context<Self>,
203 ) {
204 self.selected_kernel_for_worktree
205 .insert(worktree_id, kernelspec);
206 }
207
208 pub fn active_kernelspec(
209 &self,
210 worktree_id: WorktreeId,
211 language_at_cursor: Option<Arc<Language>>,
212 cx: &App,
213 ) -> Option<KernelSpecification> {
214 let selected_kernelspec = self.selected_kernel_for_worktree.get(&worktree_id).cloned();
215
216 if let Some(language_at_cursor) = language_at_cursor {
217 selected_kernelspec.or_else(|| {
218 self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
219 })
220 } else {
221 selected_kernelspec
222 }
223 }
224
225 fn kernelspec_legacy_by_lang_only(
226 &self,
227 worktree_id: WorktreeId,
228 language_at_cursor: Arc<Language>,
229 cx: &App,
230 ) -> Option<KernelSpecification> {
231 let settings = JupyterSettings::get_global(cx);
232 let selected_kernel = settings
233 .kernel_selections
234 .get(language_at_cursor.code_fence_block_name().as_ref());
235
236 let found_by_name = self
237 .kernel_specifications_for_worktree(worktree_id)
238 .find(|runtime_specification| {
239 if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
240 (selected_kernel, runtime_specification)
241 {
242 // Top priority is the selected kernel
243 return runtime_specification.name.to_lowercase() == selected.to_lowercase();
244 }
245 false
246 })
247 .cloned();
248
249 if let Some(found_by_name) = found_by_name {
250 return Some(found_by_name);
251 }
252
253 self.kernel_specifications_for_worktree(worktree_id)
254 .find(|kernel_option| match kernel_option {
255 KernelSpecification::Jupyter(runtime_specification) => {
256 runtime_specification.kernelspec.language.to_lowercase()
257 == language_at_cursor.code_fence_block_name().to_lowercase()
258 }
259 KernelSpecification::PythonEnv(runtime_specification) => {
260 runtime_specification.kernelspec.language.to_lowercase()
261 == language_at_cursor.code_fence_block_name().to_lowercase()
262 }
263 KernelSpecification::Remote(remote_spec) => {
264 remote_spec.kernelspec.language.to_lowercase()
265 == language_at_cursor.code_fence_block_name().to_lowercase()
266 }
267 })
268 .cloned()
269 }
270
271 pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
272 self.sessions.get(&entity_id)
273 }
274
275 pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
276 self.sessions.insert(entity_id, session);
277 }
278
279 pub fn remove_session(&mut self, entity_id: EntityId) {
280 self.sessions.remove(&entity_id);
281 }
282
283 #[cfg(test)]
284 pub fn set_kernel_specs_for_testing(
285 &mut self,
286 specs: Vec<KernelSpecification>,
287 cx: &mut Context<Self>,
288 ) {
289 self.kernel_specifications = specs;
290 cx.notify();
291 }
292}