repl_store.rs

  1use std::future::Future;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use collections::{HashMap, HashSet};
  6use command_palette_hooks::CommandPaletteFilter;
  7use gpui::{
  8    App, Context, Entity, EntityId, Global, SharedString, Subscription, Task, TaskExt, prelude::*,
  9};
 10use jupyter_websocket_client::RemoteServer;
 11use language::{Language, LanguageName};
 12use project::{Fs, Project, ProjectPath, WorktreeId};
 13use remote::RemoteConnectionOptions;
 14use settings::{Settings, SettingsStore};
 15use util::rel_path::RelPath;
 16
 17use crate::kernels::{
 18    Kernel, PythonEnvKernelSpecification, list_remote_kernelspecs, local_kernel_specifications,
 19    python_env_kernel_specifications, wsl_kernel_specifications,
 20};
 21use crate::{JupyterSettings, KernelSpecification, Session};
 22
 23struct GlobalReplStore(Entity<ReplStore>);
 24
 25impl Global for GlobalReplStore {}
 26
 27pub struct ReplStore {
 28    fs: Arc<dyn Fs>,
 29    enabled: bool,
 30    sessions: HashMap<EntityId, Entity<Session>>,
 31    kernel_specifications: Vec<KernelSpecification>,
 32    kernelspecs_initialized: bool,
 33    selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
 34    kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
 35    active_python_toolchain_for_worktree: HashMap<WorktreeId, SharedString>,
 36    remote_worktrees: HashSet<WorktreeId>,
 37    fetching_python_kernelspecs: HashSet<WorktreeId>,
 38    _subscriptions: Vec<Subscription>,
 39}
 40
 41impl ReplStore {
 42    const NAMESPACE: &'static str = "repl";
 43
 44    pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
 45        let store = cx.new(move |cx| Self::new(fs, cx));
 46        cx.set_global(GlobalReplStore(store))
 47    }
 48
 49    pub fn global(cx: &App) -> Entity<Self> {
 50        cx.global::<GlobalReplStore>().0.clone()
 51    }
 52
 53    pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
 54        let subscriptions = vec![
 55            cx.observe_global::<SettingsStore>(move |this, cx| {
 56                this.set_enabled(JupyterSettings::enabled(cx), cx);
 57            }),
 58            cx.on_app_quit(Self::shutdown_all_sessions),
 59        ];
 60
 61        let this = Self {
 62            fs,
 63            enabled: JupyterSettings::enabled(cx),
 64            sessions: HashMap::default(),
 65            kernel_specifications: Vec::new(),
 66            kernelspecs_initialized: false,
 67            _subscriptions: subscriptions,
 68            kernel_specifications_for_worktree: HashMap::default(),
 69            selected_kernel_for_worktree: HashMap::default(),
 70            active_python_toolchain_for_worktree: HashMap::default(),
 71            remote_worktrees: HashSet::default(),
 72            fetching_python_kernelspecs: HashSet::default(),
 73        };
 74        this.on_enabled_changed(cx);
 75        this
 76    }
 77
 78    pub fn fs(&self) -> &Arc<dyn Fs> {
 79        &self.fs
 80    }
 81
 82    pub fn is_enabled(&self) -> bool {
 83        self.enabled
 84    }
 85
 86    pub fn has_python_kernelspecs(&self, worktree_id: WorktreeId) -> bool {
 87        self.kernel_specifications_for_worktree
 88            .contains_key(&worktree_id)
 89    }
 90
 91    pub fn kernel_specifications_for_worktree(
 92        &self,
 93        worktree_id: WorktreeId,
 94    ) -> impl Iterator<Item = &KernelSpecification> {
 95        let global_specs = if self.remote_worktrees.contains(&worktree_id) {
 96            None
 97        } else {
 98            Some(self.kernel_specifications.iter())
 99        };
100
101        self.kernel_specifications_for_worktree
102            .get(&worktree_id)
103            .into_iter()
104            .flat_map(|specs| specs.iter())
105            .chain(global_specs.into_iter().flatten())
106    }
107
108    pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
109        self.kernel_specifications.iter()
110    }
111
112    pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
113        self.sessions.values()
114    }
115
116    fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
117        if self.enabled == enabled {
118            return;
119        }
120
121        self.enabled = enabled;
122        self.on_enabled_changed(cx);
123    }
124
125    fn on_enabled_changed(&self, cx: &mut Context<Self>) {
126        if !self.enabled {
127            CommandPaletteFilter::update_global(cx, |filter, _cx| {
128                filter.hide_namespace(Self::NAMESPACE);
129            });
130
131            return;
132        }
133
134        CommandPaletteFilter::update_global(cx, |filter, _cx| {
135            filter.show_namespace(Self::NAMESPACE);
136        });
137
138        cx.notify();
139    }
140
141    pub fn mark_ipykernel_installed(
142        &mut self,
143        cx: &mut Context<Self>,
144        spec: &PythonEnvKernelSpecification,
145    ) {
146        for specs in self.kernel_specifications_for_worktree.values_mut() {
147            for kernel_spec in specs.iter_mut() {
148                if let KernelSpecification::PythonEnv(env_spec) = kernel_spec {
149                    if env_spec == spec {
150                        env_spec.has_ipykernel = true;
151                    }
152                }
153            }
154        }
155        cx.notify();
156    }
157
158    pub fn refresh_python_kernelspecs(
159        &mut self,
160        worktree_id: WorktreeId,
161        project: &Entity<Project>,
162        cx: &mut Context<Self>,
163    ) -> Task<Result<()>> {
164        if !self.fetching_python_kernelspecs.insert(worktree_id) {
165            return Task::ready(Ok(()));
166        }
167
168        let is_remote = project.read(cx).is_remote();
169        // WSL does require access to global kernel specs, so we only exclude remote worktrees that aren't WSL.
170        // TODO: a better way to handle WSL vs SSH/remote projects,
171        let is_wsl_remote = project
172            .read(cx)
173            .remote_connection_options(cx)
174            .map_or(false, |opts| {
175                matches!(opts, RemoteConnectionOptions::Wsl(_))
176            });
177        let kernel_specifications_task = python_env_kernel_specifications(project, worktree_id, cx);
178        let active_toolchain = project.read(cx).active_toolchain(
179            ProjectPath {
180                worktree_id,
181                path: RelPath::empty().into(),
182            },
183            LanguageName::new_static("Python"),
184            cx,
185        );
186
187        cx.spawn(async move |this, cx| {
188            let kernel_specifications_res = kernel_specifications_task.await;
189
190            this.update(cx, |this, _cx| {
191                this.fetching_python_kernelspecs.remove(&worktree_id);
192            })
193            .ok();
194
195            let kernel_specifications =
196                kernel_specifications_res.context("getting python kernelspecs")?;
197
198            let active_toolchain_path = active_toolchain.await.map(|toolchain| toolchain.path);
199
200            this.update(cx, |this, cx| {
201                this.kernel_specifications_for_worktree
202                    .insert(worktree_id, kernel_specifications);
203                if let Some(path) = active_toolchain_path {
204                    this.active_python_toolchain_for_worktree
205                        .insert(worktree_id, path);
206                }
207                if is_remote && !is_wsl_remote {
208                    this.remote_worktrees.insert(worktree_id);
209                } else {
210                    this.remote_worktrees.remove(&worktree_id);
211                }
212                cx.notify();
213            })
214        })
215    }
216
217    fn get_remote_kernel_specifications(
218        &self,
219        cx: &mut Context<Self>,
220    ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
221        match (
222            std::env::var("JUPYTER_SERVER"),
223            std::env::var("JUPYTER_TOKEN"),
224        ) {
225            (Ok(server), Ok(token)) => {
226                let remote_server = RemoteServer {
227                    base_url: server,
228                    token,
229                };
230                let http_client = cx.http_client();
231                Some(cx.spawn(async move |_, _| {
232                    list_remote_kernelspecs(remote_server, http_client)
233                        .await
234                        .map(|specs| {
235                            specs
236                                .into_iter()
237                                .map(KernelSpecification::JupyterServer)
238                                .collect()
239                        })
240                }))
241            }
242            _ => None,
243        }
244    }
245
246    pub fn ensure_kernelspecs(&mut self, cx: &mut Context<Self>) {
247        if self.kernelspecs_initialized {
248            return;
249        }
250        self.kernelspecs_initialized = true;
251        self.refresh_kernelspecs(cx).detach_and_log_err(cx);
252    }
253
254    pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
255        let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
256        let wsl_kernel_specifications = wsl_kernel_specifications(cx.background_executor().clone());
257        let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
258
259        let all_specs = cx.background_spawn(async move {
260            let mut all_specs = local_kernel_specifications
261                .await?
262                .into_iter()
263                .map(KernelSpecification::Jupyter)
264                .collect::<Vec<_>>();
265
266            if let Ok(wsl_specs) = wsl_kernel_specifications.await {
267                all_specs.extend(wsl_specs);
268            }
269
270            if let Some(remote_task) = remote_kernel_specifications
271                && let Ok(remote_specs) = remote_task.await
272            {
273                all_specs.extend(remote_specs);
274            }
275
276            anyhow::Ok(all_specs)
277        });
278
279        cx.spawn(async move |this, cx| {
280            let all_specs = all_specs.await;
281
282            if let Ok(specs) = all_specs {
283                this.update(cx, |this, cx| {
284                    this.kernel_specifications = specs;
285                    cx.notify();
286                })
287                .ok();
288            }
289
290            anyhow::Ok(())
291        })
292    }
293
294    pub fn set_active_kernelspec(
295        &mut self,
296        worktree_id: WorktreeId,
297        kernelspec: KernelSpecification,
298        _cx: &mut Context<Self>,
299    ) {
300        self.selected_kernel_for_worktree
301            .insert(worktree_id, kernelspec);
302    }
303
304    pub fn active_python_toolchain_path(&self, worktree_id: WorktreeId) -> Option<&SharedString> {
305        self.active_python_toolchain_for_worktree.get(&worktree_id)
306    }
307
308    pub fn selected_kernel(&self, worktree_id: WorktreeId) -> Option<&KernelSpecification> {
309        self.selected_kernel_for_worktree.get(&worktree_id)
310    }
311
312    pub fn is_recommended_kernel(
313        &self,
314        worktree_id: WorktreeId,
315        spec: &KernelSpecification,
316    ) -> bool {
317        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
318            spec.path().as_ref() == active_path.as_ref()
319        } else {
320            false
321        }
322    }
323
324    pub fn active_kernelspec(
325        &self,
326        worktree_id: WorktreeId,
327        language_at_cursor: Option<Arc<Language>>,
328        cx: &App,
329    ) -> Option<KernelSpecification> {
330        if let Some(selected) = self.selected_kernel_for_worktree.get(&worktree_id).cloned() {
331            return Some(selected);
332        }
333
334        let language_at_cursor = language_at_cursor?;
335
336        // Prefer the recommended (active toolchain) kernel if it has ipykernel
337        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
338            let recommended = self
339                .kernel_specifications_for_worktree(worktree_id)
340                .find(|spec| {
341                    spec.has_ipykernel()
342                        && language_at_cursor.matches_kernel_language(spec.language().as_ref())
343                        && spec.path().as_ref() == active_path.as_ref()
344                })
345                .cloned();
346            if recommended.is_some() {
347                return recommended;
348            }
349        }
350
351        // Then try the first PythonEnv with ipykernel matching the language
352        let python_env = self
353            .kernel_specifications_for_worktree(worktree_id)
354            .find(|spec| {
355                matches!(spec, KernelSpecification::PythonEnv(_))
356                    && spec.has_ipykernel()
357                    && language_at_cursor.matches_kernel_language(spec.language().as_ref())
358            })
359            .cloned();
360        if python_env.is_some() {
361            return python_env;
362        }
363
364        // Fall back to legacy name-based and language-based matching
365        self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
366    }
367
368    fn kernelspec_legacy_by_lang_only(
369        &self,
370        worktree_id: WorktreeId,
371        language_at_cursor: Arc<Language>,
372        cx: &App,
373    ) -> Option<KernelSpecification> {
374        let settings = JupyterSettings::get_global(cx);
375        let selected_kernel = settings
376            .kernel_selections
377            .get(language_at_cursor.code_fence_block_name().as_ref());
378
379        let found_by_name = self
380            .kernel_specifications_for_worktree(worktree_id)
381            .find(|runtime_specification| {
382                if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
383                    (selected_kernel, runtime_specification)
384                {
385                    return runtime_specification.name.to_lowercase() == selected.to_lowercase();
386                }
387                false
388            })
389            .cloned();
390
391        if let Some(found_by_name) = found_by_name {
392            return Some(found_by_name);
393        }
394
395        self.kernel_specifications_for_worktree(worktree_id)
396            .find(|spec| {
397                spec.has_ipykernel()
398                    && language_at_cursor.matches_kernel_language(spec.language().as_ref())
399            })
400            .cloned()
401    }
402
403    pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
404        self.sessions.get(&entity_id)
405    }
406
407    pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
408        self.sessions.insert(entity_id, session);
409    }
410
411    pub fn remove_session(&mut self, entity_id: EntityId) {
412        self.sessions.remove(&entity_id);
413    }
414
415    fn shutdown_all_sessions(
416        &mut self,
417        cx: &mut Context<Self>,
418    ) -> impl Future<Output = ()> + use<> {
419        for session in self.sessions.values() {
420            session.update(cx, |session, _cx| {
421                if let Kernel::RunningKernel(mut kernel) =
422                    std::mem::replace(&mut session.kernel, Kernel::Shutdown)
423                {
424                    kernel.kill();
425                }
426            });
427        }
428        self.sessions.clear();
429        futures::future::ready(())
430    }
431
432    #[cfg(test)]
433    pub fn set_kernel_specs_for_testing(
434        &mut self,
435        specs: Vec<KernelSpecification>,
436        cx: &mut Context<Self>,
437    ) {
438        self.kernel_specifications = specs;
439        cx.notify();
440    }
441}