repl_store.rs

  1use std::future::Future;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use collections::{HashMap, HashSet};
  6use command_palette_hooks::CommandPaletteFilter;
  7use gpui::{App, Context, Entity, EntityId, Global, SharedString, Subscription, Task, prelude::*};
  8use jupyter_websocket_client::RemoteServer;
  9use language::{Language, LanguageName};
 10use project::{Fs, Project, ProjectPath, WorktreeId};
 11use remote::RemoteConnectionOptions;
 12use settings::{Settings, SettingsStore};
 13use util::rel_path::RelPath;
 14
 15use crate::kernels::{
 16    Kernel, PythonEnvKernelSpecification, list_remote_kernelspecs, local_kernel_specifications,
 17    python_env_kernel_specifications, wsl_kernel_specifications,
 18};
 19use crate::{JupyterSettings, KernelSpecification, Session};
 20
 21struct GlobalReplStore(Entity<ReplStore>);
 22
 23impl Global for GlobalReplStore {}
 24
 25pub struct ReplStore {
 26    fs: Arc<dyn Fs>,
 27    enabled: bool,
 28    sessions: HashMap<EntityId, Entity<Session>>,
 29    kernel_specifications: Vec<KernelSpecification>,
 30    kernelspecs_initialized: bool,
 31    selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
 32    kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
 33    active_python_toolchain_for_worktree: HashMap<WorktreeId, SharedString>,
 34    remote_worktrees: HashSet<WorktreeId>,
 35    fetching_python_kernelspecs: HashSet<WorktreeId>,
 36    _subscriptions: Vec<Subscription>,
 37}
 38
 39impl ReplStore {
 40    const NAMESPACE: &'static str = "repl";
 41
 42    pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
 43        let store = cx.new(move |cx| Self::new(fs, cx));
 44        cx.set_global(GlobalReplStore(store))
 45    }
 46
 47    pub fn global(cx: &App) -> Entity<Self> {
 48        cx.global::<GlobalReplStore>().0.clone()
 49    }
 50
 51    pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
 52        let subscriptions = vec![
 53            cx.observe_global::<SettingsStore>(move |this, cx| {
 54                this.set_enabled(JupyterSettings::enabled(cx), cx);
 55            }),
 56            cx.on_app_quit(Self::shutdown_all_sessions),
 57        ];
 58
 59        let this = Self {
 60            fs,
 61            enabled: JupyterSettings::enabled(cx),
 62            sessions: HashMap::default(),
 63            kernel_specifications: Vec::new(),
 64            kernelspecs_initialized: false,
 65            _subscriptions: subscriptions,
 66            kernel_specifications_for_worktree: HashMap::default(),
 67            selected_kernel_for_worktree: HashMap::default(),
 68            active_python_toolchain_for_worktree: HashMap::default(),
 69            remote_worktrees: HashSet::default(),
 70            fetching_python_kernelspecs: HashSet::default(),
 71        };
 72        this.on_enabled_changed(cx);
 73        this
 74    }
 75
 76    pub fn fs(&self) -> &Arc<dyn Fs> {
 77        &self.fs
 78    }
 79
 80    pub fn is_enabled(&self) -> bool {
 81        self.enabled
 82    }
 83
 84    pub fn has_python_kernelspecs(&self, worktree_id: WorktreeId) -> bool {
 85        self.kernel_specifications_for_worktree
 86            .contains_key(&worktree_id)
 87    }
 88
 89    pub fn kernel_specifications_for_worktree(
 90        &self,
 91        worktree_id: WorktreeId,
 92    ) -> impl Iterator<Item = &KernelSpecification> {
 93        let global_specs = if self.remote_worktrees.contains(&worktree_id) {
 94            None
 95        } else {
 96            Some(self.kernel_specifications.iter())
 97        };
 98
 99        self.kernel_specifications_for_worktree
100            .get(&worktree_id)
101            .into_iter()
102            .flat_map(|specs| specs.iter())
103            .chain(global_specs.into_iter().flatten())
104    }
105
106    pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
107        self.kernel_specifications.iter()
108    }
109
110    pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
111        self.sessions.values()
112    }
113
114    fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
115        if self.enabled == enabled {
116            return;
117        }
118
119        self.enabled = enabled;
120        self.on_enabled_changed(cx);
121    }
122
123    fn on_enabled_changed(&self, cx: &mut Context<Self>) {
124        if !self.enabled {
125            CommandPaletteFilter::update_global(cx, |filter, _cx| {
126                filter.hide_namespace(Self::NAMESPACE);
127            });
128
129            return;
130        }
131
132        CommandPaletteFilter::update_global(cx, |filter, _cx| {
133            filter.show_namespace(Self::NAMESPACE);
134        });
135
136        cx.notify();
137    }
138
139    pub fn mark_ipykernel_installed(
140        &mut self,
141        cx: &mut Context<Self>,
142        spec: &PythonEnvKernelSpecification,
143    ) {
144        for specs in self.kernel_specifications_for_worktree.values_mut() {
145            for kernel_spec in specs.iter_mut() {
146                if let KernelSpecification::PythonEnv(env_spec) = kernel_spec {
147                    if env_spec == spec {
148                        env_spec.has_ipykernel = true;
149                    }
150                }
151            }
152        }
153        cx.notify();
154    }
155
156    pub fn refresh_python_kernelspecs(
157        &mut self,
158        worktree_id: WorktreeId,
159        project: &Entity<Project>,
160        cx: &mut Context<Self>,
161    ) -> Task<Result<()>> {
162        if !self.fetching_python_kernelspecs.insert(worktree_id) {
163            return Task::ready(Ok(()));
164        }
165
166        let is_remote = project.read(cx).is_remote();
167        // WSL does require access to global kernel specs, so we only exclude remote worktrees that aren't WSL.
168        // TODO: a better way to handle WSL vs SSH/remote projects,
169        let is_wsl_remote = project
170            .read(cx)
171            .remote_connection_options(cx)
172            .map_or(false, |opts| {
173                matches!(opts, RemoteConnectionOptions::Wsl(_))
174            });
175        let kernel_specifications_task = python_env_kernel_specifications(project, worktree_id, cx);
176        let active_toolchain = project.read(cx).active_toolchain(
177            ProjectPath {
178                worktree_id,
179                path: RelPath::empty().into(),
180            },
181            LanguageName::new_static("Python"),
182            cx,
183        );
184
185        cx.spawn(async move |this, cx| {
186            let kernel_specifications_res = kernel_specifications_task.await;
187
188            this.update(cx, |this, _cx| {
189                this.fetching_python_kernelspecs.remove(&worktree_id);
190            })
191            .ok();
192
193            let kernel_specifications =
194                kernel_specifications_res.context("getting python kernelspecs")?;
195
196            let active_toolchain_path = active_toolchain.await.map(|toolchain| toolchain.path);
197
198            this.update(cx, |this, cx| {
199                this.kernel_specifications_for_worktree
200                    .insert(worktree_id, kernel_specifications);
201                if let Some(path) = active_toolchain_path {
202                    this.active_python_toolchain_for_worktree
203                        .insert(worktree_id, path);
204                }
205                if is_remote && !is_wsl_remote {
206                    this.remote_worktrees.insert(worktree_id);
207                } else {
208                    this.remote_worktrees.remove(&worktree_id);
209                }
210                cx.notify();
211            })
212        })
213    }
214
215    fn get_remote_kernel_specifications(
216        &self,
217        cx: &mut Context<Self>,
218    ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
219        match (
220            std::env::var("JUPYTER_SERVER"),
221            std::env::var("JUPYTER_TOKEN"),
222        ) {
223            (Ok(server), Ok(token)) => {
224                let remote_server = RemoteServer {
225                    base_url: server,
226                    token,
227                };
228                let http_client = cx.http_client();
229                Some(cx.spawn(async move |_, _| {
230                    list_remote_kernelspecs(remote_server, http_client)
231                        .await
232                        .map(|specs| {
233                            specs
234                                .into_iter()
235                                .map(KernelSpecification::JupyterServer)
236                                .collect()
237                        })
238                }))
239            }
240            _ => None,
241        }
242    }
243
244    pub fn ensure_kernelspecs(&mut self, cx: &mut Context<Self>) {
245        if self.kernelspecs_initialized {
246            return;
247        }
248        self.kernelspecs_initialized = true;
249        self.refresh_kernelspecs(cx).detach_and_log_err(cx);
250    }
251
252    pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
253        let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
254        let wsl_kernel_specifications = wsl_kernel_specifications(cx.background_executor().clone());
255        let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
256
257        let all_specs = cx.background_spawn(async move {
258            let mut all_specs = local_kernel_specifications
259                .await?
260                .into_iter()
261                .map(KernelSpecification::Jupyter)
262                .collect::<Vec<_>>();
263
264            if let Ok(wsl_specs) = wsl_kernel_specifications.await {
265                all_specs.extend(wsl_specs);
266            }
267
268            if let Some(remote_task) = remote_kernel_specifications
269                && let Ok(remote_specs) = remote_task.await
270            {
271                all_specs.extend(remote_specs);
272            }
273
274            anyhow::Ok(all_specs)
275        });
276
277        cx.spawn(async move |this, cx| {
278            let all_specs = all_specs.await;
279
280            if let Ok(specs) = all_specs {
281                this.update(cx, |this, cx| {
282                    this.kernel_specifications = specs;
283                    cx.notify();
284                })
285                .ok();
286            }
287
288            anyhow::Ok(())
289        })
290    }
291
292    pub fn set_active_kernelspec(
293        &mut self,
294        worktree_id: WorktreeId,
295        kernelspec: KernelSpecification,
296        _cx: &mut Context<Self>,
297    ) {
298        self.selected_kernel_for_worktree
299            .insert(worktree_id, kernelspec);
300    }
301
302    pub fn active_python_toolchain_path(&self, worktree_id: WorktreeId) -> Option<&SharedString> {
303        self.active_python_toolchain_for_worktree.get(&worktree_id)
304    }
305
306    pub fn selected_kernel(&self, worktree_id: WorktreeId) -> Option<&KernelSpecification> {
307        self.selected_kernel_for_worktree.get(&worktree_id)
308    }
309
310    pub fn is_recommended_kernel(
311        &self,
312        worktree_id: WorktreeId,
313        spec: &KernelSpecification,
314    ) -> bool {
315        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
316            spec.path().as_ref() == active_path.as_ref()
317        } else {
318            false
319        }
320    }
321
322    pub fn active_kernelspec(
323        &self,
324        worktree_id: WorktreeId,
325        language_at_cursor: Option<Arc<Language>>,
326        cx: &App,
327    ) -> Option<KernelSpecification> {
328        if let Some(selected) = self.selected_kernel_for_worktree.get(&worktree_id).cloned() {
329            return Some(selected);
330        }
331
332        let language_at_cursor = language_at_cursor?;
333
334        // Prefer the recommended (active toolchain) kernel if it has ipykernel
335        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
336            let recommended = self
337                .kernel_specifications_for_worktree(worktree_id)
338                .find(|spec| {
339                    spec.has_ipykernel()
340                        && language_at_cursor.matches_kernel_language(spec.language().as_ref())
341                        && spec.path().as_ref() == active_path.as_ref()
342                })
343                .cloned();
344            if recommended.is_some() {
345                return recommended;
346            }
347        }
348
349        // Then try the first PythonEnv with ipykernel matching the language
350        let python_env = self
351            .kernel_specifications_for_worktree(worktree_id)
352            .find(|spec| {
353                matches!(spec, KernelSpecification::PythonEnv(_))
354                    && spec.has_ipykernel()
355                    && language_at_cursor.matches_kernel_language(spec.language().as_ref())
356            })
357            .cloned();
358        if python_env.is_some() {
359            return python_env;
360        }
361
362        // Fall back to legacy name-based and language-based matching
363        self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
364    }
365
366    fn kernelspec_legacy_by_lang_only(
367        &self,
368        worktree_id: WorktreeId,
369        language_at_cursor: Arc<Language>,
370        cx: &App,
371    ) -> Option<KernelSpecification> {
372        let settings = JupyterSettings::get_global(cx);
373        let selected_kernel = settings
374            .kernel_selections
375            .get(language_at_cursor.code_fence_block_name().as_ref());
376
377        let found_by_name = self
378            .kernel_specifications_for_worktree(worktree_id)
379            .find(|runtime_specification| {
380                if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
381                    (selected_kernel, runtime_specification)
382                {
383                    return runtime_specification.name.to_lowercase() == selected.to_lowercase();
384                }
385                false
386            })
387            .cloned();
388
389        if let Some(found_by_name) = found_by_name {
390            return Some(found_by_name);
391        }
392
393        self.kernel_specifications_for_worktree(worktree_id)
394            .find(|spec| {
395                spec.has_ipykernel()
396                    && language_at_cursor.matches_kernel_language(spec.language().as_ref())
397            })
398            .cloned()
399    }
400
401    pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
402        self.sessions.get(&entity_id)
403    }
404
405    pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
406        self.sessions.insert(entity_id, session);
407    }
408
409    pub fn remove_session(&mut self, entity_id: EntityId) {
410        self.sessions.remove(&entity_id);
411    }
412
413    fn shutdown_all_sessions(
414        &mut self,
415        cx: &mut Context<Self>,
416    ) -> impl Future<Output = ()> + use<> {
417        for session in self.sessions.values() {
418            session.update(cx, |session, _cx| {
419                if let Kernel::RunningKernel(mut kernel) =
420                    std::mem::replace(&mut session.kernel, Kernel::Shutdown)
421                {
422                    kernel.kill();
423                }
424            });
425        }
426        self.sessions.clear();
427        futures::future::ready(())
428    }
429
430    #[cfg(test)]
431    pub fn set_kernel_specs_for_testing(
432        &mut self,
433        specs: Vec<KernelSpecification>,
434        cx: &mut Context<Self>,
435    ) {
436        self.kernel_specifications = specs;
437        cx.notify();
438    }
439}