1use std::future::Future;
2use std::sync::Arc;
3
4use anyhow::{Context as _, Result};
5use collections::{HashMap, HashSet};
6use command_palette_hooks::CommandPaletteFilter;
7use gpui::{App, Context, Entity, EntityId, Global, SharedString, Subscription, Task, prelude::*};
8use jupyter_websocket_client::RemoteServer;
9use language::{Language, LanguageName};
10use project::{Fs, Project, ProjectPath, WorktreeId};
11use settings::{Settings, SettingsStore};
12use util::rel_path::RelPath;
13
14use crate::kernels::{
15 Kernel, list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
16 wsl_kernel_specifications,
17};
18use crate::{JupyterSettings, KernelSpecification, Session};
19
20struct GlobalReplStore(Entity<ReplStore>);
21
22impl Global for GlobalReplStore {}
23
24pub struct ReplStore {
25 fs: Arc<dyn Fs>,
26 enabled: bool,
27 sessions: HashMap<EntityId, Entity<Session>>,
28 kernel_specifications: Vec<KernelSpecification>,
29 selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
30 kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
31 active_python_toolchain_for_worktree: HashMap<WorktreeId, SharedString>,
32 remote_worktrees: HashSet<WorktreeId>,
33 _subscriptions: Vec<Subscription>,
34}
35
36impl ReplStore {
37 const NAMESPACE: &'static str = "repl";
38
39 pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
40 let store = cx.new(move |cx| Self::new(fs, cx));
41
42 #[cfg(not(feature = "test-support"))]
43 store
44 .update(cx, |store, cx| store.refresh_kernelspecs(cx))
45 .detach_and_log_err(cx);
46
47 cx.set_global(GlobalReplStore(store))
48 }
49
50 pub fn global(cx: &App) -> Entity<Self> {
51 cx.global::<GlobalReplStore>().0.clone()
52 }
53
54 pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
55 let subscriptions = vec![
56 cx.observe_global::<SettingsStore>(move |this, cx| {
57 this.set_enabled(JupyterSettings::enabled(cx), cx);
58 }),
59 cx.on_app_quit(Self::shutdown_all_sessions),
60 ];
61
62 let this = Self {
63 fs,
64 enabled: JupyterSettings::enabled(cx),
65 sessions: HashMap::default(),
66 kernel_specifications: Vec::new(),
67 _subscriptions: subscriptions,
68 kernel_specifications_for_worktree: HashMap::default(),
69 selected_kernel_for_worktree: HashMap::default(),
70 active_python_toolchain_for_worktree: HashMap::default(),
71 remote_worktrees: HashSet::default(),
72 };
73 this.on_enabled_changed(cx);
74 this
75 }
76
77 pub fn fs(&self) -> &Arc<dyn Fs> {
78 &self.fs
79 }
80
81 pub fn is_enabled(&self) -> bool {
82 self.enabled
83 }
84
85 pub fn has_python_kernelspecs(&self, worktree_id: WorktreeId) -> bool {
86 self.kernel_specifications_for_worktree
87 .contains_key(&worktree_id)
88 }
89
90 pub fn kernel_specifications_for_worktree(
91 &self,
92 worktree_id: WorktreeId,
93 ) -> impl Iterator<Item = &KernelSpecification> {
94 let global_specs = if self.remote_worktrees.contains(&worktree_id) {
95 None
96 } else {
97 Some(self.kernel_specifications.iter())
98 };
99
100 self.kernel_specifications_for_worktree
101 .get(&worktree_id)
102 .into_iter()
103 .flat_map(|specs| specs.iter())
104 .chain(global_specs.into_iter().flatten())
105 }
106
107 pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
108 self.kernel_specifications.iter()
109 }
110
111 pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
112 self.sessions.values()
113 }
114
115 fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
116 if self.enabled == enabled {
117 return;
118 }
119
120 self.enabled = enabled;
121 self.on_enabled_changed(cx);
122 }
123
124 fn on_enabled_changed(&self, cx: &mut Context<Self>) {
125 if !self.enabled {
126 CommandPaletteFilter::update_global(cx, |filter, _cx| {
127 filter.hide_namespace(Self::NAMESPACE);
128 });
129
130 return;
131 }
132
133 CommandPaletteFilter::update_global(cx, |filter, _cx| {
134 filter.show_namespace(Self::NAMESPACE);
135 });
136
137 cx.notify();
138 }
139
140 pub fn refresh_python_kernelspecs(
141 &mut self,
142 worktree_id: WorktreeId,
143 project: &Entity<Project>,
144 cx: &mut Context<Self>,
145 ) -> Task<Result<()>> {
146 let is_remote = project.read(cx).is_remote();
147 let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
148 let active_toolchain = project.read(cx).active_toolchain(
149 ProjectPath {
150 worktree_id,
151 path: RelPath::empty().into(),
152 },
153 LanguageName::new_static("Python"),
154 cx,
155 );
156
157 cx.spawn(async move |this, cx| {
158 let kernel_specifications = kernel_specifications
159 .await
160 .context("getting python kernelspecs")?;
161
162 let active_toolchain_path = active_toolchain.await.map(|toolchain| toolchain.path);
163
164 this.update(cx, |this, cx| {
165 this.kernel_specifications_for_worktree
166 .insert(worktree_id, kernel_specifications);
167 if let Some(path) = active_toolchain_path {
168 this.active_python_toolchain_for_worktree
169 .insert(worktree_id, path);
170 }
171 if is_remote {
172 this.remote_worktrees.insert(worktree_id);
173 } else {
174 this.remote_worktrees.remove(&worktree_id);
175 }
176 cx.notify();
177 })
178 })
179 }
180
181 fn get_remote_kernel_specifications(
182 &self,
183 cx: &mut Context<Self>,
184 ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
185 match (
186 std::env::var("JUPYTER_SERVER"),
187 std::env::var("JUPYTER_TOKEN"),
188 ) {
189 (Ok(server), Ok(token)) => {
190 let remote_server = RemoteServer {
191 base_url: server,
192 token,
193 };
194 let http_client = cx.http_client();
195 Some(cx.spawn(async move |_, _| {
196 list_remote_kernelspecs(remote_server, http_client)
197 .await
198 .map(|specs| {
199 specs
200 .into_iter()
201 .map(KernelSpecification::JupyterServer)
202 .collect()
203 })
204 }))
205 }
206 _ => None,
207 }
208 }
209
210 pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
211 let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
212 let wsl_kernel_specifications = wsl_kernel_specifications(cx.background_executor().clone());
213
214 let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
215
216 let all_specs = cx.background_spawn(async move {
217 let mut all_specs = local_kernel_specifications
218 .await?
219 .into_iter()
220 .map(KernelSpecification::Jupyter)
221 .collect::<Vec<_>>();
222
223 if let Ok(wsl_specs) = wsl_kernel_specifications.await {
224 all_specs.extend(wsl_specs);
225 }
226
227 if let Some(remote_task) = remote_kernel_specifications
228 && let Ok(remote_specs) = remote_task.await
229 {
230 all_specs.extend(remote_specs);
231 }
232
233 anyhow::Ok(all_specs)
234 });
235
236 cx.spawn(async move |this, cx| {
237 let all_specs = all_specs.await;
238
239 if let Ok(specs) = all_specs {
240 this.update(cx, |this, cx| {
241 this.kernel_specifications = specs;
242 cx.notify();
243 })
244 .ok();
245 }
246
247 anyhow::Ok(())
248 })
249 }
250
251 pub fn set_active_kernelspec(
252 &mut self,
253 worktree_id: WorktreeId,
254 kernelspec: KernelSpecification,
255 _cx: &mut Context<Self>,
256 ) {
257 self.selected_kernel_for_worktree
258 .insert(worktree_id, kernelspec);
259 }
260
261 pub fn active_python_toolchain_path(&self, worktree_id: WorktreeId) -> Option<&SharedString> {
262 self.active_python_toolchain_for_worktree.get(&worktree_id)
263 }
264
265 pub fn selected_kernel(&self, worktree_id: WorktreeId) -> Option<&KernelSpecification> {
266 self.selected_kernel_for_worktree.get(&worktree_id)
267 }
268
269 pub fn is_recommended_kernel(
270 &self,
271 worktree_id: WorktreeId,
272 spec: &KernelSpecification,
273 ) -> bool {
274 if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
275 spec.path().as_ref() == active_path.as_ref()
276 } else {
277 false
278 }
279 }
280
281 pub fn active_kernelspec(
282 &self,
283 worktree_id: WorktreeId,
284 language_at_cursor: Option<Arc<Language>>,
285 cx: &App,
286 ) -> Option<KernelSpecification> {
287 if let Some(selected) = self.selected_kernel_for_worktree.get(&worktree_id).cloned() {
288 return Some(selected);
289 }
290
291 let language_at_cursor = language_at_cursor?;
292 let language_name = language_at_cursor.code_fence_block_name().to_lowercase();
293
294 // Prefer the recommended (active toolchain) kernel if it has ipykernel
295 if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
296 let recommended = self
297 .kernel_specifications_for_worktree(worktree_id)
298 .find(|spec| {
299 spec.has_ipykernel()
300 && spec.language().as_ref().to_lowercase() == language_name
301 && spec.path().as_ref() == active_path.as_ref()
302 })
303 .cloned();
304 if recommended.is_some() {
305 return recommended;
306 }
307 }
308
309 // Then try the first PythonEnv with ipykernel matching the language
310 let python_env = self
311 .kernel_specifications_for_worktree(worktree_id)
312 .find(|spec| {
313 matches!(spec, KernelSpecification::PythonEnv(_))
314 && spec.has_ipykernel()
315 && spec.language().as_ref().to_lowercase() == language_name
316 })
317 .cloned();
318 if python_env.is_some() {
319 return python_env;
320 }
321
322 // Fall back to legacy name-based and language-based matching
323 self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
324 }
325
326 fn kernelspec_legacy_by_lang_only(
327 &self,
328 worktree_id: WorktreeId,
329 language_at_cursor: Arc<Language>,
330 cx: &App,
331 ) -> Option<KernelSpecification> {
332 let settings = JupyterSettings::get_global(cx);
333 let selected_kernel = settings
334 .kernel_selections
335 .get(language_at_cursor.code_fence_block_name().as_ref());
336
337 let found_by_name = self
338 .kernel_specifications_for_worktree(worktree_id)
339 .find(|runtime_specification| {
340 if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
341 (selected_kernel, runtime_specification)
342 {
343 return runtime_specification.name.to_lowercase() == selected.to_lowercase();
344 }
345 false
346 })
347 .cloned();
348
349 if let Some(found_by_name) = found_by_name {
350 return Some(found_by_name);
351 }
352
353 let language_name = language_at_cursor.code_fence_block_name().to_lowercase();
354 self.kernel_specifications_for_worktree(worktree_id)
355 .find(|spec| {
356 spec.has_ipykernel() && spec.language().as_ref().to_lowercase() == language_name
357 })
358 .cloned()
359 }
360
361 pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
362 self.sessions.get(&entity_id)
363 }
364
365 pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
366 self.sessions.insert(entity_id, session);
367 }
368
369 pub fn remove_session(&mut self, entity_id: EntityId) {
370 self.sessions.remove(&entity_id);
371 }
372
373 fn shutdown_all_sessions(
374 &mut self,
375 cx: &mut Context<Self>,
376 ) -> impl Future<Output = ()> + use<> {
377 for session in self.sessions.values() {
378 session.update(cx, |session, _cx| {
379 if let Kernel::RunningKernel(mut kernel) =
380 std::mem::replace(&mut session.kernel, Kernel::Shutdown)
381 {
382 kernel.kill();
383 }
384 });
385 }
386 self.sessions.clear();
387 futures::future::ready(())
388 }
389
390 #[cfg(test)]
391 pub fn set_kernel_specs_for_testing(
392 &mut self,
393 specs: Vec<KernelSpecification>,
394 cx: &mut Context<Self>,
395 ) {
396 self.kernel_specifications = specs;
397 cx.notify();
398 }
399}