repl_store.rs

  1use std::future::Future;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use collections::{HashMap, HashSet};
  6use command_palette_hooks::CommandPaletteFilter;
  7use gpui::{App, Context, Entity, EntityId, Global, SharedString, Subscription, Task, prelude::*};
  8use jupyter_websocket_client::RemoteServer;
  9use language::{Language, LanguageName};
 10use project::{Fs, Project, ProjectPath, WorktreeId};
 11use remote::RemoteConnectionOptions;
 12use settings::{Settings, SettingsStore};
 13use util::rel_path::RelPath;
 14
 15use crate::kernels::{
 16    Kernel, list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
 17    wsl_kernel_specifications,
 18};
 19use crate::{JupyterSettings, KernelSpecification, Session};
 20
 21struct GlobalReplStore(Entity<ReplStore>);
 22
 23impl Global for GlobalReplStore {}
 24
 25pub struct ReplStore {
 26    fs: Arc<dyn Fs>,
 27    enabled: bool,
 28    sessions: HashMap<EntityId, Entity<Session>>,
 29    kernel_specifications: Vec<KernelSpecification>,
 30    selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
 31    kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
 32    active_python_toolchain_for_worktree: HashMap<WorktreeId, SharedString>,
 33    remote_worktrees: HashSet<WorktreeId>,
 34    _subscriptions: Vec<Subscription>,
 35}
 36
 37impl ReplStore {
 38    const NAMESPACE: &'static str = "repl";
 39
 40    pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
 41        let store = cx.new(move |cx| Self::new(fs, cx));
 42
 43        #[cfg(not(feature = "test-support"))]
 44        store
 45            .update(cx, |store, cx| store.refresh_kernelspecs(cx))
 46            .detach_and_log_err(cx);
 47
 48        cx.set_global(GlobalReplStore(store))
 49    }
 50
 51    pub fn global(cx: &App) -> Entity<Self> {
 52        cx.global::<GlobalReplStore>().0.clone()
 53    }
 54
 55    pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
 56        let subscriptions = vec![
 57            cx.observe_global::<SettingsStore>(move |this, cx| {
 58                this.set_enabled(JupyterSettings::enabled(cx), cx);
 59            }),
 60            cx.on_app_quit(Self::shutdown_all_sessions),
 61        ];
 62
 63        let this = Self {
 64            fs,
 65            enabled: JupyterSettings::enabled(cx),
 66            sessions: HashMap::default(),
 67            kernel_specifications: Vec::new(),
 68            _subscriptions: subscriptions,
 69            kernel_specifications_for_worktree: HashMap::default(),
 70            selected_kernel_for_worktree: HashMap::default(),
 71            active_python_toolchain_for_worktree: HashMap::default(),
 72            remote_worktrees: HashSet::default(),
 73        };
 74        this.on_enabled_changed(cx);
 75        this
 76    }
 77
 78    pub fn fs(&self) -> &Arc<dyn Fs> {
 79        &self.fs
 80    }
 81
 82    pub fn is_enabled(&self) -> bool {
 83        self.enabled
 84    }
 85
 86    pub fn has_python_kernelspecs(&self, worktree_id: WorktreeId) -> bool {
 87        self.kernel_specifications_for_worktree
 88            .contains_key(&worktree_id)
 89    }
 90
 91    pub fn kernel_specifications_for_worktree(
 92        &self,
 93        worktree_id: WorktreeId,
 94    ) -> impl Iterator<Item = &KernelSpecification> {
 95        let global_specs = if self.remote_worktrees.contains(&worktree_id) {
 96            None
 97        } else {
 98            Some(self.kernel_specifications.iter())
 99        };
100
101        self.kernel_specifications_for_worktree
102            .get(&worktree_id)
103            .into_iter()
104            .flat_map(|specs| specs.iter())
105            .chain(global_specs.into_iter().flatten())
106    }
107
108    pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
109        self.kernel_specifications.iter()
110    }
111
112    pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
113        self.sessions.values()
114    }
115
116    fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
117        if self.enabled == enabled {
118            return;
119        }
120
121        self.enabled = enabled;
122        self.on_enabled_changed(cx);
123    }
124
125    fn on_enabled_changed(&self, cx: &mut Context<Self>) {
126        if !self.enabled {
127            CommandPaletteFilter::update_global(cx, |filter, _cx| {
128                filter.hide_namespace(Self::NAMESPACE);
129            });
130
131            return;
132        }
133
134        CommandPaletteFilter::update_global(cx, |filter, _cx| {
135            filter.show_namespace(Self::NAMESPACE);
136        });
137
138        cx.notify();
139    }
140
141    pub fn refresh_python_kernelspecs(
142        &mut self,
143        worktree_id: WorktreeId,
144        project: &Entity<Project>,
145        cx: &mut Context<Self>,
146    ) -> Task<Result<()>> {
147        let is_remote = project.read(cx).is_remote();
148        // WSL does require access to global kernel specs, so we only exclude remote worktrees that aren't WSL.
149        // TODO: a better way to handle WSL vs SSH/remote projects,
150        let is_wsl_remote = project
151            .read(cx)
152            .remote_connection_options(cx)
153            .map_or(false, |opts| {
154                matches!(opts, RemoteConnectionOptions::Wsl(_))
155            });
156        let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
157        let active_toolchain = project.read(cx).active_toolchain(
158            ProjectPath {
159                worktree_id,
160                path: RelPath::empty().into(),
161            },
162            LanguageName::new_static("Python"),
163            cx,
164        );
165
166        cx.spawn(async move |this, cx| {
167            let kernel_specifications = kernel_specifications
168                .await
169                .context("getting python kernelspecs")?;
170
171            let active_toolchain_path = active_toolchain.await.map(|toolchain| toolchain.path);
172
173            this.update(cx, |this, cx| {
174                this.kernel_specifications_for_worktree
175                    .insert(worktree_id, kernel_specifications);
176                if let Some(path) = active_toolchain_path {
177                    this.active_python_toolchain_for_worktree
178                        .insert(worktree_id, path);
179                }
180                if is_remote && !is_wsl_remote {
181                    this.remote_worktrees.insert(worktree_id);
182                } else {
183                    this.remote_worktrees.remove(&worktree_id);
184                }
185                cx.notify();
186            })
187        })
188    }
189
190    fn get_remote_kernel_specifications(
191        &self,
192        cx: &mut Context<Self>,
193    ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
194        match (
195            std::env::var("JUPYTER_SERVER"),
196            std::env::var("JUPYTER_TOKEN"),
197        ) {
198            (Ok(server), Ok(token)) => {
199                let remote_server = RemoteServer {
200                    base_url: server,
201                    token,
202                };
203                let http_client = cx.http_client();
204                Some(cx.spawn(async move |_, _| {
205                    list_remote_kernelspecs(remote_server, http_client)
206                        .await
207                        .map(|specs| {
208                            specs
209                                .into_iter()
210                                .map(KernelSpecification::JupyterServer)
211                                .collect()
212                        })
213                }))
214            }
215            _ => None,
216        }
217    }
218
219    pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
220        let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
221        let wsl_kernel_specifications = wsl_kernel_specifications(cx.background_executor().clone());
222
223        let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
224
225        let all_specs = cx.background_spawn(async move {
226            let mut all_specs = local_kernel_specifications
227                .await?
228                .into_iter()
229                .map(KernelSpecification::Jupyter)
230                .collect::<Vec<_>>();
231
232            if let Ok(wsl_specs) = wsl_kernel_specifications.await {
233                all_specs.extend(wsl_specs);
234            }
235
236            if let Some(remote_task) = remote_kernel_specifications
237                && let Ok(remote_specs) = remote_task.await
238            {
239                all_specs.extend(remote_specs);
240            }
241
242            anyhow::Ok(all_specs)
243        });
244
245        cx.spawn(async move |this, cx| {
246            let all_specs = all_specs.await;
247
248            if let Ok(specs) = all_specs {
249                this.update(cx, |this, cx| {
250                    this.kernel_specifications = specs;
251                    cx.notify();
252                })
253                .ok();
254            }
255
256            anyhow::Ok(())
257        })
258    }
259
260    pub fn set_active_kernelspec(
261        &mut self,
262        worktree_id: WorktreeId,
263        kernelspec: KernelSpecification,
264        _cx: &mut Context<Self>,
265    ) {
266        self.selected_kernel_for_worktree
267            .insert(worktree_id, kernelspec);
268    }
269
270    pub fn active_python_toolchain_path(&self, worktree_id: WorktreeId) -> Option<&SharedString> {
271        self.active_python_toolchain_for_worktree.get(&worktree_id)
272    }
273
274    pub fn selected_kernel(&self, worktree_id: WorktreeId) -> Option<&KernelSpecification> {
275        self.selected_kernel_for_worktree.get(&worktree_id)
276    }
277
278    pub fn is_recommended_kernel(
279        &self,
280        worktree_id: WorktreeId,
281        spec: &KernelSpecification,
282    ) -> bool {
283        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
284            spec.path().as_ref() == active_path.as_ref()
285        } else {
286            false
287        }
288    }
289
290    pub fn active_kernelspec(
291        &self,
292        worktree_id: WorktreeId,
293        language_at_cursor: Option<Arc<Language>>,
294        cx: &App,
295    ) -> Option<KernelSpecification> {
296        if let Some(selected) = self.selected_kernel_for_worktree.get(&worktree_id).cloned() {
297            return Some(selected);
298        }
299
300        let language_at_cursor = language_at_cursor?;
301
302        // Prefer the recommended (active toolchain) kernel if it has ipykernel
303        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
304            let recommended = self
305                .kernel_specifications_for_worktree(worktree_id)
306                .find(|spec| {
307                    spec.has_ipykernel()
308                        && language_at_cursor.matches_kernel_language(spec.language().as_ref())
309                        && spec.path().as_ref() == active_path.as_ref()
310                })
311                .cloned();
312            if recommended.is_some() {
313                return recommended;
314            }
315        }
316
317        // Then try the first PythonEnv with ipykernel matching the language
318        let python_env = self
319            .kernel_specifications_for_worktree(worktree_id)
320            .find(|spec| {
321                matches!(spec, KernelSpecification::PythonEnv(_))
322                    && spec.has_ipykernel()
323                    && language_at_cursor.matches_kernel_language(spec.language().as_ref())
324            })
325            .cloned();
326        if python_env.is_some() {
327            return python_env;
328        }
329
330        // Fall back to legacy name-based and language-based matching
331        self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
332    }
333
334    fn kernelspec_legacy_by_lang_only(
335        &self,
336        worktree_id: WorktreeId,
337        language_at_cursor: Arc<Language>,
338        cx: &App,
339    ) -> Option<KernelSpecification> {
340        let settings = JupyterSettings::get_global(cx);
341        let selected_kernel = settings
342            .kernel_selections
343            .get(language_at_cursor.code_fence_block_name().as_ref());
344
345        let found_by_name = self
346            .kernel_specifications_for_worktree(worktree_id)
347            .find(|runtime_specification| {
348                if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
349                    (selected_kernel, runtime_specification)
350                {
351                    return runtime_specification.name.to_lowercase() == selected.to_lowercase();
352                }
353                false
354            })
355            .cloned();
356
357        if let Some(found_by_name) = found_by_name {
358            return Some(found_by_name);
359        }
360
361        self.kernel_specifications_for_worktree(worktree_id)
362            .find(|spec| {
363                spec.has_ipykernel()
364                    && language_at_cursor.matches_kernel_language(spec.language().as_ref())
365            })
366            .cloned()
367    }
368
369    pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
370        self.sessions.get(&entity_id)
371    }
372
373    pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
374        self.sessions.insert(entity_id, session);
375    }
376
377    pub fn remove_session(&mut self, entity_id: EntityId) {
378        self.sessions.remove(&entity_id);
379    }
380
381    fn shutdown_all_sessions(
382        &mut self,
383        cx: &mut Context<Self>,
384    ) -> impl Future<Output = ()> + use<> {
385        for session in self.sessions.values() {
386            session.update(cx, |session, _cx| {
387                if let Kernel::RunningKernel(mut kernel) =
388                    std::mem::replace(&mut session.kernel, Kernel::Shutdown)
389                {
390                    kernel.kill();
391                }
392            });
393        }
394        self.sessions.clear();
395        futures::future::ready(())
396    }
397
398    #[cfg(test)]
399    pub fn set_kernel_specs_for_testing(
400        &mut self,
401        specs: Vec<KernelSpecification>,
402        cx: &mut Context<Self>,
403    ) {
404        self.kernel_specifications = specs;
405        cx.notify();
406    }
407}