repl_store.rs

  1use std::future::Future;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use collections::{HashMap, HashSet};
  6use command_palette_hooks::CommandPaletteFilter;
  7use gpui::{App, Context, Entity, EntityId, Global, SharedString, Subscription, Task, prelude::*};
  8use jupyter_websocket_client::RemoteServer;
  9use language::{Language, LanguageName};
 10use project::{Fs, Project, ProjectPath, WorktreeId};
 11use settings::{Settings, SettingsStore};
 12use util::rel_path::RelPath;
 13
 14use crate::kernels::{
 15    Kernel, list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
 16    wsl_kernel_specifications,
 17};
 18use crate::{JupyterSettings, KernelSpecification, Session};
 19
 20struct GlobalReplStore(Entity<ReplStore>);
 21
 22impl Global for GlobalReplStore {}
 23
 24pub struct ReplStore {
 25    fs: Arc<dyn Fs>,
 26    enabled: bool,
 27    sessions: HashMap<EntityId, Entity<Session>>,
 28    kernel_specifications: Vec<KernelSpecification>,
 29    selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
 30    kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
 31    active_python_toolchain_for_worktree: HashMap<WorktreeId, SharedString>,
 32    remote_worktrees: HashSet<WorktreeId>,
 33    _subscriptions: Vec<Subscription>,
 34}
 35
 36impl ReplStore {
 37    const NAMESPACE: &'static str = "repl";
 38
 39    pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
 40        let store = cx.new(move |cx| Self::new(fs, cx));
 41
 42        #[cfg(not(feature = "test-support"))]
 43        store
 44            .update(cx, |store, cx| store.refresh_kernelspecs(cx))
 45            .detach_and_log_err(cx);
 46
 47        cx.set_global(GlobalReplStore(store))
 48    }
 49
 50    pub fn global(cx: &App) -> Entity<Self> {
 51        cx.global::<GlobalReplStore>().0.clone()
 52    }
 53
 54    pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
 55        let subscriptions = vec![
 56            cx.observe_global::<SettingsStore>(move |this, cx| {
 57                this.set_enabled(JupyterSettings::enabled(cx), cx);
 58            }),
 59            cx.on_app_quit(Self::shutdown_all_sessions),
 60        ];
 61
 62        let this = Self {
 63            fs,
 64            enabled: JupyterSettings::enabled(cx),
 65            sessions: HashMap::default(),
 66            kernel_specifications: Vec::new(),
 67            _subscriptions: subscriptions,
 68            kernel_specifications_for_worktree: HashMap::default(),
 69            selected_kernel_for_worktree: HashMap::default(),
 70            active_python_toolchain_for_worktree: HashMap::default(),
 71            remote_worktrees: HashSet::default(),
 72        };
 73        this.on_enabled_changed(cx);
 74        this
 75    }
 76
 77    pub fn fs(&self) -> &Arc<dyn Fs> {
 78        &self.fs
 79    }
 80
 81    pub fn is_enabled(&self) -> bool {
 82        self.enabled
 83    }
 84
 85    pub fn has_python_kernelspecs(&self, worktree_id: WorktreeId) -> bool {
 86        self.kernel_specifications_for_worktree
 87            .contains_key(&worktree_id)
 88    }
 89
 90    pub fn kernel_specifications_for_worktree(
 91        &self,
 92        worktree_id: WorktreeId,
 93    ) -> impl Iterator<Item = &KernelSpecification> {
 94        let global_specs = if self.remote_worktrees.contains(&worktree_id) {
 95            None
 96        } else {
 97            Some(self.kernel_specifications.iter())
 98        };
 99
100        self.kernel_specifications_for_worktree
101            .get(&worktree_id)
102            .into_iter()
103            .flat_map(|specs| specs.iter())
104            .chain(global_specs.into_iter().flatten())
105    }
106
107    pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
108        self.kernel_specifications.iter()
109    }
110
111    pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
112        self.sessions.values()
113    }
114
115    fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
116        if self.enabled == enabled {
117            return;
118        }
119
120        self.enabled = enabled;
121        self.on_enabled_changed(cx);
122    }
123
124    fn on_enabled_changed(&self, cx: &mut Context<Self>) {
125        if !self.enabled {
126            CommandPaletteFilter::update_global(cx, |filter, _cx| {
127                filter.hide_namespace(Self::NAMESPACE);
128            });
129
130            return;
131        }
132
133        CommandPaletteFilter::update_global(cx, |filter, _cx| {
134            filter.show_namespace(Self::NAMESPACE);
135        });
136
137        cx.notify();
138    }
139
140    pub fn refresh_python_kernelspecs(
141        &mut self,
142        worktree_id: WorktreeId,
143        project: &Entity<Project>,
144        cx: &mut Context<Self>,
145    ) -> Task<Result<()>> {
146        let is_remote = project.read(cx).is_remote();
147        let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
148        let active_toolchain = project.read(cx).active_toolchain(
149            ProjectPath {
150                worktree_id,
151                path: RelPath::empty().into(),
152            },
153            LanguageName::new_static("Python"),
154            cx,
155        );
156
157        cx.spawn(async move |this, cx| {
158            let kernel_specifications = kernel_specifications
159                .await
160                .context("getting python kernelspecs")?;
161
162            let active_toolchain_path = active_toolchain.await.map(|toolchain| toolchain.path);
163
164            this.update(cx, |this, cx| {
165                this.kernel_specifications_for_worktree
166                    .insert(worktree_id, kernel_specifications);
167                if let Some(path) = active_toolchain_path {
168                    this.active_python_toolchain_for_worktree
169                        .insert(worktree_id, path);
170                }
171                if is_remote {
172                    this.remote_worktrees.insert(worktree_id);
173                } else {
174                    this.remote_worktrees.remove(&worktree_id);
175                }
176                cx.notify();
177            })
178        })
179    }
180
181    fn get_remote_kernel_specifications(
182        &self,
183        cx: &mut Context<Self>,
184    ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
185        match (
186            std::env::var("JUPYTER_SERVER"),
187            std::env::var("JUPYTER_TOKEN"),
188        ) {
189            (Ok(server), Ok(token)) => {
190                let remote_server = RemoteServer {
191                    base_url: server,
192                    token,
193                };
194                let http_client = cx.http_client();
195                Some(cx.spawn(async move |_, _| {
196                    list_remote_kernelspecs(remote_server, http_client)
197                        .await
198                        .map(|specs| {
199                            specs
200                                .into_iter()
201                                .map(KernelSpecification::JupyterServer)
202                                .collect()
203                        })
204                }))
205            }
206            _ => None,
207        }
208    }
209
210    pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
211        let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
212        let wsl_kernel_specifications = wsl_kernel_specifications(cx.background_executor().clone());
213
214        let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
215
216        let all_specs = cx.background_spawn(async move {
217            let mut all_specs = local_kernel_specifications
218                .await?
219                .into_iter()
220                .map(KernelSpecification::Jupyter)
221                .collect::<Vec<_>>();
222
223            if let Ok(wsl_specs) = wsl_kernel_specifications.await {
224                all_specs.extend(wsl_specs);
225            }
226
227            if let Some(remote_task) = remote_kernel_specifications
228                && let Ok(remote_specs) = remote_task.await
229            {
230                all_specs.extend(remote_specs);
231            }
232
233            anyhow::Ok(all_specs)
234        });
235
236        cx.spawn(async move |this, cx| {
237            let all_specs = all_specs.await;
238
239            if let Ok(specs) = all_specs {
240                this.update(cx, |this, cx| {
241                    this.kernel_specifications = specs;
242                    cx.notify();
243                })
244                .ok();
245            }
246
247            anyhow::Ok(())
248        })
249    }
250
251    pub fn set_active_kernelspec(
252        &mut self,
253        worktree_id: WorktreeId,
254        kernelspec: KernelSpecification,
255        _cx: &mut Context<Self>,
256    ) {
257        self.selected_kernel_for_worktree
258            .insert(worktree_id, kernelspec);
259    }
260
261    pub fn active_python_toolchain_path(&self, worktree_id: WorktreeId) -> Option<&SharedString> {
262        self.active_python_toolchain_for_worktree.get(&worktree_id)
263    }
264
265    pub fn selected_kernel(&self, worktree_id: WorktreeId) -> Option<&KernelSpecification> {
266        self.selected_kernel_for_worktree.get(&worktree_id)
267    }
268
269    pub fn is_recommended_kernel(
270        &self,
271        worktree_id: WorktreeId,
272        spec: &KernelSpecification,
273    ) -> bool {
274        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
275            spec.path().as_ref() == active_path.as_ref()
276        } else {
277            false
278        }
279    }
280
281    pub fn active_kernelspec(
282        &self,
283        worktree_id: WorktreeId,
284        language_at_cursor: Option<Arc<Language>>,
285        cx: &App,
286    ) -> Option<KernelSpecification> {
287        if let Some(selected) = self.selected_kernel_for_worktree.get(&worktree_id).cloned() {
288            return Some(selected);
289        }
290
291        let language_at_cursor = language_at_cursor?;
292        let language_name = language_at_cursor.code_fence_block_name().to_lowercase();
293
294        // Prefer the recommended (active toolchain) kernel if it has ipykernel
295        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
296            let recommended = self
297                .kernel_specifications_for_worktree(worktree_id)
298                .find(|spec| {
299                    spec.has_ipykernel()
300                        && spec.language().as_ref().to_lowercase() == language_name
301                        && spec.path().as_ref() == active_path.as_ref()
302                })
303                .cloned();
304            if recommended.is_some() {
305                return recommended;
306            }
307        }
308
309        // Then try the first PythonEnv with ipykernel matching the language
310        let python_env = self
311            .kernel_specifications_for_worktree(worktree_id)
312            .find(|spec| {
313                matches!(spec, KernelSpecification::PythonEnv(_))
314                    && spec.has_ipykernel()
315                    && spec.language().as_ref().to_lowercase() == language_name
316            })
317            .cloned();
318        if python_env.is_some() {
319            return python_env;
320        }
321
322        // Fall back to legacy name-based and language-based matching
323        self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
324    }
325
326    fn kernelspec_legacy_by_lang_only(
327        &self,
328        worktree_id: WorktreeId,
329        language_at_cursor: Arc<Language>,
330        cx: &App,
331    ) -> Option<KernelSpecification> {
332        let settings = JupyterSettings::get_global(cx);
333        let selected_kernel = settings
334            .kernel_selections
335            .get(language_at_cursor.code_fence_block_name().as_ref());
336
337        let found_by_name = self
338            .kernel_specifications_for_worktree(worktree_id)
339            .find(|runtime_specification| {
340                if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
341                    (selected_kernel, runtime_specification)
342                {
343                    return runtime_specification.name.to_lowercase() == selected.to_lowercase();
344                }
345                false
346            })
347            .cloned();
348
349        if let Some(found_by_name) = found_by_name {
350            return Some(found_by_name);
351        }
352
353        let language_name = language_at_cursor.code_fence_block_name().to_lowercase();
354        self.kernel_specifications_for_worktree(worktree_id)
355            .find(|spec| {
356                spec.has_ipykernel() && spec.language().as_ref().to_lowercase() == language_name
357            })
358            .cloned()
359    }
360
361    pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
362        self.sessions.get(&entity_id)
363    }
364
365    pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
366        self.sessions.insert(entity_id, session);
367    }
368
369    pub fn remove_session(&mut self, entity_id: EntityId) {
370        self.sessions.remove(&entity_id);
371    }
372
373    fn shutdown_all_sessions(
374        &mut self,
375        cx: &mut Context<Self>,
376    ) -> impl Future<Output = ()> + use<> {
377        for session in self.sessions.values() {
378            session.update(cx, |session, _cx| {
379                if let Kernel::RunningKernel(mut kernel) =
380                    std::mem::replace(&mut session.kernel, Kernel::Shutdown)
381                {
382                    kernel.kill();
383                }
384            });
385        }
386        self.sessions.clear();
387        futures::future::ready(())
388    }
389
390    #[cfg(test)]
391    pub fn set_kernel_specs_for_testing(
392        &mut self,
393        specs: Vec<KernelSpecification>,
394        cx: &mut Context<Self>,
395    ) {
396        self.kernel_specifications = specs;
397        cx.notify();
398    }
399}