repl_store.rs

  1use std::future::Future;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use collections::HashMap;
  6use command_palette_hooks::CommandPaletteFilter;
  7use gpui::{App, Context, Entity, EntityId, Global, SharedString, Subscription, Task, prelude::*};
  8use jupyter_websocket_client::RemoteServer;
  9use language::{Language, LanguageName};
 10use project::{Fs, Project, ProjectPath, WorktreeId};
 11use settings::{Settings, SettingsStore};
 12use util::rel_path::RelPath;
 13
 14use crate::kernels::{
 15    Kernel, list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
 16};
 17use crate::{JupyterSettings, KernelSpecification, Session};
 18
 19struct GlobalReplStore(Entity<ReplStore>);
 20
 21impl Global for GlobalReplStore {}
 22
 23pub struct ReplStore {
 24    fs: Arc<dyn Fs>,
 25    enabled: bool,
 26    sessions: HashMap<EntityId, Entity<Session>>,
 27    kernel_specifications: Vec<KernelSpecification>,
 28    selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
 29    kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
 30    active_python_toolchain_for_worktree: HashMap<WorktreeId, SharedString>,
 31    _subscriptions: Vec<Subscription>,
 32}
 33
 34impl ReplStore {
 35    const NAMESPACE: &'static str = "repl";
 36
 37    pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
 38        let store = cx.new(move |cx| Self::new(fs, cx));
 39
 40        #[cfg(not(feature = "test-support"))]
 41        store
 42            .update(cx, |store, cx| store.refresh_kernelspecs(cx))
 43            .detach_and_log_err(cx);
 44
 45        cx.set_global(GlobalReplStore(store))
 46    }
 47
 48    pub fn global(cx: &App) -> Entity<Self> {
 49        cx.global::<GlobalReplStore>().0.clone()
 50    }
 51
 52    pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
 53        let subscriptions = vec![
 54            cx.observe_global::<SettingsStore>(move |this, cx| {
 55                this.set_enabled(JupyterSettings::enabled(cx), cx);
 56            }),
 57            cx.on_app_quit(Self::shutdown_all_sessions),
 58        ];
 59
 60        let this = Self {
 61            fs,
 62            enabled: JupyterSettings::enabled(cx),
 63            sessions: HashMap::default(),
 64            kernel_specifications: Vec::new(),
 65            _subscriptions: subscriptions,
 66            kernel_specifications_for_worktree: HashMap::default(),
 67            selected_kernel_for_worktree: HashMap::default(),
 68            active_python_toolchain_for_worktree: HashMap::default(),
 69        };
 70        this.on_enabled_changed(cx);
 71        this
 72    }
 73
 74    pub fn fs(&self) -> &Arc<dyn Fs> {
 75        &self.fs
 76    }
 77
 78    pub fn is_enabled(&self) -> bool {
 79        self.enabled
 80    }
 81
 82    pub fn has_python_kernelspecs(&self, worktree_id: WorktreeId) -> bool {
 83        self.kernel_specifications_for_worktree
 84            .contains_key(&worktree_id)
 85    }
 86
 87    pub fn kernel_specifications_for_worktree(
 88        &self,
 89        worktree_id: WorktreeId,
 90    ) -> impl Iterator<Item = &KernelSpecification> {
 91        self.kernel_specifications_for_worktree
 92            .get(&worktree_id)
 93            .into_iter()
 94            .flat_map(|specs| specs.iter())
 95            .chain(self.kernel_specifications.iter())
 96    }
 97
 98    pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
 99        self.kernel_specifications.iter()
100    }
101
102    pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
103        self.sessions.values()
104    }
105
106    fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
107        if self.enabled == enabled {
108            return;
109        }
110
111        self.enabled = enabled;
112        self.on_enabled_changed(cx);
113    }
114
115    fn on_enabled_changed(&self, cx: &mut Context<Self>) {
116        if !self.enabled {
117            CommandPaletteFilter::update_global(cx, |filter, _cx| {
118                filter.hide_namespace(Self::NAMESPACE);
119            });
120
121            return;
122        }
123
124        CommandPaletteFilter::update_global(cx, |filter, _cx| {
125            filter.show_namespace(Self::NAMESPACE);
126        });
127
128        cx.notify();
129    }
130
131    pub fn refresh_python_kernelspecs(
132        &mut self,
133        worktree_id: WorktreeId,
134        project: &Entity<Project>,
135        cx: &mut Context<Self>,
136    ) -> Task<Result<()>> {
137        let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
138        let active_toolchain = project.read(cx).active_toolchain(
139            ProjectPath {
140                worktree_id,
141                path: RelPath::empty().into(),
142            },
143            LanguageName::new_static("Python"),
144            cx,
145        );
146
147        cx.spawn(async move |this, cx| {
148            let kernel_specifications = kernel_specifications
149                .await
150                .context("getting python kernelspecs")?;
151
152            let active_toolchain_path = active_toolchain.await.map(|toolchain| toolchain.path);
153
154            this.update(cx, |this, cx| {
155                this.kernel_specifications_for_worktree
156                    .insert(worktree_id, kernel_specifications);
157                if let Some(path) = active_toolchain_path {
158                    this.active_python_toolchain_for_worktree
159                        .insert(worktree_id, path);
160                }
161                cx.notify();
162            })
163        })
164    }
165
166    fn get_remote_kernel_specifications(
167        &self,
168        cx: &mut Context<Self>,
169    ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
170        match (
171            std::env::var("JUPYTER_SERVER"),
172            std::env::var("JUPYTER_TOKEN"),
173        ) {
174            (Ok(server), Ok(token)) => {
175                let remote_server = RemoteServer {
176                    base_url: server,
177                    token,
178                };
179                let http_client = cx.http_client();
180                Some(cx.spawn(async move |_, _| {
181                    list_remote_kernelspecs(remote_server, http_client)
182                        .await
183                        .map(|specs| specs.into_iter().map(KernelSpecification::Remote).collect())
184                }))
185            }
186            _ => None,
187        }
188    }
189
190    pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
191        let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
192
193        let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
194
195        let all_specs = cx.background_spawn(async move {
196            let mut all_specs = local_kernel_specifications
197                .await?
198                .into_iter()
199                .map(KernelSpecification::Jupyter)
200                .collect::<Vec<_>>();
201
202            if let Some(remote_task) = remote_kernel_specifications
203                && let Ok(remote_specs) = remote_task.await
204            {
205                all_specs.extend(remote_specs);
206            }
207
208            anyhow::Ok(all_specs)
209        });
210
211        cx.spawn(async move |this, cx| {
212            let all_specs = all_specs.await;
213
214            if let Ok(specs) = all_specs {
215                this.update(cx, |this, cx| {
216                    this.kernel_specifications = specs;
217                    cx.notify();
218                })
219                .ok();
220            }
221
222            anyhow::Ok(())
223        })
224    }
225
226    pub fn set_active_kernelspec(
227        &mut self,
228        worktree_id: WorktreeId,
229        kernelspec: KernelSpecification,
230        _cx: &mut Context<Self>,
231    ) {
232        self.selected_kernel_for_worktree
233            .insert(worktree_id, kernelspec);
234    }
235
236    pub fn active_python_toolchain_path(&self, worktree_id: WorktreeId) -> Option<&SharedString> {
237        self.active_python_toolchain_for_worktree.get(&worktree_id)
238    }
239
240    pub fn is_recommended_kernel(
241        &self,
242        worktree_id: WorktreeId,
243        spec: &KernelSpecification,
244    ) -> bool {
245        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
246            spec.path().as_ref() == active_path.as_ref()
247        } else {
248            false
249        }
250    }
251
252    pub fn active_kernelspec(
253        &self,
254        worktree_id: WorktreeId,
255        language_at_cursor: Option<Arc<Language>>,
256        cx: &App,
257    ) -> Option<KernelSpecification> {
258        if let Some(selected) = self.selected_kernel_for_worktree.get(&worktree_id).cloned() {
259            return Some(selected);
260        }
261
262        let language_at_cursor = language_at_cursor?;
263        let language_name = language_at_cursor.code_fence_block_name().to_lowercase();
264
265        // Prefer the recommended (active toolchain) kernel if it has ipykernel
266        if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
267            let recommended = self
268                .kernel_specifications_for_worktree(worktree_id)
269                .find(|spec| {
270                    spec.has_ipykernel()
271                        && spec.language().as_ref().to_lowercase() == language_name
272                        && spec.path().as_ref() == active_path.as_ref()
273                })
274                .cloned();
275            if recommended.is_some() {
276                return recommended;
277            }
278        }
279
280        // Then try the first PythonEnv with ipykernel matching the language
281        let python_env = self
282            .kernel_specifications_for_worktree(worktree_id)
283            .find(|spec| {
284                matches!(spec, KernelSpecification::PythonEnv(_))
285                    && spec.has_ipykernel()
286                    && spec.language().as_ref().to_lowercase() == language_name
287            })
288            .cloned();
289        if python_env.is_some() {
290            return python_env;
291        }
292
293        // Fall back to legacy name-based and language-based matching
294        self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
295    }
296
297    fn kernelspec_legacy_by_lang_only(
298        &self,
299        worktree_id: WorktreeId,
300        language_at_cursor: Arc<Language>,
301        cx: &App,
302    ) -> Option<KernelSpecification> {
303        let settings = JupyterSettings::get_global(cx);
304        let selected_kernel = settings
305            .kernel_selections
306            .get(language_at_cursor.code_fence_block_name().as_ref());
307
308        let found_by_name = self
309            .kernel_specifications_for_worktree(worktree_id)
310            .find(|runtime_specification| {
311                if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
312                    (selected_kernel, runtime_specification)
313                {
314                    return runtime_specification.name.to_lowercase() == selected.to_lowercase();
315                }
316                false
317            })
318            .cloned();
319
320        if let Some(found_by_name) = found_by_name {
321            return Some(found_by_name);
322        }
323
324        let language_name = language_at_cursor.code_fence_block_name().to_lowercase();
325        self.kernel_specifications_for_worktree(worktree_id)
326            .find(|spec| {
327                spec.has_ipykernel() && spec.language().as_ref().to_lowercase() == language_name
328            })
329            .cloned()
330    }
331
332    pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
333        self.sessions.get(&entity_id)
334    }
335
336    pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
337        self.sessions.insert(entity_id, session);
338    }
339
340    pub fn remove_session(&mut self, entity_id: EntityId) {
341        self.sessions.remove(&entity_id);
342    }
343
344    fn shutdown_all_sessions(
345        &mut self,
346        cx: &mut Context<Self>,
347    ) -> impl Future<Output = ()> + use<> {
348        for session in self.sessions.values() {
349            session.update(cx, |session, _cx| {
350                if let Kernel::RunningKernel(mut kernel) =
351                    std::mem::replace(&mut session.kernel, Kernel::Shutdown)
352                {
353                    kernel.kill();
354                }
355            });
356        }
357        self.sessions.clear();
358        futures::future::ready(())
359    }
360
361    #[cfg(test)]
362    pub fn set_kernel_specs_for_testing(
363        &mut self,
364        specs: Vec<KernelSpecification>,
365        cx: &mut Context<Self>,
366    ) {
367        self.kernel_specifications = specs;
368        cx.notify();
369    }
370}