1use std::future::Future;
2use std::sync::Arc;
3
4use anyhow::{Context as _, Result};
5use collections::{HashMap, HashSet};
6use command_palette_hooks::CommandPaletteFilter;
7use gpui::{App, Context, Entity, EntityId, Global, SharedString, Subscription, Task, prelude::*};
8use jupyter_websocket_client::RemoteServer;
9use language::{Language, LanguageName};
10use project::{Fs, Project, ProjectPath, WorktreeId};
11use remote::RemoteConnectionOptions;
12use settings::{Settings, SettingsStore};
13use util::rel_path::RelPath;
14
15use crate::kernels::{
16 Kernel, list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
17 wsl_kernel_specifications,
18};
19use crate::{JupyterSettings, KernelSpecification, Session};
20
21struct GlobalReplStore(Entity<ReplStore>);
22
23impl Global for GlobalReplStore {}
24
25pub struct ReplStore {
26 fs: Arc<dyn Fs>,
27 enabled: bool,
28 sessions: HashMap<EntityId, Entity<Session>>,
29 kernel_specifications: Vec<KernelSpecification>,
30 kernelspecs_initialized: bool,
31 selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
32 kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
33 active_python_toolchain_for_worktree: HashMap<WorktreeId, SharedString>,
34 remote_worktrees: HashSet<WorktreeId>,
35 _subscriptions: Vec<Subscription>,
36}
37
38impl ReplStore {
39 const NAMESPACE: &'static str = "repl";
40
41 pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
42 let store = cx.new(move |cx| Self::new(fs, cx));
43 cx.set_global(GlobalReplStore(store))
44 }
45
46 pub fn global(cx: &App) -> Entity<Self> {
47 cx.global::<GlobalReplStore>().0.clone()
48 }
49
50 pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
51 let subscriptions = vec![
52 cx.observe_global::<SettingsStore>(move |this, cx| {
53 this.set_enabled(JupyterSettings::enabled(cx), cx);
54 }),
55 cx.on_app_quit(Self::shutdown_all_sessions),
56 ];
57
58 let this = Self {
59 fs,
60 enabled: JupyterSettings::enabled(cx),
61 sessions: HashMap::default(),
62 kernel_specifications: Vec::new(),
63 kernelspecs_initialized: false,
64 _subscriptions: subscriptions,
65 kernel_specifications_for_worktree: HashMap::default(),
66 selected_kernel_for_worktree: HashMap::default(),
67 active_python_toolchain_for_worktree: HashMap::default(),
68 remote_worktrees: HashSet::default(),
69 };
70 this.on_enabled_changed(cx);
71 this
72 }
73
74 pub fn fs(&self) -> &Arc<dyn Fs> {
75 &self.fs
76 }
77
78 pub fn is_enabled(&self) -> bool {
79 self.enabled
80 }
81
82 pub fn has_python_kernelspecs(&self, worktree_id: WorktreeId) -> bool {
83 self.kernel_specifications_for_worktree
84 .contains_key(&worktree_id)
85 }
86
87 pub fn kernel_specifications_for_worktree(
88 &self,
89 worktree_id: WorktreeId,
90 ) -> impl Iterator<Item = &KernelSpecification> {
91 let global_specs = if self.remote_worktrees.contains(&worktree_id) {
92 None
93 } else {
94 Some(self.kernel_specifications.iter())
95 };
96
97 self.kernel_specifications_for_worktree
98 .get(&worktree_id)
99 .into_iter()
100 .flat_map(|specs| specs.iter())
101 .chain(global_specs.into_iter().flatten())
102 }
103
104 pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
105 self.kernel_specifications.iter()
106 }
107
108 pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
109 self.sessions.values()
110 }
111
112 fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
113 if self.enabled == enabled {
114 return;
115 }
116
117 self.enabled = enabled;
118 self.on_enabled_changed(cx);
119 }
120
121 fn on_enabled_changed(&self, cx: &mut Context<Self>) {
122 if !self.enabled {
123 CommandPaletteFilter::update_global(cx, |filter, _cx| {
124 filter.hide_namespace(Self::NAMESPACE);
125 });
126
127 return;
128 }
129
130 CommandPaletteFilter::update_global(cx, |filter, _cx| {
131 filter.show_namespace(Self::NAMESPACE);
132 });
133
134 cx.notify();
135 }
136
137 pub fn refresh_python_kernelspecs(
138 &mut self,
139 worktree_id: WorktreeId,
140 project: &Entity<Project>,
141 cx: &mut Context<Self>,
142 ) -> Task<Result<()>> {
143 let is_remote = project.read(cx).is_remote();
144 // WSL does require access to global kernel specs, so we only exclude remote worktrees that aren't WSL.
145 // TODO: a better way to handle WSL vs SSH/remote projects,
146 let is_wsl_remote = project
147 .read(cx)
148 .remote_connection_options(cx)
149 .map_or(false, |opts| {
150 matches!(opts, RemoteConnectionOptions::Wsl(_))
151 });
152 let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
153 let active_toolchain = project.read(cx).active_toolchain(
154 ProjectPath {
155 worktree_id,
156 path: RelPath::empty().into(),
157 },
158 LanguageName::new_static("Python"),
159 cx,
160 );
161
162 cx.spawn(async move |this, cx| {
163 let kernel_specifications = kernel_specifications
164 .await
165 .context("getting python kernelspecs")?;
166
167 let active_toolchain_path = active_toolchain.await.map(|toolchain| toolchain.path);
168
169 this.update(cx, |this, cx| {
170 this.kernel_specifications_for_worktree
171 .insert(worktree_id, kernel_specifications);
172 if let Some(path) = active_toolchain_path {
173 this.active_python_toolchain_for_worktree
174 .insert(worktree_id, path);
175 }
176 if is_remote && !is_wsl_remote {
177 this.remote_worktrees.insert(worktree_id);
178 } else {
179 this.remote_worktrees.remove(&worktree_id);
180 }
181 cx.notify();
182 })
183 })
184 }
185
186 fn get_remote_kernel_specifications(
187 &self,
188 cx: &mut Context<Self>,
189 ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
190 match (
191 std::env::var("JUPYTER_SERVER"),
192 std::env::var("JUPYTER_TOKEN"),
193 ) {
194 (Ok(server), Ok(token)) => {
195 let remote_server = RemoteServer {
196 base_url: server,
197 token,
198 };
199 let http_client = cx.http_client();
200 Some(cx.spawn(async move |_, _| {
201 list_remote_kernelspecs(remote_server, http_client)
202 .await
203 .map(|specs| {
204 specs
205 .into_iter()
206 .map(KernelSpecification::JupyterServer)
207 .collect()
208 })
209 }))
210 }
211 _ => None,
212 }
213 }
214
215 pub fn ensure_kernelspecs(&mut self, cx: &mut Context<Self>) {
216 if self.kernelspecs_initialized {
217 return;
218 }
219 self.kernelspecs_initialized = true;
220 self.refresh_kernelspecs(cx).detach_and_log_err(cx);
221 }
222
223 pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
224 let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
225 let wsl_kernel_specifications = wsl_kernel_specifications(cx.background_executor().clone());
226 let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
227
228 let all_specs = cx.background_spawn(async move {
229 let mut all_specs = local_kernel_specifications
230 .await?
231 .into_iter()
232 .map(KernelSpecification::Jupyter)
233 .collect::<Vec<_>>();
234
235 if let Ok(wsl_specs) = wsl_kernel_specifications.await {
236 all_specs.extend(wsl_specs);
237 }
238
239 if let Some(remote_task) = remote_kernel_specifications
240 && let Ok(remote_specs) = remote_task.await
241 {
242 all_specs.extend(remote_specs);
243 }
244
245 anyhow::Ok(all_specs)
246 });
247
248 cx.spawn(async move |this, cx| {
249 let all_specs = all_specs.await;
250
251 if let Ok(specs) = all_specs {
252 this.update(cx, |this, cx| {
253 this.kernel_specifications = specs;
254 cx.notify();
255 })
256 .ok();
257 }
258
259 anyhow::Ok(())
260 })
261 }
262
263 pub fn set_active_kernelspec(
264 &mut self,
265 worktree_id: WorktreeId,
266 kernelspec: KernelSpecification,
267 _cx: &mut Context<Self>,
268 ) {
269 self.selected_kernel_for_worktree
270 .insert(worktree_id, kernelspec);
271 }
272
273 pub fn active_python_toolchain_path(&self, worktree_id: WorktreeId) -> Option<&SharedString> {
274 self.active_python_toolchain_for_worktree.get(&worktree_id)
275 }
276
277 pub fn selected_kernel(&self, worktree_id: WorktreeId) -> Option<&KernelSpecification> {
278 self.selected_kernel_for_worktree.get(&worktree_id)
279 }
280
281 pub fn is_recommended_kernel(
282 &self,
283 worktree_id: WorktreeId,
284 spec: &KernelSpecification,
285 ) -> bool {
286 if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
287 spec.path().as_ref() == active_path.as_ref()
288 } else {
289 false
290 }
291 }
292
293 pub fn active_kernelspec(
294 &self,
295 worktree_id: WorktreeId,
296 language_at_cursor: Option<Arc<Language>>,
297 cx: &App,
298 ) -> Option<KernelSpecification> {
299 if let Some(selected) = self.selected_kernel_for_worktree.get(&worktree_id).cloned() {
300 return Some(selected);
301 }
302
303 let language_at_cursor = language_at_cursor?;
304
305 // Prefer the recommended (active toolchain) kernel if it has ipykernel
306 if let Some(active_path) = self.active_python_toolchain_path(worktree_id) {
307 let recommended = self
308 .kernel_specifications_for_worktree(worktree_id)
309 .find(|spec| {
310 spec.has_ipykernel()
311 && language_at_cursor.matches_kernel_language(spec.language().as_ref())
312 && spec.path().as_ref() == active_path.as_ref()
313 })
314 .cloned();
315 if recommended.is_some() {
316 return recommended;
317 }
318 }
319
320 // Then try the first PythonEnv with ipykernel matching the language
321 let python_env = self
322 .kernel_specifications_for_worktree(worktree_id)
323 .find(|spec| {
324 matches!(spec, KernelSpecification::PythonEnv(_))
325 && spec.has_ipykernel()
326 && language_at_cursor.matches_kernel_language(spec.language().as_ref())
327 })
328 .cloned();
329 if python_env.is_some() {
330 return python_env;
331 }
332
333 // Fall back to legacy name-based and language-based matching
334 self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
335 }
336
337 fn kernelspec_legacy_by_lang_only(
338 &self,
339 worktree_id: WorktreeId,
340 language_at_cursor: Arc<Language>,
341 cx: &App,
342 ) -> Option<KernelSpecification> {
343 let settings = JupyterSettings::get_global(cx);
344 let selected_kernel = settings
345 .kernel_selections
346 .get(language_at_cursor.code_fence_block_name().as_ref());
347
348 let found_by_name = self
349 .kernel_specifications_for_worktree(worktree_id)
350 .find(|runtime_specification| {
351 if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
352 (selected_kernel, runtime_specification)
353 {
354 return runtime_specification.name.to_lowercase() == selected.to_lowercase();
355 }
356 false
357 })
358 .cloned();
359
360 if let Some(found_by_name) = found_by_name {
361 return Some(found_by_name);
362 }
363
364 self.kernel_specifications_for_worktree(worktree_id)
365 .find(|spec| {
366 spec.has_ipykernel()
367 && language_at_cursor.matches_kernel_language(spec.language().as_ref())
368 })
369 .cloned()
370 }
371
372 pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
373 self.sessions.get(&entity_id)
374 }
375
376 pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
377 self.sessions.insert(entity_id, session);
378 }
379
380 pub fn remove_session(&mut self, entity_id: EntityId) {
381 self.sessions.remove(&entity_id);
382 }
383
384 fn shutdown_all_sessions(
385 &mut self,
386 cx: &mut Context<Self>,
387 ) -> impl Future<Output = ()> + use<> {
388 for session in self.sessions.values() {
389 session.update(cx, |session, _cx| {
390 if let Kernel::RunningKernel(mut kernel) =
391 std::mem::replace(&mut session.kernel, Kernel::Shutdown)
392 {
393 kernel.kill();
394 }
395 });
396 }
397 self.sessions.clear();
398 futures::future::ready(())
399 }
400
401 #[cfg(test)]
402 pub fn set_kernel_specs_for_testing(
403 &mut self,
404 specs: Vec<KernelSpecification>,
405 cx: &mut Context<Self>,
406 ) {
407 self.kernel_specifications = specs;
408 cx.notify();
409 }
410}