repl_store.rs

  1use std::future::Future;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use collections::{HashMap, HashSet};
  6use command_palette_hooks::CommandPaletteFilter;
  7use gpui::{App, Context, Entity, EntityId, Global, SharedString, Subscription, Task, prelude::*};
  8use jupyter_websocket_client::RemoteServer;
  9use language::{Language, LanguageName};
 10use project::{Fs, Project, ProjectPath, WorktreeId};
 11use remote::RemoteConnectionOptions;
 12use settings::{Settings, SettingsStore};
 13use util::rel_path::RelPath;
 14
 15use crate::kernels::{
 16    Kernel, list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
 17    wsl_kernel_specifications,
 18};
 19use crate::{JupyterSettings, KernelSpecification, Session};
 20
 21struct GlobalReplStore(Entity<ReplStore>);
 22
 23impl Global for GlobalReplStore {}
 24
 25pub struct ReplStore {
 26    fs: Arc<dyn Fs>,
 27    enabled: bool,
 28    sessions: HashMap<EntityId, Entity<Session>>,
 29    kernel_specifications: Vec<KernelSpecification>,
 30    kernelspecs_initialized: bool,
 31    selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
 32    kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
 33    active_python_toolchain_for_worktree: HashMap<WorktreeId, SharedString>,
 34    remote_worktrees: HashSet<WorktreeId>,
 35    fetching_python_kernelspecs: HashSet<WorktreeId>,
 36    _subscriptions: Vec<Subscription>,
 37}
 38
 39impl ReplStore {
 40    const NAMESPACE: &'static str = "repl";
 41
 42    pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
 43        let store = cx.new(move |cx| Self::new(fs, cx));
 44        cx.set_global(GlobalReplStore(store))
 45    }
 46
 47    pub fn global(cx: &App) -> Entity<Self> {
 48        cx.global::<GlobalReplStore>().0.clone()
 49    }
 50
 51    pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
 52        let subscriptions = vec![
 53            cx.observe_global::<SettingsStore>(move |this, cx| {
 54                this.set_enabled(JupyterSettings::enabled(cx), cx);
 55            }),
 56            cx.on_app_quit(Self::shutdown_all_sessions),
 57        ];
 58
 59        let this = Self {
 60            fs,
 61            enabled: JupyterSettings::enabled(cx),
 62            sessions: HashMap::default(),
 63            kernel_specifications: Vec::new(),
 64            kernelspecs_initialized: false,
 65            _subscriptions: subscriptions,
 66            kernel_specifications_for_worktree: HashMap::default(),
 67            selected_kernel_for_worktree: HashMap::default(),
 68            active_python_toolchain_for_worktree: HashMap::default(),
 69            remote_worktrees: HashSet::default(),
 70            fetching_python_kernelspecs: HashSet::default(),
 71        };
 72        this.on_enabled_changed(cx);
 73        this
 74    }
 75
 76    pub fn fs(&self) -> &Arc<dyn Fs> {
 77        &self.fs
 78    }
 79
 80    pub fn is_enabled(&self) -> bool {
 81        self.enabled
 82    }
 83
 84    pub fn has_python_kernelspecs(&self, worktree_id: WorktreeId) -> bool {
 85        self.kernel_specifications_for_worktree
 86            .contains_key(&worktree_id)
 87    }
 88
 89    pub fn kernel_specifications_for_worktree(
 90        &self,
 91        worktree_id: WorktreeId,
 92    ) -> impl Iterator<Item = &KernelSpecification> {
 93        let global_specs = if self.remote_worktrees.contains(&worktree_id) {
 94            None
 95        } else {
 96            Some(self.kernel_specifications.iter())
 97        };
 98
 99        self.kernel_specifications_for_worktree
100            .get(&worktree_id)
101            .into_iter()
102            .flat_map(|specs| specs.iter())
103            .chain(global_specs.into_iter().flatten())
104    }
105
106    pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
107        self.kernel_specifications.iter()
108    }
109
110    pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
111        self.sessions.values()
112    }
113
114    fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
115        if self.enabled == enabled {
116            return;
117        }
118
119        self.enabled = enabled;
120        self.on_enabled_changed(cx);
121    }
122
123    fn on_enabled_changed(&self, cx: &mut Context<Self>) {
124        if !self.enabled {
125            CommandPaletteFilter::update_global(cx, |filter, _cx| {
126                filter.hide_namespace(Self::NAMESPACE);
127            });
128
129            return;
130        }
131
132        CommandPaletteFilter::update_global(cx, |filter, _cx| {
133            filter.show_namespace(Self::NAMESPACE);
134        });
135
136        cx.notify();
137    }
138
139    pub fn refresh_python_kernelspecs(
140        &mut self,
141        worktree_id: WorktreeId,
142        project: &Entity<Project>,
143        cx: &mut Context<Self>,
144    ) -> Task<Result<()>> {
145        if !self.fetching_python_kernelspecs.insert(worktree_id) {
146            return Task::ready(Ok(()));
147        }
148
149        let is_remote = project.read(cx).is_remote();
150        // WSL does require access to global kernel specs, so we only exclude remote worktrees that aren't WSL.
151        // TODO: a better way to handle WSL vs SSH/remote projects,
152        let is_wsl_remote = project
153            .read(cx)
154            .remote_connection_options(cx)
155            .map_or(false, |opts| {
156                matches!(opts, RemoteConnectionOptions::Wsl(_))
157            });
158        let kernel_specifications_task = python_env_kernel_specifications(project, worktree_id, cx);
159        let active_toolchain = project.read(cx).active_toolchain(
160            ProjectPath {
161                worktree_id,
162                path: RelPath::empty().into(),
163            },
164            LanguageName::new_static("Python"),
165            cx,
166        );
167
168        cx.spawn(async move |this, cx| {
169            let kernel_specifications_res = kernel_specifications_task.await;
170
171            this.update(cx, |this, _cx| {
172                this.fetching_python_kernelspecs.remove(&worktree_id);
173            })
174            .ok();
175
176            let kernel_specifications =
177                kernel_specifications_res.context("getting python kernelspecs")?;
178
179            let active_toolchain_path = active_toolchain.await.map(|toolchain| toolchain.path);
180
181            this.update(cx, |this, cx| {
182                this.kernel_specifications_for_worktree
183                    .insert(worktree_id, kernel_specifications);
184                if let Some(path) = active_toolchain_path {
185                    this.active_python_toolchain_for_worktree
186                        .insert(worktree_id, path);
187                }
188                if is_remote && !is_wsl_remote {
189                    this.remote_worktrees.insert(worktree_id);
190                } else {
191                    this.remote_worktrees.remove(&worktree_id);
192                }
193                cx.notify();
194            })
195        })
196    }
197
198    fn get_remote_kernel_specifications(
199        &self,
200        cx: &mut Context<Self>,
201    ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
202        match (
203            std::env::var("JUPYTER_SERVER"),
204            std::env::var("JUPYTER_TOKEN"),
205        ) {
206            (Ok(server), Ok(token)) => {
207                let remote_server = RemoteServer {
208                    base_url: server,
209                    token,
210                };
211                let http_client = cx.http_client();
212                Some(cx.spawn(async move |_, _| {
213                    list_remote_kernelspecs(remote_server, http_client)
214                        .await
215                        .map(|specs| {
216                            specs
217                                .into_iter()
218                                .map(KernelSpecification::JupyterServer)
219                                .collect()
220                        })
221                }))
222            }
223            _ => None,
224        }
225    }
226
227    pub fn ensure_kernelspecs(&mut self, cx: &mut Context<Self>) {
228        if self.kernelspecs_initialized {
229            return;
230        }
231        self.kernelspecs_initialized = true;
232        self.refresh_kernelspecs(cx).detach_and_log_err(cx);
233    }
234
235    pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
236        let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
237        let wsl_kernel_specifications = wsl_kernel_specifications(cx.background_executor().clone());
238        let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
239
240        let all_specs = cx.background_spawn(async move {
241            let mut all_specs = local_kernel_specifications
242                .await?
243                .into_iter()
244                .map(KernelSpecification::Jupyter)
245                .collect::<Vec<_>>();
246
247            if let Ok(wsl_specs) = wsl_kernel_specifications.await {
248                all_specs.extend(wsl_specs);
249            }
250
251            if let Some(remote_task) = remote_kernel_specifications
252                && let Ok(remote_specs) = remote_task.await
253            {
254                all_specs.extend(remote_specs);
255            }
256
257            anyhow::Ok(all_specs)
258        });
259
260        cx.spawn(async move |this, cx| {
261            let all_specs = all_specs.await;
262
263            if let Ok(specs) = all_specs {
264                this.update(cx, |this, cx| {
265                    this.kernel_specifications = specs;
266                    cx.notify();
267                })
268                .ok();
269            }
270
271            anyhow::Ok(())
272        })
273    }
274
275    pub fn set_active_kernelspec(
276        &mut self,
277        worktree_id: WorktreeId,
278        kernelspec: KernelSpecification,
279        _cx: &mut Context<Self>,
280    ) {
281        self.selected_kernel_for_worktree
282            .insert(worktree_id, kernelspec);
283    }
284
285    pub fn active_python_toolchain_path(&self, worktree_id: WorktreeId) -> Option<&SharedString> {
286        self.active_python_toolchain_for_worktree.get(&worktree_id)
287    }
288
289    pub fn selected_kernel(&self, worktree_id: WorktreeId) -> Option<&KernelSpecification> {
290        self.selected_kernel_for_worktree.get(&worktree_id)
291    }
292
293    pub fn is_recommended_kernel(
294        &self,
295        worktree_id: WorktreeId,
296        spec: &KernelSpecification,
297    ) -> bool {
298        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
299            spec.path().as_ref() == active_path.as_ref()
300        } else {
301            false
302        }
303    }
304
305    pub fn active_kernelspec(
306        &self,
307        worktree_id: WorktreeId,
308        language_at_cursor: Option<Arc<Language>>,
309        cx: &App,
310    ) -> Option<KernelSpecification> {
311        if let Some(selected) = self.selected_kernel_for_worktree.get(&worktree_id).cloned() {
312            return Some(selected);
313        }
314
315        let language_at_cursor = language_at_cursor?;
316
317        // Prefer the recommended (active toolchain) kernel if it has ipykernel
318        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
319            let recommended = self
320                .kernel_specifications_for_worktree(worktree_id)
321                .find(|spec| {
322                    spec.has_ipykernel()
323                        && language_at_cursor.matches_kernel_language(spec.language().as_ref())
324                        && spec.path().as_ref() == active_path.as_ref()
325                })
326                .cloned();
327            if recommended.is_some() {
328                return recommended;
329            }
330        }
331
332        // Then try the first PythonEnv with ipykernel matching the language
333        let python_env = self
334            .kernel_specifications_for_worktree(worktree_id)
335            .find(|spec| {
336                matches!(spec, KernelSpecification::PythonEnv(_))
337                    && spec.has_ipykernel()
338                    && language_at_cursor.matches_kernel_language(spec.language().as_ref())
339            })
340            .cloned();
341        if python_env.is_some() {
342            return python_env;
343        }
344
345        // Fall back to legacy name-based and language-based matching
346        self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
347    }
348
349    fn kernelspec_legacy_by_lang_only(
350        &self,
351        worktree_id: WorktreeId,
352        language_at_cursor: Arc<Language>,
353        cx: &App,
354    ) -> Option<KernelSpecification> {
355        let settings = JupyterSettings::get_global(cx);
356        let selected_kernel = settings
357            .kernel_selections
358            .get(language_at_cursor.code_fence_block_name().as_ref());
359
360        let found_by_name = self
361            .kernel_specifications_for_worktree(worktree_id)
362            .find(|runtime_specification| {
363                if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
364                    (selected_kernel, runtime_specification)
365                {
366                    return runtime_specification.name.to_lowercase() == selected.to_lowercase();
367                }
368                false
369            })
370            .cloned();
371
372        if let Some(found_by_name) = found_by_name {
373            return Some(found_by_name);
374        }
375
376        self.kernel_specifications_for_worktree(worktree_id)
377            .find(|spec| {
378                spec.has_ipykernel()
379                    && language_at_cursor.matches_kernel_language(spec.language().as_ref())
380            })
381            .cloned()
382    }
383
384    pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
385        self.sessions.get(&entity_id)
386    }
387
388    pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
389        self.sessions.insert(entity_id, session);
390    }
391
392    pub fn remove_session(&mut self, entity_id: EntityId) {
393        self.sessions.remove(&entity_id);
394    }
395
396    fn shutdown_all_sessions(
397        &mut self,
398        cx: &mut Context<Self>,
399    ) -> impl Future<Output = ()> + use<> {
400        for session in self.sessions.values() {
401            session.update(cx, |session, _cx| {
402                if let Kernel::RunningKernel(mut kernel) =
403                    std::mem::replace(&mut session.kernel, Kernel::Shutdown)
404                {
405                    kernel.kill();
406                }
407            });
408        }
409        self.sessions.clear();
410        futures::future::ready(())
411    }
412
413    #[cfg(test)]
414    pub fn set_kernel_specs_for_testing(
415        &mut self,
416        specs: Vec<KernelSpecification>,
417        cx: &mut Context<Self>,
418    ) {
419        self.kernel_specifications = specs;
420        cx.notify();
421    }
422}