1use std::sync::Arc;
  2
  3use anyhow::{Context as _, Result};
  4use collections::HashMap;
  5use command_palette_hooks::CommandPaletteFilter;
  6use gpui::{App, Context, Entity, EntityId, Global, Subscription, Task, prelude::*};
  7use jupyter_websocket_client::RemoteServer;
  8use language::Language;
  9use project::{Fs, Project, WorktreeId};
 10use settings::{Settings, SettingsStore};
 11
 12use crate::kernels::{
 13    list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
 14};
 15use crate::{JupyterSettings, KernelSpecification, Session};
 16
 17struct GlobalReplStore(Entity<ReplStore>);
 18
 19impl Global for GlobalReplStore {}
 20
 21pub struct ReplStore {
 22    fs: Arc<dyn Fs>,
 23    enabled: bool,
 24    sessions: HashMap<EntityId, Entity<Session>>,
 25    kernel_specifications: Vec<KernelSpecification>,
 26    selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
 27    kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
 28    _subscriptions: Vec<Subscription>,
 29}
 30
 31impl ReplStore {
 32    const NAMESPACE: &'static str = "repl";
 33
 34    pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
 35        let store = cx.new(move |cx| Self::new(fs, cx));
 36
 37        store
 38            .update(cx, |store, cx| store.refresh_kernelspecs(cx))
 39            .detach_and_log_err(cx);
 40
 41        cx.set_global(GlobalReplStore(store))
 42    }
 43
 44    pub fn global(cx: &App) -> Entity<Self> {
 45        cx.global::<GlobalReplStore>().0.clone()
 46    }
 47
 48    pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
 49        let subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
 50            this.set_enabled(JupyterSettings::enabled(cx), cx);
 51        })];
 52
 53        let this = Self {
 54            fs,
 55            enabled: JupyterSettings::enabled(cx),
 56            sessions: HashMap::default(),
 57            kernel_specifications: Vec::new(),
 58            _subscriptions: subscriptions,
 59            kernel_specifications_for_worktree: HashMap::default(),
 60            selected_kernel_for_worktree: HashMap::default(),
 61        };
 62        this.on_enabled_changed(cx);
 63        this
 64    }
 65
 66    pub fn fs(&self) -> &Arc<dyn Fs> {
 67        &self.fs
 68    }
 69
 70    pub fn is_enabled(&self) -> bool {
 71        self.enabled
 72    }
 73
 74    pub fn kernel_specifications_for_worktree(
 75        &self,
 76        worktree_id: WorktreeId,
 77    ) -> impl Iterator<Item = &KernelSpecification> {
 78        self.kernel_specifications_for_worktree
 79            .get(&worktree_id)
 80            .into_iter()
 81            .flat_map(|specs| specs.iter())
 82            .chain(self.kernel_specifications.iter())
 83    }
 84
 85    pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
 86        self.kernel_specifications.iter()
 87    }
 88
 89    pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
 90        self.sessions.values()
 91    }
 92
 93    fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
 94        if self.enabled == enabled {
 95            return;
 96        }
 97
 98        self.enabled = enabled;
 99        self.on_enabled_changed(cx);
100    }
101
102    fn on_enabled_changed(&self, cx: &mut Context<Self>) {
103        if !self.enabled {
104            CommandPaletteFilter::update_global(cx, |filter, _cx| {
105                filter.hide_namespace(Self::NAMESPACE);
106            });
107
108            return;
109        }
110
111        CommandPaletteFilter::update_global(cx, |filter, _cx| {
112            filter.show_namespace(Self::NAMESPACE);
113        });
114
115        cx.notify();
116    }
117
118    pub fn refresh_python_kernelspecs(
119        &mut self,
120        worktree_id: WorktreeId,
121        project: &Entity<Project>,
122        cx: &mut Context<Self>,
123    ) -> Task<Result<()>> {
124        let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
125        cx.spawn(async move |this, cx| {
126            let kernel_specifications = kernel_specifications
127                .await
128                .context("getting python kernelspecs")?;
129
130            this.update(cx, |this, cx| {
131                this.kernel_specifications_for_worktree
132                    .insert(worktree_id, kernel_specifications);
133                cx.notify();
134            })
135        })
136    }
137
138    fn get_remote_kernel_specifications(
139        &self,
140        cx: &mut Context<Self>,
141    ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
142        match (
143            std::env::var("JUPYTER_SERVER"),
144            std::env::var("JUPYTER_TOKEN"),
145        ) {
146            (Ok(server), Ok(token)) => {
147                let remote_server = RemoteServer {
148                    base_url: server,
149                    token,
150                };
151                let http_client = cx.http_client();
152                Some(cx.spawn(async move |_, _| {
153                    list_remote_kernelspecs(remote_server, http_client)
154                        .await
155                        .map(|specs| specs.into_iter().map(KernelSpecification::Remote).collect())
156                }))
157            }
158            _ => None,
159        }
160    }
161
162    pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
163        let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
164
165        let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
166
167        let all_specs = cx.background_spawn(async move {
168            let mut all_specs = local_kernel_specifications
169                .await?
170                .into_iter()
171                .map(KernelSpecification::Jupyter)
172                .collect::<Vec<_>>();
173
174            if let Some(remote_task) = remote_kernel_specifications
175                && let Ok(remote_specs) = remote_task.await
176            {
177                all_specs.extend(remote_specs);
178            }
179
180            anyhow::Ok(all_specs)
181        });
182
183        cx.spawn(async move |this, cx| {
184            let all_specs = all_specs.await;
185
186            if let Ok(specs) = all_specs {
187                this.update(cx, |this, cx| {
188                    this.kernel_specifications = specs;
189                    cx.notify();
190                })
191                .ok();
192            }
193
194            anyhow::Ok(())
195        })
196    }
197
198    pub fn set_active_kernelspec(
199        &mut self,
200        worktree_id: WorktreeId,
201        kernelspec: KernelSpecification,
202        _cx: &mut Context<Self>,
203    ) {
204        self.selected_kernel_for_worktree
205            .insert(worktree_id, kernelspec);
206    }
207
208    pub fn active_kernelspec(
209        &self,
210        worktree_id: WorktreeId,
211        language_at_cursor: Option<Arc<Language>>,
212        cx: &App,
213    ) -> Option<KernelSpecification> {
214        let selected_kernelspec = self.selected_kernel_for_worktree.get(&worktree_id).cloned();
215
216        if let Some(language_at_cursor) = language_at_cursor {
217            selected_kernelspec.or_else(|| {
218                self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
219            })
220        } else {
221            selected_kernelspec
222        }
223    }
224
225    fn kernelspec_legacy_by_lang_only(
226        &self,
227        worktree_id: WorktreeId,
228        language_at_cursor: Arc<Language>,
229        cx: &App,
230    ) -> Option<KernelSpecification> {
231        let settings = JupyterSettings::get_global(cx);
232        let selected_kernel = settings
233            .kernel_selections
234            .get(language_at_cursor.code_fence_block_name().as_ref());
235
236        let found_by_name = self
237            .kernel_specifications_for_worktree(worktree_id)
238            .find(|runtime_specification| {
239                if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
240                    (selected_kernel, runtime_specification)
241                {
242                    // Top priority is the selected kernel
243                    return runtime_specification.name.to_lowercase() == selected.to_lowercase();
244                }
245                false
246            })
247            .cloned();
248
249        if let Some(found_by_name) = found_by_name {
250            return Some(found_by_name);
251        }
252
253        self.kernel_specifications_for_worktree(worktree_id)
254            .find(|kernel_option| match kernel_option {
255                KernelSpecification::Jupyter(runtime_specification) => {
256                    runtime_specification.kernelspec.language.to_lowercase()
257                        == language_at_cursor.code_fence_block_name().to_lowercase()
258                }
259                KernelSpecification::PythonEnv(runtime_specification) => {
260                    runtime_specification.kernelspec.language.to_lowercase()
261                        == language_at_cursor.code_fence_block_name().to_lowercase()
262                }
263                KernelSpecification::Remote(remote_spec) => {
264                    remote_spec.kernelspec.language.to_lowercase()
265                        == language_at_cursor.code_fence_block_name().to_lowercase()
266                }
267            })
268            .cloned()
269    }
270
271    pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
272        self.sessions.get(&entity_id)
273    }
274
275    pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
276        self.sessions.insert(entity_id, session);
277    }
278
279    pub fn remove_session(&mut self, entity_id: EntityId) {
280        self.sessions.remove(&entity_id);
281    }
282
283    #[cfg(test)]
284    pub fn set_kernel_specs_for_testing(
285        &mut self,
286        specs: Vec<KernelSpecification>,
287        cx: &mut Context<Self>,
288    ) {
289        self.kernel_specifications = specs;
290        cx.notify();
291    }
292}