1use std::future::Future;
2use std::sync::Arc;
3
4use anyhow::{Context as _, Result};
5use collections::{HashMap, HashSet};
6use command_palette_hooks::CommandPaletteFilter;
7use gpui::{App, Context, Entity, EntityId, Global, SharedString, Subscription, Task, prelude::*};
8use jupyter_websocket_client::RemoteServer;
9use language::{Language, LanguageName};
10use project::{Fs, Project, ProjectPath, WorktreeId};
11use remote::RemoteConnectionOptions;
12use settings::{Settings, SettingsStore};
13use util::rel_path::RelPath;
14
15use crate::kernels::{
16 Kernel, list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
17 wsl_kernel_specifications,
18};
19use crate::{JupyterSettings, KernelSpecification, Session};
20
21struct GlobalReplStore(Entity<ReplStore>);
22
23impl Global for GlobalReplStore {}
24
25pub struct ReplStore {
26 fs: Arc<dyn Fs>,
27 enabled: bool,
28 sessions: HashMap<EntityId, Entity<Session>>,
29 kernel_specifications: Vec<KernelSpecification>,
30 kernelspecs_initialized: bool,
31 selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
32 kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
33 active_python_toolchain_for_worktree: HashMap<WorktreeId, SharedString>,
34 remote_worktrees: HashSet<WorktreeId>,
35 fetching_python_kernelspecs: HashSet<WorktreeId>,
36 _subscriptions: Vec<Subscription>,
37}
38
39impl ReplStore {
40 const NAMESPACE: &'static str = "repl";
41
42 pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
43 let store = cx.new(move |cx| Self::new(fs, cx));
44 cx.set_global(GlobalReplStore(store))
45 }
46
47 pub fn global(cx: &App) -> Entity<Self> {
48 cx.global::<GlobalReplStore>().0.clone()
49 }
50
51 pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
52 let subscriptions = vec![
53 cx.observe_global::<SettingsStore>(move |this, cx| {
54 this.set_enabled(JupyterSettings::enabled(cx), cx);
55 }),
56 cx.on_app_quit(Self::shutdown_all_sessions),
57 ];
58
59 let this = Self {
60 fs,
61 enabled: JupyterSettings::enabled(cx),
62 sessions: HashMap::default(),
63 kernel_specifications: Vec::new(),
64 kernelspecs_initialized: false,
65 _subscriptions: subscriptions,
66 kernel_specifications_for_worktree: HashMap::default(),
67 selected_kernel_for_worktree: HashMap::default(),
68 active_python_toolchain_for_worktree: HashMap::default(),
69 remote_worktrees: HashSet::default(),
70 fetching_python_kernelspecs: HashSet::default(),
71 };
72 this.on_enabled_changed(cx);
73 this
74 }
75
76 pub fn fs(&self) -> &Arc<dyn Fs> {
77 &self.fs
78 }
79
80 pub fn is_enabled(&self) -> bool {
81 self.enabled
82 }
83
84 pub fn has_python_kernelspecs(&self, worktree_id: WorktreeId) -> bool {
85 self.kernel_specifications_for_worktree
86 .contains_key(&worktree_id)
87 }
88
89 pub fn kernel_specifications_for_worktree(
90 &self,
91 worktree_id: WorktreeId,
92 ) -> impl Iterator<Item = &KernelSpecification> {
93 let global_specs = if self.remote_worktrees.contains(&worktree_id) {
94 None
95 } else {
96 Some(self.kernel_specifications.iter())
97 };
98
99 self.kernel_specifications_for_worktree
100 .get(&worktree_id)
101 .into_iter()
102 .flat_map(|specs| specs.iter())
103 .chain(global_specs.into_iter().flatten())
104 }
105
106 pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
107 self.kernel_specifications.iter()
108 }
109
110 pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
111 self.sessions.values()
112 }
113
114 fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
115 if self.enabled == enabled {
116 return;
117 }
118
119 self.enabled = enabled;
120 self.on_enabled_changed(cx);
121 }
122
123 fn on_enabled_changed(&self, cx: &mut Context<Self>) {
124 if !self.enabled {
125 CommandPaletteFilter::update_global(cx, |filter, _cx| {
126 filter.hide_namespace(Self::NAMESPACE);
127 });
128
129 return;
130 }
131
132 CommandPaletteFilter::update_global(cx, |filter, _cx| {
133 filter.show_namespace(Self::NAMESPACE);
134 });
135
136 cx.notify();
137 }
138
139 pub fn refresh_python_kernelspecs(
140 &mut self,
141 worktree_id: WorktreeId,
142 project: &Entity<Project>,
143 cx: &mut Context<Self>,
144 ) -> Task<Result<()>> {
145 if !self.fetching_python_kernelspecs.insert(worktree_id) {
146 return Task::ready(Ok(()));
147 }
148
149 let is_remote = project.read(cx).is_remote();
150 // WSL does require access to global kernel specs, so we only exclude remote worktrees that aren't WSL.
151 // TODO: a better way to handle WSL vs SSH/remote projects,
152 let is_wsl_remote = project
153 .read(cx)
154 .remote_connection_options(cx)
155 .map_or(false, |opts| {
156 matches!(opts, RemoteConnectionOptions::Wsl(_))
157 });
158 let kernel_specifications_task = python_env_kernel_specifications(project, worktree_id, cx);
159 let active_toolchain = project.read(cx).active_toolchain(
160 ProjectPath {
161 worktree_id,
162 path: RelPath::empty().into(),
163 },
164 LanguageName::new_static("Python"),
165 cx,
166 );
167
168 cx.spawn(async move |this, cx| {
169 let kernel_specifications_res = kernel_specifications_task.await;
170
171 this.update(cx, |this, _cx| {
172 this.fetching_python_kernelspecs.remove(&worktree_id);
173 })
174 .ok();
175
176 let kernel_specifications =
177 kernel_specifications_res.context("getting python kernelspecs")?;
178
179 let active_toolchain_path = active_toolchain.await.map(|toolchain| toolchain.path);
180
181 this.update(cx, |this, cx| {
182 this.kernel_specifications_for_worktree
183 .insert(worktree_id, kernel_specifications);
184 if let Some(path) = active_toolchain_path {
185 this.active_python_toolchain_for_worktree
186 .insert(worktree_id, path);
187 }
188 if is_remote && !is_wsl_remote {
189 this.remote_worktrees.insert(worktree_id);
190 } else {
191 this.remote_worktrees.remove(&worktree_id);
192 }
193 cx.notify();
194 })
195 })
196 }
197
198 fn get_remote_kernel_specifications(
199 &self,
200 cx: &mut Context<Self>,
201 ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
202 match (
203 std::env::var("JUPYTER_SERVER"),
204 std::env::var("JUPYTER_TOKEN"),
205 ) {
206 (Ok(server), Ok(token)) => {
207 let remote_server = RemoteServer {
208 base_url: server,
209 token,
210 };
211 let http_client = cx.http_client();
212 Some(cx.spawn(async move |_, _| {
213 list_remote_kernelspecs(remote_server, http_client)
214 .await
215 .map(|specs| {
216 specs
217 .into_iter()
218 .map(KernelSpecification::JupyterServer)
219 .collect()
220 })
221 }))
222 }
223 _ => None,
224 }
225 }
226
227 pub fn ensure_kernelspecs(&mut self, cx: &mut Context<Self>) {
228 if self.kernelspecs_initialized {
229 return;
230 }
231 self.kernelspecs_initialized = true;
232 self.refresh_kernelspecs(cx).detach_and_log_err(cx);
233 }
234
235 pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
236 let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
237 let wsl_kernel_specifications = wsl_kernel_specifications(cx.background_executor().clone());
238 let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
239
240 let all_specs = cx.background_spawn(async move {
241 let mut all_specs = local_kernel_specifications
242 .await?
243 .into_iter()
244 .map(KernelSpecification::Jupyter)
245 .collect::<Vec<_>>();
246
247 if let Ok(wsl_specs) = wsl_kernel_specifications.await {
248 all_specs.extend(wsl_specs);
249 }
250
251 if let Some(remote_task) = remote_kernel_specifications
252 && let Ok(remote_specs) = remote_task.await
253 {
254 all_specs.extend(remote_specs);
255 }
256
257 anyhow::Ok(all_specs)
258 });
259
260 cx.spawn(async move |this, cx| {
261 let all_specs = all_specs.await;
262
263 if let Ok(specs) = all_specs {
264 this.update(cx, |this, cx| {
265 this.kernel_specifications = specs;
266 cx.notify();
267 })
268 .ok();
269 }
270
271 anyhow::Ok(())
272 })
273 }
274
275 pub fn set_active_kernelspec(
276 &mut self,
277 worktree_id: WorktreeId,
278 kernelspec: KernelSpecification,
279 _cx: &mut Context<Self>,
280 ) {
281 self.selected_kernel_for_worktree
282 .insert(worktree_id, kernelspec);
283 }
284
285 pub fn active_python_toolchain_path(&self, worktree_id: WorktreeId) -> Option<&SharedString> {
286 self.active_python_toolchain_for_worktree.get(&worktree_id)
287 }
288
289 pub fn selected_kernel(&self, worktree_id: WorktreeId) -> Option<&KernelSpecification> {
290 self.selected_kernel_for_worktree.get(&worktree_id)
291 }
292
293 pub fn is_recommended_kernel(
294 &self,
295 worktree_id: WorktreeId,
296 spec: &KernelSpecification,
297 ) -> bool {
298 if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
299 spec.path().as_ref() == active_path.as_ref()
300 } else {
301 false
302 }
303 }
304
305 pub fn active_kernelspec(
306 &self,
307 worktree_id: WorktreeId,
308 language_at_cursor: Option<Arc<Language>>,
309 cx: &App,
310 ) -> Option<KernelSpecification> {
311 if let Some(selected) = self.selected_kernel_for_worktree.get(&worktree_id).cloned() {
312 return Some(selected);
313 }
314
315 let language_at_cursor = language_at_cursor?;
316
317 // Prefer the recommended (active toolchain) kernel if it has ipykernel
318 if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
319 let recommended = self
320 .kernel_specifications_for_worktree(worktree_id)
321 .find(|spec| {
322 spec.has_ipykernel()
323 && language_at_cursor.matches_kernel_language(spec.language().as_ref())
324 && spec.path().as_ref() == active_path.as_ref()
325 })
326 .cloned();
327 if recommended.is_some() {
328 return recommended;
329 }
330 }
331
332 // Then try the first PythonEnv with ipykernel matching the language
333 let python_env = self
334 .kernel_specifications_for_worktree(worktree_id)
335 .find(|spec| {
336 matches!(spec, KernelSpecification::PythonEnv(_))
337 && spec.has_ipykernel()
338 && language_at_cursor.matches_kernel_language(spec.language().as_ref())
339 })
340 .cloned();
341 if python_env.is_some() {
342 return python_env;
343 }
344
345 // Fall back to legacy name-based and language-based matching
346 self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
347 }
348
349 fn kernelspec_legacy_by_lang_only(
350 &self,
351 worktree_id: WorktreeId,
352 language_at_cursor: Arc<Language>,
353 cx: &App,
354 ) -> Option<KernelSpecification> {
355 let settings = JupyterSettings::get_global(cx);
356 let selected_kernel = settings
357 .kernel_selections
358 .get(language_at_cursor.code_fence_block_name().as_ref());
359
360 let found_by_name = self
361 .kernel_specifications_for_worktree(worktree_id)
362 .find(|runtime_specification| {
363 if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
364 (selected_kernel, runtime_specification)
365 {
366 return runtime_specification.name.to_lowercase() == selected.to_lowercase();
367 }
368 false
369 })
370 .cloned();
371
372 if let Some(found_by_name) = found_by_name {
373 return Some(found_by_name);
374 }
375
376 self.kernel_specifications_for_worktree(worktree_id)
377 .find(|spec| {
378 spec.has_ipykernel()
379 && language_at_cursor.matches_kernel_language(spec.language().as_ref())
380 })
381 .cloned()
382 }
383
384 pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
385 self.sessions.get(&entity_id)
386 }
387
388 pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
389 self.sessions.insert(entity_id, session);
390 }
391
392 pub fn remove_session(&mut self, entity_id: EntityId) {
393 self.sessions.remove(&entity_id);
394 }
395
396 fn shutdown_all_sessions(
397 &mut self,
398 cx: &mut Context<Self>,
399 ) -> impl Future<Output = ()> + use<> {
400 for session in self.sessions.values() {
401 session.update(cx, |session, _cx| {
402 if let Kernel::RunningKernel(mut kernel) =
403 std::mem::replace(&mut session.kernel, Kernel::Shutdown)
404 {
405 kernel.kill();
406 }
407 });
408 }
409 self.sessions.clear();
410 futures::future::ready(())
411 }
412
413 #[cfg(test)]
414 pub fn set_kernel_specs_for_testing(
415 &mut self,
416 specs: Vec<KernelSpecification>,
417 cx: &mut Context<Self>,
418 ) {
419 self.kernel_specifications = specs;
420 cx.notify();
421 }
422}