repl_store.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Context as _, Result};
  4use collections::HashMap;
  5use command_palette_hooks::CommandPaletteFilter;
  6use gpui::{App, Context, Entity, EntityId, Global, Subscription, Task, prelude::*};
  7use jupyter_websocket_client::RemoteServer;
  8use language::Language;
  9use project::{Fs, Project, WorktreeId};
 10use settings::{Settings, SettingsStore};
 11
 12use crate::kernels::{
 13    list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
 14};
 15use crate::{JupyterSettings, KernelSpecification, Session};
 16
 17struct GlobalReplStore(Entity<ReplStore>);
 18
 19impl Global for GlobalReplStore {}
 20
 21pub struct ReplStore {
 22    fs: Arc<dyn Fs>,
 23    enabled: bool,
 24    sessions: HashMap<EntityId, Entity<Session>>,
 25    kernel_specifications: Vec<KernelSpecification>,
 26    selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
 27    kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
 28    _subscriptions: Vec<Subscription>,
 29}
 30
 31impl ReplStore {
 32    const NAMESPACE: &'static str = "repl";
 33
 34    pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
 35        let store = cx.new(move |cx| Self::new(fs, cx));
 36
 37        store
 38            .update(cx, |store, cx| store.refresh_kernelspecs(cx))
 39            .detach_and_log_err(cx);
 40
 41        cx.set_global(GlobalReplStore(store))
 42    }
 43
 44    pub fn global(cx: &App) -> Entity<Self> {
 45        cx.global::<GlobalReplStore>().0.clone()
 46    }
 47
 48    pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
 49        let subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
 50            this.set_enabled(JupyterSettings::enabled(cx), cx);
 51        })];
 52
 53        let this = Self {
 54            fs,
 55            enabled: JupyterSettings::enabled(cx),
 56            sessions: HashMap::default(),
 57            kernel_specifications: Vec::new(),
 58            _subscriptions: subscriptions,
 59            kernel_specifications_for_worktree: HashMap::default(),
 60            selected_kernel_for_worktree: HashMap::default(),
 61        };
 62        this.on_enabled_changed(cx);
 63        this
 64    }
 65
 66    pub fn fs(&self) -> &Arc<dyn Fs> {
 67        &self.fs
 68    }
 69
 70    pub fn is_enabled(&self) -> bool {
 71        self.enabled
 72    }
 73
 74    pub fn kernel_specifications_for_worktree(
 75        &self,
 76        worktree_id: WorktreeId,
 77    ) -> impl Iterator<Item = &KernelSpecification> {
 78        self.kernel_specifications_for_worktree
 79            .get(&worktree_id)
 80            .into_iter()
 81            .flat_map(|specs| specs.iter())
 82            .chain(self.kernel_specifications.iter())
 83    }
 84
 85    pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
 86        self.kernel_specifications.iter()
 87    }
 88
 89    pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
 90        self.sessions.values()
 91    }
 92
 93    fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
 94        if self.enabled == enabled {
 95            return;
 96        }
 97
 98        self.enabled = enabled;
 99        self.on_enabled_changed(cx);
100    }
101
102    fn on_enabled_changed(&self, cx: &mut Context<Self>) {
103        if !self.enabled {
104            CommandPaletteFilter::update_global(cx, |filter, _cx| {
105                filter.hide_namespace(Self::NAMESPACE);
106            });
107
108            return;
109        }
110
111        CommandPaletteFilter::update_global(cx, |filter, _cx| {
112            filter.show_namespace(Self::NAMESPACE);
113        });
114
115        cx.notify();
116    }
117
118    pub fn refresh_python_kernelspecs(
119        &mut self,
120        worktree_id: WorktreeId,
121        project: &Entity<Project>,
122        cx: &mut Context<Self>,
123    ) -> Task<Result<()>> {
124        let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
125        cx.spawn(async move |this, cx| {
126            let kernel_specifications = kernel_specifications
127                .await
128                .context("getting python kernelspecs")?;
129
130            this.update(cx, |this, cx| {
131                this.kernel_specifications_for_worktree
132                    .insert(worktree_id, kernel_specifications);
133                cx.notify();
134            })
135        })
136    }
137
138    fn get_remote_kernel_specifications(
139        &self,
140        cx: &mut Context<Self>,
141    ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
142        match (
143            std::env::var("JUPYTER_SERVER"),
144            std::env::var("JUPYTER_TOKEN"),
145        ) {
146            (Ok(server), Ok(token)) => {
147                let remote_server = RemoteServer {
148                    base_url: server,
149                    token,
150                };
151                let http_client = cx.http_client();
152                Some(cx.spawn(async move |_, _| {
153                    list_remote_kernelspecs(remote_server, http_client)
154                        .await
155                        .map(|specs| specs.into_iter().map(KernelSpecification::Remote).collect())
156                }))
157            }
158            _ => None,
159        }
160    }
161
162    pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
163        let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
164
165        let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
166
167        let all_specs = cx.background_spawn(async move {
168            let mut all_specs = local_kernel_specifications
169                .await?
170                .into_iter()
171                .map(KernelSpecification::Jupyter)
172                .collect::<Vec<_>>();
173
174            if let Some(remote_task) = remote_kernel_specifications
175                && let Ok(remote_specs) = remote_task.await {
176                    all_specs.extend(remote_specs);
177                }
178
179            anyhow::Ok(all_specs)
180        });
181
182        cx.spawn(async move |this, cx| {
183            let all_specs = all_specs.await;
184
185            if let Ok(specs) = all_specs {
186                this.update(cx, |this, cx| {
187                    this.kernel_specifications = specs;
188                    cx.notify();
189                })
190                .ok();
191            }
192
193            anyhow::Ok(())
194        })
195    }
196
197    pub fn set_active_kernelspec(
198        &mut self,
199        worktree_id: WorktreeId,
200        kernelspec: KernelSpecification,
201        _cx: &mut Context<Self>,
202    ) {
203        self.selected_kernel_for_worktree
204            .insert(worktree_id, kernelspec);
205    }
206
207    pub fn active_kernelspec(
208        &self,
209        worktree_id: WorktreeId,
210        language_at_cursor: Option<Arc<Language>>,
211        cx: &App,
212    ) -> Option<KernelSpecification> {
213        let selected_kernelspec = self.selected_kernel_for_worktree.get(&worktree_id).cloned();
214
215        if let Some(language_at_cursor) = language_at_cursor {
216            selected_kernelspec
217                .or_else(|| self.kernelspec_legacy_by_lang_only(language_at_cursor, cx))
218        } else {
219            selected_kernelspec
220        }
221    }
222
223    fn kernelspec_legacy_by_lang_only(
224        &self,
225        language_at_cursor: Arc<Language>,
226        cx: &App,
227    ) -> Option<KernelSpecification> {
228        let settings = JupyterSettings::get_global(cx);
229        let selected_kernel = settings
230            .kernel_selections
231            .get(language_at_cursor.code_fence_block_name().as_ref());
232
233        let found_by_name = self
234            .kernel_specifications
235            .iter()
236            .find(|runtime_specification| {
237                if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
238                    (selected_kernel, runtime_specification)
239                {
240                    // Top priority is the selected kernel
241                    return runtime_specification.name.to_lowercase() == selected.to_lowercase();
242                }
243                false
244            })
245            .cloned();
246
247        if let Some(found_by_name) = found_by_name {
248            return Some(found_by_name);
249        }
250
251        self.kernel_specifications
252            .iter()
253            .find(|kernel_option| match kernel_option {
254                KernelSpecification::Jupyter(runtime_specification) => {
255                    runtime_specification.kernelspec.language.to_lowercase()
256                        == language_at_cursor.code_fence_block_name().to_lowercase()
257                }
258                KernelSpecification::PythonEnv(runtime_specification) => {
259                    runtime_specification.kernelspec.language.to_lowercase()
260                        == language_at_cursor.code_fence_block_name().to_lowercase()
261                }
262                KernelSpecification::Remote(remote_spec) => {
263                    remote_spec.kernelspec.language.to_lowercase()
264                        == language_at_cursor.code_fence_block_name().to_lowercase()
265                }
266            })
267            .cloned()
268    }
269
270    pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
271        self.sessions.get(&entity_id)
272    }
273
274    pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
275        self.sessions.insert(entity_id, session);
276    }
277
278    pub fn remove_session(&mut self, entity_id: EntityId) {
279        self.sessions.remove(&entity_id);
280    }
281
282    #[cfg(test)]
283    pub fn set_kernel_specs_for_testing(
284        &mut self,
285        specs: Vec<KernelSpecification>,
286        cx: &mut Context<Self>,
287    ) {
288        self.kernel_specifications = specs;
289        cx.notify();
290    }
291}