repl_store.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use client::telemetry::Telemetry;
  5use collections::HashMap;
  6use command_palette_hooks::CommandPaletteFilter;
  7use gpui::{
  8    prelude::*, AppContext, EntityId, Global, Model, ModelContext, Subscription, Task, View,
  9};
 10use jupyter_websocket_client::RemoteServer;
 11use language::Language;
 12use project::{Fs, Project, WorktreeId};
 13use settings::{Settings, SettingsStore};
 14
 15use crate::kernels::{
 16    list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
 17};
 18use crate::{JupyterSettings, KernelSpecification, Session};
 19
 20struct GlobalReplStore(Model<ReplStore>);
 21
 22impl Global for GlobalReplStore {}
 23
 24pub struct ReplStore {
 25    fs: Arc<dyn Fs>,
 26    enabled: bool,
 27    sessions: HashMap<EntityId, View<Session>>,
 28    kernel_specifications: Vec<KernelSpecification>,
 29    selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
 30    kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
 31    telemetry: Arc<Telemetry>,
 32    _subscriptions: Vec<Subscription>,
 33}
 34
 35impl ReplStore {
 36    const NAMESPACE: &'static str = "repl";
 37
 38    pub(crate) fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
 39        let store = cx.new_model(move |cx| Self::new(fs, telemetry, cx));
 40
 41        store
 42            .update(cx, |store, cx| store.refresh_kernelspecs(cx))
 43            .detach_and_log_err(cx);
 44
 45        cx.set_global(GlobalReplStore(store))
 46    }
 47
 48    pub fn global(cx: &AppContext) -> Model<Self> {
 49        cx.global::<GlobalReplStore>().0.clone()
 50    }
 51
 52    pub fn new(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut ModelContext<Self>) -> Self {
 53        let subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
 54            this.set_enabled(JupyterSettings::enabled(cx), cx);
 55        })];
 56
 57        let this = Self {
 58            fs,
 59            telemetry,
 60            enabled: JupyterSettings::enabled(cx),
 61            sessions: HashMap::default(),
 62            kernel_specifications: Vec::new(),
 63            _subscriptions: subscriptions,
 64            kernel_specifications_for_worktree: HashMap::default(),
 65            selected_kernel_for_worktree: HashMap::default(),
 66        };
 67        this.on_enabled_changed(cx);
 68        this
 69    }
 70
 71    pub fn fs(&self) -> &Arc<dyn Fs> {
 72        &self.fs
 73    }
 74
 75    pub fn telemetry(&self) -> &Arc<Telemetry> {
 76        &self.telemetry
 77    }
 78
 79    pub fn is_enabled(&self) -> bool {
 80        self.enabled
 81    }
 82
 83    pub fn kernel_specifications_for_worktree(
 84        &self,
 85        worktree_id: WorktreeId,
 86    ) -> impl Iterator<Item = &KernelSpecification> {
 87        self.kernel_specifications_for_worktree
 88            .get(&worktree_id)
 89            .into_iter()
 90            .flat_map(|specs| specs.iter())
 91            .chain(self.kernel_specifications.iter())
 92    }
 93
 94    pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
 95        self.kernel_specifications.iter()
 96    }
 97
 98    pub fn sessions(&self) -> impl Iterator<Item = &View<Session>> {
 99        self.sessions.values()
100    }
101
102    fn set_enabled(&mut self, enabled: bool, cx: &mut ModelContext<Self>) {
103        if self.enabled == enabled {
104            return;
105        }
106
107        self.enabled = enabled;
108        self.on_enabled_changed(cx);
109    }
110
111    fn on_enabled_changed(&self, cx: &mut ModelContext<Self>) {
112        if !self.enabled {
113            CommandPaletteFilter::update_global(cx, |filter, _cx| {
114                filter.hide_namespace(Self::NAMESPACE);
115            });
116
117            return;
118        }
119
120        CommandPaletteFilter::update_global(cx, |filter, _cx| {
121            filter.show_namespace(Self::NAMESPACE);
122        });
123
124        cx.notify();
125    }
126
127    pub fn refresh_python_kernelspecs(
128        &mut self,
129        worktree_id: WorktreeId,
130        project: &Model<Project>,
131        cx: &mut ModelContext<Self>,
132    ) -> Task<Result<()>> {
133        let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
134        cx.spawn(move |this, mut cx| async move {
135            let kernel_specifications = kernel_specifications
136                .await
137                .map_err(|e| anyhow::anyhow!("Failed to get python kernelspecs: {:?}", e))?;
138
139            this.update(&mut cx, |this, cx| {
140                this.kernel_specifications_for_worktree
141                    .insert(worktree_id, kernel_specifications);
142                cx.notify();
143            })
144        })
145    }
146
147    fn get_remote_kernel_specifications(
148        &self,
149        cx: &mut ModelContext<Self>,
150    ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
151        match (
152            std::env::var("JUPYTER_SERVER"),
153            std::env::var("JUPYTER_TOKEN"),
154        ) {
155            (Ok(server), Ok(token)) => {
156                let remote_server = RemoteServer {
157                    base_url: server,
158                    token,
159                };
160                let http_client = cx.http_client();
161                Some(cx.spawn(|_, _| async move {
162                    list_remote_kernelspecs(remote_server, http_client)
163                        .await
164                        .map(|specs| specs.into_iter().map(KernelSpecification::Remote).collect())
165                }))
166            }
167            _ => None,
168        }
169    }
170
171    pub fn refresh_kernelspecs(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
172        let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
173
174        let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
175
176        let all_specs = cx.background_executor().spawn(async move {
177            let mut all_specs = local_kernel_specifications
178                .await?
179                .into_iter()
180                .map(KernelSpecification::Jupyter)
181                .collect::<Vec<_>>();
182
183            if let Some(remote_task) = remote_kernel_specifications {
184                if let Ok(remote_specs) = remote_task.await {
185                    all_specs.extend(remote_specs);
186                }
187            }
188
189            anyhow::Ok(all_specs)
190        });
191
192        cx.spawn(|this, mut cx| async move {
193            let all_specs = all_specs.await;
194
195            if let Ok(specs) = all_specs {
196                this.update(&mut cx, |this, cx| {
197                    this.kernel_specifications = specs;
198                    cx.notify();
199                })
200                .ok();
201            }
202
203            anyhow::Ok(())
204        })
205    }
206
207    pub fn set_active_kernelspec(
208        &mut self,
209        worktree_id: WorktreeId,
210        kernelspec: KernelSpecification,
211        _cx: &mut ModelContext<Self>,
212    ) {
213        self.selected_kernel_for_worktree
214            .insert(worktree_id, kernelspec);
215    }
216
217    pub fn active_kernelspec(
218        &self,
219        worktree_id: WorktreeId,
220        language_at_cursor: Option<Arc<Language>>,
221        cx: &AppContext,
222    ) -> Option<KernelSpecification> {
223        let selected_kernelspec = self.selected_kernel_for_worktree.get(&worktree_id).cloned();
224
225        if let Some(language_at_cursor) = language_at_cursor {
226            selected_kernelspec
227                .or_else(|| self.kernelspec_legacy_by_lang_only(language_at_cursor, cx))
228        } else {
229            selected_kernelspec
230        }
231    }
232
233    fn kernelspec_legacy_by_lang_only(
234        &self,
235        language_at_cursor: Arc<Language>,
236        cx: &AppContext,
237    ) -> Option<KernelSpecification> {
238        let settings = JupyterSettings::get_global(cx);
239        let selected_kernel = settings
240            .kernel_selections
241            .get(language_at_cursor.code_fence_block_name().as_ref());
242
243        let found_by_name = self
244            .kernel_specifications
245            .iter()
246            .find(|runtime_specification| {
247                if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
248                    (selected_kernel, runtime_specification)
249                {
250                    // Top priority is the selected kernel
251                    return runtime_specification.name.to_lowercase() == selected.to_lowercase();
252                }
253                false
254            })
255            .cloned();
256
257        if let Some(found_by_name) = found_by_name {
258            return Some(found_by_name);
259        }
260
261        self.kernel_specifications
262            .iter()
263            .find(|kernel_option| match kernel_option {
264                KernelSpecification::Jupyter(runtime_specification) => {
265                    runtime_specification.kernelspec.language.to_lowercase()
266                        == language_at_cursor.code_fence_block_name().to_lowercase()
267                }
268                KernelSpecification::PythonEnv(runtime_specification) => {
269                    runtime_specification.kernelspec.language.to_lowercase()
270                        == language_at_cursor.code_fence_block_name().to_lowercase()
271                }
272                KernelSpecification::Remote(remote_spec) => {
273                    remote_spec.kernelspec.language.to_lowercase()
274                        == language_at_cursor.code_fence_block_name().to_lowercase()
275                }
276            })
277            .cloned()
278    }
279
280    pub fn get_session(&self, entity_id: EntityId) -> Option<&View<Session>> {
281        self.sessions.get(&entity_id)
282    }
283
284    pub fn insert_session(&mut self, entity_id: EntityId, session: View<Session>) {
285        self.sessions.insert(entity_id, session);
286    }
287
288    pub fn remove_session(&mut self, entity_id: EntityId) {
289        self.sessions.remove(&entity_id);
290    }
291}