repl_store.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashMap;
  5use command_palette_hooks::CommandPaletteFilter;
  6use gpui::{
  7    prelude::*, AppContext, EntityId, Global, Model, ModelContext, Subscription, Task, View,
  8};
  9use jupyter_websocket_client::RemoteServer;
 10use language::Language;
 11use project::{Fs, Project, WorktreeId};
 12use settings::{Settings, SettingsStore};
 13
 14use crate::kernels::{
 15    list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
 16};
 17use crate::{JupyterSettings, KernelSpecification, Session};
 18
 19struct GlobalReplStore(Model<ReplStore>);
 20
 21impl Global for GlobalReplStore {}
 22
 23pub struct ReplStore {
 24    fs: Arc<dyn Fs>,
 25    enabled: bool,
 26    sessions: HashMap<EntityId, View<Session>>,
 27    kernel_specifications: Vec<KernelSpecification>,
 28    selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
 29    kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
 30    _subscriptions: Vec<Subscription>,
 31}
 32
 33impl ReplStore {
 34    const NAMESPACE: &'static str = "repl";
 35
 36    pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut AppContext) {
 37        let store = cx.new_model(move |cx| Self::new(fs, cx));
 38
 39        store
 40            .update(cx, |store, cx| store.refresh_kernelspecs(cx))
 41            .detach_and_log_err(cx);
 42
 43        cx.set_global(GlobalReplStore(store))
 44    }
 45
 46    pub fn global(cx: &AppContext) -> Model<Self> {
 47        cx.global::<GlobalReplStore>().0.clone()
 48    }
 49
 50    pub fn new(fs: Arc<dyn Fs>, cx: &mut ModelContext<Self>) -> Self {
 51        let subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
 52            this.set_enabled(JupyterSettings::enabled(cx), cx);
 53        })];
 54
 55        let this = Self {
 56            fs,
 57            enabled: JupyterSettings::enabled(cx),
 58            sessions: HashMap::default(),
 59            kernel_specifications: Vec::new(),
 60            _subscriptions: subscriptions,
 61            kernel_specifications_for_worktree: HashMap::default(),
 62            selected_kernel_for_worktree: HashMap::default(),
 63        };
 64        this.on_enabled_changed(cx);
 65        this
 66    }
 67
 68    pub fn fs(&self) -> &Arc<dyn Fs> {
 69        &self.fs
 70    }
 71
 72    pub fn is_enabled(&self) -> bool {
 73        self.enabled
 74    }
 75
 76    pub fn kernel_specifications_for_worktree(
 77        &self,
 78        worktree_id: WorktreeId,
 79    ) -> impl Iterator<Item = &KernelSpecification> {
 80        self.kernel_specifications_for_worktree
 81            .get(&worktree_id)
 82            .into_iter()
 83            .flat_map(|specs| specs.iter())
 84            .chain(self.kernel_specifications.iter())
 85    }
 86
 87    pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
 88        self.kernel_specifications.iter()
 89    }
 90
 91    pub fn sessions(&self) -> impl Iterator<Item = &View<Session>> {
 92        self.sessions.values()
 93    }
 94
 95    fn set_enabled(&mut self, enabled: bool, cx: &mut ModelContext<Self>) {
 96        if self.enabled == enabled {
 97            return;
 98        }
 99
100        self.enabled = enabled;
101        self.on_enabled_changed(cx);
102    }
103
104    fn on_enabled_changed(&self, cx: &mut ModelContext<Self>) {
105        if !self.enabled {
106            CommandPaletteFilter::update_global(cx, |filter, _cx| {
107                filter.hide_namespace(Self::NAMESPACE);
108            });
109
110            return;
111        }
112
113        CommandPaletteFilter::update_global(cx, |filter, _cx| {
114            filter.show_namespace(Self::NAMESPACE);
115        });
116
117        cx.notify();
118    }
119
120    pub fn refresh_python_kernelspecs(
121        &mut self,
122        worktree_id: WorktreeId,
123        project: &Model<Project>,
124        cx: &mut ModelContext<Self>,
125    ) -> Task<Result<()>> {
126        let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
127        cx.spawn(move |this, mut cx| async move {
128            let kernel_specifications = kernel_specifications
129                .await
130                .map_err(|e| anyhow::anyhow!("Failed to get python kernelspecs: {:?}", e))?;
131
132            this.update(&mut cx, |this, cx| {
133                this.kernel_specifications_for_worktree
134                    .insert(worktree_id, kernel_specifications);
135                cx.notify();
136            })
137        })
138    }
139
140    fn get_remote_kernel_specifications(
141        &self,
142        cx: &mut ModelContext<Self>,
143    ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
144        match (
145            std::env::var("JUPYTER_SERVER"),
146            std::env::var("JUPYTER_TOKEN"),
147        ) {
148            (Ok(server), Ok(token)) => {
149                let remote_server = RemoteServer {
150                    base_url: server,
151                    token,
152                };
153                let http_client = cx.http_client();
154                Some(cx.spawn(|_, _| async move {
155                    list_remote_kernelspecs(remote_server, http_client)
156                        .await
157                        .map(|specs| specs.into_iter().map(KernelSpecification::Remote).collect())
158                }))
159            }
160            _ => None,
161        }
162    }
163
164    pub fn refresh_kernelspecs(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
165        let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
166
167        let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
168
169        let all_specs = cx.background_executor().spawn(async move {
170            let mut all_specs = local_kernel_specifications
171                .await?
172                .into_iter()
173                .map(KernelSpecification::Jupyter)
174                .collect::<Vec<_>>();
175
176            if let Some(remote_task) = remote_kernel_specifications {
177                if let Ok(remote_specs) = remote_task.await {
178                    all_specs.extend(remote_specs);
179                }
180            }
181
182            anyhow::Ok(all_specs)
183        });
184
185        cx.spawn(|this, mut cx| async move {
186            let all_specs = all_specs.await;
187
188            if let Ok(specs) = all_specs {
189                this.update(&mut cx, |this, cx| {
190                    this.kernel_specifications = specs;
191                    cx.notify();
192                })
193                .ok();
194            }
195
196            anyhow::Ok(())
197        })
198    }
199
200    pub fn set_active_kernelspec(
201        &mut self,
202        worktree_id: WorktreeId,
203        kernelspec: KernelSpecification,
204        _cx: &mut ModelContext<Self>,
205    ) {
206        self.selected_kernel_for_worktree
207            .insert(worktree_id, kernelspec);
208    }
209
210    pub fn active_kernelspec(
211        &self,
212        worktree_id: WorktreeId,
213        language_at_cursor: Option<Arc<Language>>,
214        cx: &AppContext,
215    ) -> Option<KernelSpecification> {
216        let selected_kernelspec = self.selected_kernel_for_worktree.get(&worktree_id).cloned();
217
218        if let Some(language_at_cursor) = language_at_cursor {
219            selected_kernelspec
220                .or_else(|| self.kernelspec_legacy_by_lang_only(language_at_cursor, cx))
221        } else {
222            selected_kernelspec
223        }
224    }
225
226    fn kernelspec_legacy_by_lang_only(
227        &self,
228        language_at_cursor: Arc<Language>,
229        cx: &AppContext,
230    ) -> Option<KernelSpecification> {
231        let settings = JupyterSettings::get_global(cx);
232        let selected_kernel = settings
233            .kernel_selections
234            .get(language_at_cursor.code_fence_block_name().as_ref());
235
236        let found_by_name = self
237            .kernel_specifications
238            .iter()
239            .find(|runtime_specification| {
240                if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
241                    (selected_kernel, runtime_specification)
242                {
243                    // Top priority is the selected kernel
244                    return runtime_specification.name.to_lowercase() == selected.to_lowercase();
245                }
246                false
247            })
248            .cloned();
249
250        if let Some(found_by_name) = found_by_name {
251            return Some(found_by_name);
252        }
253
254        self.kernel_specifications
255            .iter()
256            .find(|kernel_option| match kernel_option {
257                KernelSpecification::Jupyter(runtime_specification) => {
258                    runtime_specification.kernelspec.language.to_lowercase()
259                        == language_at_cursor.code_fence_block_name().to_lowercase()
260                }
261                KernelSpecification::PythonEnv(runtime_specification) => {
262                    runtime_specification.kernelspec.language.to_lowercase()
263                        == language_at_cursor.code_fence_block_name().to_lowercase()
264                }
265                KernelSpecification::Remote(remote_spec) => {
266                    remote_spec.kernelspec.language.to_lowercase()
267                        == language_at_cursor.code_fence_block_name().to_lowercase()
268                }
269            })
270            .cloned()
271    }
272
273    pub fn get_session(&self, entity_id: EntityId) -> Option<&View<Session>> {
274        self.sessions.get(&entity_id)
275    }
276
277    pub fn insert_session(&mut self, entity_id: EntityId, session: View<Session>) {
278        self.sessions.insert(entity_id, session);
279    }
280
281    pub fn remove_session(&mut self, entity_id: EntityId) {
282        self.sessions.remove(&entity_id);
283    }
284}