1use std::future::Future;
2use std::sync::Arc;
3
4use anyhow::{Context as _, Result};
5use collections::{HashMap, HashSet};
6use command_palette_hooks::CommandPaletteFilter;
7use gpui::{App, Context, Entity, EntityId, Global, SharedString, Subscription, Task, prelude::*};
8use jupyter_websocket_client::RemoteServer;
9use language::{Language, LanguageName};
10use project::{Fs, Project, ProjectPath, WorktreeId};
11use remote::RemoteConnectionOptions;
12use settings::{Settings, SettingsStore};
13use util::rel_path::RelPath;
14
15use crate::kernels::{
16 Kernel, list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
17 wsl_kernel_specifications,
18};
19use crate::{JupyterSettings, KernelSpecification, Session};
20
21struct GlobalReplStore(Entity<ReplStore>);
22
23impl Global for GlobalReplStore {}
24
25pub struct ReplStore {
26 fs: Arc<dyn Fs>,
27 enabled: bool,
28 sessions: HashMap<EntityId, Entity<Session>>,
29 kernel_specifications: Vec<KernelSpecification>,
30 selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
31 kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
32 active_python_toolchain_for_worktree: HashMap<WorktreeId, SharedString>,
33 remote_worktrees: HashSet<WorktreeId>,
34 _subscriptions: Vec<Subscription>,
35}
36
37impl ReplStore {
38 const NAMESPACE: &'static str = "repl";
39
40 pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
41 let store = cx.new(move |cx| Self::new(fs, cx));
42
43 #[cfg(not(feature = "test-support"))]
44 store
45 .update(cx, |store, cx| store.refresh_kernelspecs(cx))
46 .detach_and_log_err(cx);
47
48 cx.set_global(GlobalReplStore(store))
49 }
50
51 pub fn global(cx: &App) -> Entity<Self> {
52 cx.global::<GlobalReplStore>().0.clone()
53 }
54
55 pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
56 let subscriptions = vec![
57 cx.observe_global::<SettingsStore>(move |this, cx| {
58 this.set_enabled(JupyterSettings::enabled(cx), cx);
59 }),
60 cx.on_app_quit(Self::shutdown_all_sessions),
61 ];
62
63 let this = Self {
64 fs,
65 enabled: JupyterSettings::enabled(cx),
66 sessions: HashMap::default(),
67 kernel_specifications: Vec::new(),
68 _subscriptions: subscriptions,
69 kernel_specifications_for_worktree: HashMap::default(),
70 selected_kernel_for_worktree: HashMap::default(),
71 active_python_toolchain_for_worktree: HashMap::default(),
72 remote_worktrees: HashSet::default(),
73 };
74 this.on_enabled_changed(cx);
75 this
76 }
77
78 pub fn fs(&self) -> &Arc<dyn Fs> {
79 &self.fs
80 }
81
82 pub fn is_enabled(&self) -> bool {
83 self.enabled
84 }
85
86 pub fn has_python_kernelspecs(&self, worktree_id: WorktreeId) -> bool {
87 self.kernel_specifications_for_worktree
88 .contains_key(&worktree_id)
89 }
90
91 pub fn kernel_specifications_for_worktree(
92 &self,
93 worktree_id: WorktreeId,
94 ) -> impl Iterator<Item = &KernelSpecification> {
95 let global_specs = if self.remote_worktrees.contains(&worktree_id) {
96 None
97 } else {
98 Some(self.kernel_specifications.iter())
99 };
100
101 self.kernel_specifications_for_worktree
102 .get(&worktree_id)
103 .into_iter()
104 .flat_map(|specs| specs.iter())
105 .chain(global_specs.into_iter().flatten())
106 }
107
108 pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
109 self.kernel_specifications.iter()
110 }
111
112 pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
113 self.sessions.values()
114 }
115
116 fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
117 if self.enabled == enabled {
118 return;
119 }
120
121 self.enabled = enabled;
122 self.on_enabled_changed(cx);
123 }
124
125 fn on_enabled_changed(&self, cx: &mut Context<Self>) {
126 if !self.enabled {
127 CommandPaletteFilter::update_global(cx, |filter, _cx| {
128 filter.hide_namespace(Self::NAMESPACE);
129 });
130
131 return;
132 }
133
134 CommandPaletteFilter::update_global(cx, |filter, _cx| {
135 filter.show_namespace(Self::NAMESPACE);
136 });
137
138 cx.notify();
139 }
140
141 pub fn refresh_python_kernelspecs(
142 &mut self,
143 worktree_id: WorktreeId,
144 project: &Entity<Project>,
145 cx: &mut Context<Self>,
146 ) -> Task<Result<()>> {
147 let is_remote = project.read(cx).is_remote();
148 // WSL does require access to global kernel specs, so we only exclude remote worktrees that aren't WSL.
149 // TODO: a better way to handle WSL vs SSH/remote projects,
150 let is_wsl_remote = project
151 .read(cx)
152 .remote_connection_options(cx)
153 .map_or(false, |opts| {
154 matches!(opts, RemoteConnectionOptions::Wsl(_))
155 });
156 let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
157 let active_toolchain = project.read(cx).active_toolchain(
158 ProjectPath {
159 worktree_id,
160 path: RelPath::empty().into(),
161 },
162 LanguageName::new_static("Python"),
163 cx,
164 );
165
166 cx.spawn(async move |this, cx| {
167 let kernel_specifications = kernel_specifications
168 .await
169 .context("getting python kernelspecs")?;
170
171 let active_toolchain_path = active_toolchain.await.map(|toolchain| toolchain.path);
172
173 this.update(cx, |this, cx| {
174 this.kernel_specifications_for_worktree
175 .insert(worktree_id, kernel_specifications);
176 if let Some(path) = active_toolchain_path {
177 this.active_python_toolchain_for_worktree
178 .insert(worktree_id, path);
179 }
180 if is_remote && !is_wsl_remote {
181 this.remote_worktrees.insert(worktree_id);
182 } else {
183 this.remote_worktrees.remove(&worktree_id);
184 }
185 cx.notify();
186 })
187 })
188 }
189
190 fn get_remote_kernel_specifications(
191 &self,
192 cx: &mut Context<Self>,
193 ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
194 match (
195 std::env::var("JUPYTER_SERVER"),
196 std::env::var("JUPYTER_TOKEN"),
197 ) {
198 (Ok(server), Ok(token)) => {
199 let remote_server = RemoteServer {
200 base_url: server,
201 token,
202 };
203 let http_client = cx.http_client();
204 Some(cx.spawn(async move |_, _| {
205 list_remote_kernelspecs(remote_server, http_client)
206 .await
207 .map(|specs| {
208 specs
209 .into_iter()
210 .map(KernelSpecification::JupyterServer)
211 .collect()
212 })
213 }))
214 }
215 _ => None,
216 }
217 }
218
219 pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
220 let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
221 let wsl_kernel_specifications = wsl_kernel_specifications(cx.background_executor().clone());
222
223 let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
224
225 let all_specs = cx.background_spawn(async move {
226 let mut all_specs = local_kernel_specifications
227 .await?
228 .into_iter()
229 .map(KernelSpecification::Jupyter)
230 .collect::<Vec<_>>();
231
232 if let Ok(wsl_specs) = wsl_kernel_specifications.await {
233 all_specs.extend(wsl_specs);
234 }
235
236 if let Some(remote_task) = remote_kernel_specifications
237 && let Ok(remote_specs) = remote_task.await
238 {
239 all_specs.extend(remote_specs);
240 }
241
242 anyhow::Ok(all_specs)
243 });
244
245 cx.spawn(async move |this, cx| {
246 let all_specs = all_specs.await;
247
248 if let Ok(specs) = all_specs {
249 this.update(cx, |this, cx| {
250 this.kernel_specifications = specs;
251 cx.notify();
252 })
253 .ok();
254 }
255
256 anyhow::Ok(())
257 })
258 }
259
260 pub fn set_active_kernelspec(
261 &mut self,
262 worktree_id: WorktreeId,
263 kernelspec: KernelSpecification,
264 _cx: &mut Context<Self>,
265 ) {
266 self.selected_kernel_for_worktree
267 .insert(worktree_id, kernelspec);
268 }
269
270 pub fn active_python_toolchain_path(&self, worktree_id: WorktreeId) -> Option<&SharedString> {
271 self.active_python_toolchain_for_worktree.get(&worktree_id)
272 }
273
274 pub fn selected_kernel(&self, worktree_id: WorktreeId) -> Option<&KernelSpecification> {
275 self.selected_kernel_for_worktree.get(&worktree_id)
276 }
277
278 pub fn is_recommended_kernel(
279 &self,
280 worktree_id: WorktreeId,
281 spec: &KernelSpecification,
282 ) -> bool {
283 if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
284 spec.path().as_ref() == active_path.as_ref()
285 } else {
286 false
287 }
288 }
289
290 pub fn active_kernelspec(
291 &self,
292 worktree_id: WorktreeId,
293 language_at_cursor: Option<Arc<Language>>,
294 cx: &App,
295 ) -> Option<KernelSpecification> {
296 if let Some(selected) = self.selected_kernel_for_worktree.get(&worktree_id).cloned() {
297 return Some(selected);
298 }
299
300 let language_at_cursor = language_at_cursor?;
301
302 // Prefer the recommended (active toolchain) kernel if it has ipykernel
303 if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
304 let recommended = self
305 .kernel_specifications_for_worktree(worktree_id)
306 .find(|spec| {
307 spec.has_ipykernel()
308 && language_at_cursor.matches_kernel_language(spec.language().as_ref())
309 && spec.path().as_ref() == active_path.as_ref()
310 })
311 .cloned();
312 if recommended.is_some() {
313 return recommended;
314 }
315 }
316
317 // Then try the first PythonEnv with ipykernel matching the language
318 let python_env = self
319 .kernel_specifications_for_worktree(worktree_id)
320 .find(|spec| {
321 matches!(spec, KernelSpecification::PythonEnv(_))
322 && spec.has_ipykernel()
323 && language_at_cursor.matches_kernel_language(spec.language().as_ref())
324 })
325 .cloned();
326 if python_env.is_some() {
327 return python_env;
328 }
329
330 // Fall back to legacy name-based and language-based matching
331 self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
332 }
333
334 fn kernelspec_legacy_by_lang_only(
335 &self,
336 worktree_id: WorktreeId,
337 language_at_cursor: Arc<Language>,
338 cx: &App,
339 ) -> Option<KernelSpecification> {
340 let settings = JupyterSettings::get_global(cx);
341 let selected_kernel = settings
342 .kernel_selections
343 .get(language_at_cursor.code_fence_block_name().as_ref());
344
345 let found_by_name = self
346 .kernel_specifications_for_worktree(worktree_id)
347 .find(|runtime_specification| {
348 if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
349 (selected_kernel, runtime_specification)
350 {
351 return runtime_specification.name.to_lowercase() == selected.to_lowercase();
352 }
353 false
354 })
355 .cloned();
356
357 if let Some(found_by_name) = found_by_name {
358 return Some(found_by_name);
359 }
360
361 self.kernel_specifications_for_worktree(worktree_id)
362 .find(|spec| {
363 spec.has_ipykernel()
364 && language_at_cursor.matches_kernel_language(spec.language().as_ref())
365 })
366 .cloned()
367 }
368
369 pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
370 self.sessions.get(&entity_id)
371 }
372
373 pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
374 self.sessions.insert(entity_id, session);
375 }
376
377 pub fn remove_session(&mut self, entity_id: EntityId) {
378 self.sessions.remove(&entity_id);
379 }
380
381 fn shutdown_all_sessions(
382 &mut self,
383 cx: &mut Context<Self>,
384 ) -> impl Future<Output = ()> + use<> {
385 for session in self.sessions.values() {
386 session.update(cx, |session, _cx| {
387 if let Kernel::RunningKernel(mut kernel) =
388 std::mem::replace(&mut session.kernel, Kernel::Shutdown)
389 {
390 kernel.kill();
391 }
392 });
393 }
394 self.sessions.clear();
395 futures::future::ready(())
396 }
397
398 #[cfg(test)]
399 pub fn set_kernel_specs_for_testing(
400 &mut self,
401 specs: Vec<KernelSpecification>,
402 cx: &mut Context<Self>,
403 ) {
404 self.kernel_specifications = specs;
405 cx.notify();
406 }
407}