repl_store.rs

  1use std::future::Future;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use collections::{HashMap, HashSet};
  6use command_palette_hooks::CommandPaletteFilter;
  7use gpui::{App, Context, Entity, EntityId, Global, SharedString, Subscription, Task, prelude::*};
  8use jupyter_websocket_client::RemoteServer;
  9use language::{Language, LanguageName};
 10use project::{Fs, Project, ProjectPath, WorktreeId};
 11use remote::RemoteConnectionOptions;
 12use settings::{Settings, SettingsStore};
 13use util::rel_path::RelPath;
 14
 15use crate::kernels::{
 16    Kernel, list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
 17    wsl_kernel_specifications,
 18};
 19use crate::{JupyterSettings, KernelSpecification, Session};
 20
 21struct GlobalReplStore(Entity<ReplStore>);
 22
 23impl Global for GlobalReplStore {}
 24
 25pub struct ReplStore {
 26    fs: Arc<dyn Fs>,
 27    enabled: bool,
 28    sessions: HashMap<EntityId, Entity<Session>>,
 29    kernel_specifications: Vec<KernelSpecification>,
 30    kernelspecs_initialized: bool,
 31    selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
 32    kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
 33    active_python_toolchain_for_worktree: HashMap<WorktreeId, SharedString>,
 34    remote_worktrees: HashSet<WorktreeId>,
 35    _subscriptions: Vec<Subscription>,
 36}
 37
 38impl ReplStore {
 39    const NAMESPACE: &'static str = "repl";
 40
 41    pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
 42        let store = cx.new(move |cx| Self::new(fs, cx));
 43        cx.set_global(GlobalReplStore(store))
 44    }
 45
 46    pub fn global(cx: &App) -> Entity<Self> {
 47        cx.global::<GlobalReplStore>().0.clone()
 48    }
 49
 50    pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
 51        let subscriptions = vec![
 52            cx.observe_global::<SettingsStore>(move |this, cx| {
 53                this.set_enabled(JupyterSettings::enabled(cx), cx);
 54            }),
 55            cx.on_app_quit(Self::shutdown_all_sessions),
 56        ];
 57
 58        let this = Self {
 59            fs,
 60            enabled: JupyterSettings::enabled(cx),
 61            sessions: HashMap::default(),
 62            kernel_specifications: Vec::new(),
 63            kernelspecs_initialized: false,
 64            _subscriptions: subscriptions,
 65            kernel_specifications_for_worktree: HashMap::default(),
 66            selected_kernel_for_worktree: HashMap::default(),
 67            active_python_toolchain_for_worktree: HashMap::default(),
 68            remote_worktrees: HashSet::default(),
 69        };
 70        this.on_enabled_changed(cx);
 71        this
 72    }
 73
 74    pub fn fs(&self) -> &Arc<dyn Fs> {
 75        &self.fs
 76    }
 77
 78    pub fn is_enabled(&self) -> bool {
 79        self.enabled
 80    }
 81
 82    pub fn has_python_kernelspecs(&self, worktree_id: WorktreeId) -> bool {
 83        self.kernel_specifications_for_worktree
 84            .contains_key(&worktree_id)
 85    }
 86
 87    pub fn kernel_specifications_for_worktree(
 88        &self,
 89        worktree_id: WorktreeId,
 90    ) -> impl Iterator<Item = &KernelSpecification> {
 91        let global_specs = if self.remote_worktrees.contains(&worktree_id) {
 92            None
 93        } else {
 94            Some(self.kernel_specifications.iter())
 95        };
 96
 97        self.kernel_specifications_for_worktree
 98            .get(&worktree_id)
 99            .into_iter()
100            .flat_map(|specs| specs.iter())
101            .chain(global_specs.into_iter().flatten())
102    }
103
104    pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
105        self.kernel_specifications.iter()
106    }
107
108    pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
109        self.sessions.values()
110    }
111
112    fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
113        if self.enabled == enabled {
114            return;
115        }
116
117        self.enabled = enabled;
118        self.on_enabled_changed(cx);
119    }
120
121    fn on_enabled_changed(&self, cx: &mut Context<Self>) {
122        if !self.enabled {
123            CommandPaletteFilter::update_global(cx, |filter, _cx| {
124                filter.hide_namespace(Self::NAMESPACE);
125            });
126
127            return;
128        }
129
130        CommandPaletteFilter::update_global(cx, |filter, _cx| {
131            filter.show_namespace(Self::NAMESPACE);
132        });
133
134        cx.notify();
135    }
136
137    pub fn refresh_python_kernelspecs(
138        &mut self,
139        worktree_id: WorktreeId,
140        project: &Entity<Project>,
141        cx: &mut Context<Self>,
142    ) -> Task<Result<()>> {
143        let is_remote = project.read(cx).is_remote();
144        // WSL does require access to global kernel specs, so we only exclude remote worktrees that aren't WSL.
145        // TODO: a better way to handle WSL vs SSH/remote projects,
146        let is_wsl_remote = project
147            .read(cx)
148            .remote_connection_options(cx)
149            .map_or(false, |opts| {
150                matches!(opts, RemoteConnectionOptions::Wsl(_))
151            });
152        let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
153        let active_toolchain = project.read(cx).active_toolchain(
154            ProjectPath {
155                worktree_id,
156                path: RelPath::empty().into(),
157            },
158            LanguageName::new_static("Python"),
159            cx,
160        );
161
162        cx.spawn(async move |this, cx| {
163            let kernel_specifications = kernel_specifications
164                .await
165                .context("getting python kernelspecs")?;
166
167            let active_toolchain_path = active_toolchain.await.map(|toolchain| toolchain.path);
168
169            this.update(cx, |this, cx| {
170                this.kernel_specifications_for_worktree
171                    .insert(worktree_id, kernel_specifications);
172                if let Some(path) = active_toolchain_path {
173                    this.active_python_toolchain_for_worktree
174                        .insert(worktree_id, path);
175                }
176                if is_remote && !is_wsl_remote {
177                    this.remote_worktrees.insert(worktree_id);
178                } else {
179                    this.remote_worktrees.remove(&worktree_id);
180                }
181                cx.notify();
182            })
183        })
184    }
185
186    fn get_remote_kernel_specifications(
187        &self,
188        cx: &mut Context<Self>,
189    ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
190        match (
191            std::env::var("JUPYTER_SERVER"),
192            std::env::var("JUPYTER_TOKEN"),
193        ) {
194            (Ok(server), Ok(token)) => {
195                let remote_server = RemoteServer {
196                    base_url: server,
197                    token,
198                };
199                let http_client = cx.http_client();
200                Some(cx.spawn(async move |_, _| {
201                    list_remote_kernelspecs(remote_server, http_client)
202                        .await
203                        .map(|specs| {
204                            specs
205                                .into_iter()
206                                .map(KernelSpecification::JupyterServer)
207                                .collect()
208                        })
209                }))
210            }
211            _ => None,
212        }
213    }
214
215    pub fn ensure_kernelspecs(&mut self, cx: &mut Context<Self>) {
216        if self.kernelspecs_initialized {
217            return;
218        }
219        self.kernelspecs_initialized = true;
220        self.refresh_kernelspecs(cx).detach_and_log_err(cx);
221    }
222
223    pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
224        let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
225        let wsl_kernel_specifications = wsl_kernel_specifications(cx.background_executor().clone());
226        let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
227
228        let all_specs = cx.background_spawn(async move {
229            let mut all_specs = local_kernel_specifications
230                .await?
231                .into_iter()
232                .map(KernelSpecification::Jupyter)
233                .collect::<Vec<_>>();
234
235            if let Ok(wsl_specs) = wsl_kernel_specifications.await {
236                all_specs.extend(wsl_specs);
237            }
238
239            if let Some(remote_task) = remote_kernel_specifications
240                && let Ok(remote_specs) = remote_task.await
241            {
242                all_specs.extend(remote_specs);
243            }
244
245            anyhow::Ok(all_specs)
246        });
247
248        cx.spawn(async move |this, cx| {
249            let all_specs = all_specs.await;
250
251            if let Ok(specs) = all_specs {
252                this.update(cx, |this, cx| {
253                    this.kernel_specifications = specs;
254                    cx.notify();
255                })
256                .ok();
257            }
258
259            anyhow::Ok(())
260        })
261    }
262
263    pub fn set_active_kernelspec(
264        &mut self,
265        worktree_id: WorktreeId,
266        kernelspec: KernelSpecification,
267        _cx: &mut Context<Self>,
268    ) {
269        self.selected_kernel_for_worktree
270            .insert(worktree_id, kernelspec);
271    }
272
273    pub fn active_python_toolchain_path(&self, worktree_id: WorktreeId) -> Option<&SharedString> {
274        self.active_python_toolchain_for_worktree.get(&worktree_id)
275    }
276
277    pub fn selected_kernel(&self, worktree_id: WorktreeId) -> Option<&KernelSpecification> {
278        self.selected_kernel_for_worktree.get(&worktree_id)
279    }
280
281    pub fn is_recommended_kernel(
282        &self,
283        worktree_id: WorktreeId,
284        spec: &KernelSpecification,
285    ) -> bool {
286        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
287            spec.path().as_ref() == active_path.as_ref()
288        } else {
289            false
290        }
291    }
292
293    pub fn active_kernelspec(
294        &self,
295        worktree_id: WorktreeId,
296        language_at_cursor: Option<Arc<Language>>,
297        cx: &App,
298    ) -> Option<KernelSpecification> {
299        if let Some(selected) = self.selected_kernel_for_worktree.get(&worktree_id).cloned() {
300            return Some(selected);
301        }
302
303        let language_at_cursor = language_at_cursor?;
304
305        // Prefer the recommended (active toolchain) kernel if it has ipykernel
306        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
307            let recommended = self
308                .kernel_specifications_for_worktree(worktree_id)
309                .find(|spec| {
310                    spec.has_ipykernel()
311                        && language_at_cursor.matches_kernel_language(spec.language().as_ref())
312                        && spec.path().as_ref() == active_path.as_ref()
313                })
314                .cloned();
315            if recommended.is_some() {
316                return recommended;
317            }
318        }
319
320        // Then try the first PythonEnv with ipykernel matching the language
321        let python_env = self
322            .kernel_specifications_for_worktree(worktree_id)
323            .find(|spec| {
324                matches!(spec, KernelSpecification::PythonEnv(_))
325                    && spec.has_ipykernel()
326                    && language_at_cursor.matches_kernel_language(spec.language().as_ref())
327            })
328            .cloned();
329        if python_env.is_some() {
330            return python_env;
331        }
332
333        // Fall back to legacy name-based and language-based matching
334        self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
335    }
336
337    fn kernelspec_legacy_by_lang_only(
338        &self,
339        worktree_id: WorktreeId,
340        language_at_cursor: Arc<Language>,
341        cx: &App,
342    ) -> Option<KernelSpecification> {
343        let settings = JupyterSettings::get_global(cx);
344        let selected_kernel = settings
345            .kernel_selections
346            .get(language_at_cursor.code_fence_block_name().as_ref());
347
348        let found_by_name = self
349            .kernel_specifications_for_worktree(worktree_id)
350            .find(|runtime_specification| {
351                if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
352                    (selected_kernel, runtime_specification)
353                {
354                    return runtime_specification.name.to_lowercase() == selected.to_lowercase();
355                }
356                false
357            })
358            .cloned();
359
360        if let Some(found_by_name) = found_by_name {
361            return Some(found_by_name);
362        }
363
364        self.kernel_specifications_for_worktree(worktree_id)
365            .find(|spec| {
366                spec.has_ipykernel()
367                    && language_at_cursor.matches_kernel_language(spec.language().as_ref())
368            })
369            .cloned()
370    }
371
372    pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
373        self.sessions.get(&entity_id)
374    }
375
376    pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
377        self.sessions.insert(entity_id, session);
378    }
379
380    pub fn remove_session(&mut self, entity_id: EntityId) {
381        self.sessions.remove(&entity_id);
382    }
383
384    fn shutdown_all_sessions(
385        &mut self,
386        cx: &mut Context<Self>,
387    ) -> impl Future<Output = ()> + use<> {
388        for session in self.sessions.values() {
389            session.update(cx, |session, _cx| {
390                if let Kernel::RunningKernel(mut kernel) =
391                    std::mem::replace(&mut session.kernel, Kernel::Shutdown)
392                {
393                    kernel.kill();
394                }
395            });
396        }
397        self.sessions.clear();
398        futures::future::ready(())
399    }
400
401    #[cfg(test)]
402    pub fn set_kernel_specs_for_testing(
403        &mut self,
404        specs: Vec<KernelSpecification>,
405        cx: &mut Context<Self>,
406    ) {
407        self.kernel_specifications = specs;
408        cx.notify();
409    }
410}