1use std::sync::Arc;
2
3use anyhow::{Context as _, Result};
4use collections::HashMap;
5use command_palette_hooks::CommandPaletteFilter;
6use gpui::{App, Context, Entity, EntityId, Global, Subscription, Task, prelude::*};
7use jupyter_websocket_client::RemoteServer;
8use language::Language;
9use project::{Fs, Project, WorktreeId};
10use settings::{Settings, SettingsStore};
11
12use crate::kernels::{
13 list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
14};
15use crate::{JupyterSettings, KernelSpecification, Session};
16
17struct GlobalReplStore(Entity<ReplStore>);
18
19impl Global for GlobalReplStore {}
20
21pub struct ReplStore {
22 fs: Arc<dyn Fs>,
23 enabled: bool,
24 sessions: HashMap<EntityId, Entity<Session>>,
25 kernel_specifications: Vec<KernelSpecification>,
26 selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
27 kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
28 _subscriptions: Vec<Subscription>,
29}
30
31impl ReplStore {
32 const NAMESPACE: &'static str = "repl";
33
34 pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut App) {
35 let store = cx.new(move |cx| Self::new(fs, cx));
36
37 #[cfg(not(feature = "test-support"))]
38 store
39 .update(cx, |store, cx| store.refresh_kernelspecs(cx))
40 .detach_and_log_err(cx);
41
42 cx.set_global(GlobalReplStore(store))
43 }
44
45 pub fn global(cx: &App) -> Entity<Self> {
46 cx.global::<GlobalReplStore>().0.clone()
47 }
48
49 pub fn new(fs: Arc<dyn Fs>, cx: &mut Context<Self>) -> Self {
50 let subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
51 this.set_enabled(JupyterSettings::enabled(cx), cx);
52 })];
53
54 let this = Self {
55 fs,
56 enabled: JupyterSettings::enabled(cx),
57 sessions: HashMap::default(),
58 kernel_specifications: Vec::new(),
59 _subscriptions: subscriptions,
60 kernel_specifications_for_worktree: HashMap::default(),
61 selected_kernel_for_worktree: HashMap::default(),
62 };
63 this.on_enabled_changed(cx);
64 this
65 }
66
67 pub fn fs(&self) -> &Arc<dyn Fs> {
68 &self.fs
69 }
70
71 pub fn is_enabled(&self) -> bool {
72 self.enabled
73 }
74
75 pub fn kernel_specifications_for_worktree(
76 &self,
77 worktree_id: WorktreeId,
78 ) -> impl Iterator<Item = &KernelSpecification> {
79 self.kernel_specifications_for_worktree
80 .get(&worktree_id)
81 .into_iter()
82 .flat_map(|specs| specs.iter())
83 .chain(self.kernel_specifications.iter())
84 }
85
86 pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
87 self.kernel_specifications.iter()
88 }
89
90 pub fn sessions(&self) -> impl Iterator<Item = &Entity<Session>> {
91 self.sessions.values()
92 }
93
94 fn set_enabled(&mut self, enabled: bool, cx: &mut Context<Self>) {
95 if self.enabled == enabled {
96 return;
97 }
98
99 self.enabled = enabled;
100 self.on_enabled_changed(cx);
101 }
102
103 fn on_enabled_changed(&self, cx: &mut Context<Self>) {
104 if !self.enabled {
105 CommandPaletteFilter::update_global(cx, |filter, _cx| {
106 filter.hide_namespace(Self::NAMESPACE);
107 });
108
109 return;
110 }
111
112 CommandPaletteFilter::update_global(cx, |filter, _cx| {
113 filter.show_namespace(Self::NAMESPACE);
114 });
115
116 cx.notify();
117 }
118
119 pub fn refresh_python_kernelspecs(
120 &mut self,
121 worktree_id: WorktreeId,
122 project: &Entity<Project>,
123 cx: &mut Context<Self>,
124 ) -> Task<Result<()>> {
125 let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
126 cx.spawn(async move |this, cx| {
127 let kernel_specifications = kernel_specifications
128 .await
129 .context("getting python kernelspecs")?;
130
131 this.update(cx, |this, cx| {
132 this.kernel_specifications_for_worktree
133 .insert(worktree_id, kernel_specifications);
134 cx.notify();
135 })
136 })
137 }
138
139 fn get_remote_kernel_specifications(
140 &self,
141 cx: &mut Context<Self>,
142 ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
143 match (
144 std::env::var("JUPYTER_SERVER"),
145 std::env::var("JUPYTER_TOKEN"),
146 ) {
147 (Ok(server), Ok(token)) => {
148 let remote_server = RemoteServer {
149 base_url: server,
150 token,
151 };
152 let http_client = cx.http_client();
153 Some(cx.spawn(async move |_, _| {
154 list_remote_kernelspecs(remote_server, http_client)
155 .await
156 .map(|specs| specs.into_iter().map(KernelSpecification::Remote).collect())
157 }))
158 }
159 _ => None,
160 }
161 }
162
163 pub fn refresh_kernelspecs(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
164 let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
165
166 let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
167
168 let all_specs = cx.background_spawn(async move {
169 let mut all_specs = local_kernel_specifications
170 .await?
171 .into_iter()
172 .map(KernelSpecification::Jupyter)
173 .collect::<Vec<_>>();
174
175 if let Some(remote_task) = remote_kernel_specifications
176 && let Ok(remote_specs) = remote_task.await
177 {
178 all_specs.extend(remote_specs);
179 }
180
181 anyhow::Ok(all_specs)
182 });
183
184 cx.spawn(async move |this, cx| {
185 let all_specs = all_specs.await;
186
187 if let Ok(specs) = all_specs {
188 this.update(cx, |this, cx| {
189 this.kernel_specifications = specs;
190 cx.notify();
191 })
192 .ok();
193 }
194
195 anyhow::Ok(())
196 })
197 }
198
199 pub fn set_active_kernelspec(
200 &mut self,
201 worktree_id: WorktreeId,
202 kernelspec: KernelSpecification,
203 _cx: &mut Context<Self>,
204 ) {
205 self.selected_kernel_for_worktree
206 .insert(worktree_id, kernelspec);
207 }
208
209 pub fn active_kernelspec(
210 &self,
211 worktree_id: WorktreeId,
212 language_at_cursor: Option<Arc<Language>>,
213 cx: &App,
214 ) -> Option<KernelSpecification> {
215 let selected_kernelspec = self.selected_kernel_for_worktree.get(&worktree_id).cloned();
216
217 if let Some(language_at_cursor) = language_at_cursor {
218 selected_kernelspec.or_else(|| {
219 self.kernelspec_legacy_by_lang_only(worktree_id, language_at_cursor, cx)
220 })
221 } else {
222 selected_kernelspec
223 }
224 }
225
226 fn kernelspec_legacy_by_lang_only(
227 &self,
228 worktree_id: WorktreeId,
229 language_at_cursor: Arc<Language>,
230 cx: &App,
231 ) -> Option<KernelSpecification> {
232 let settings = JupyterSettings::get_global(cx);
233 let selected_kernel = settings
234 .kernel_selections
235 .get(language_at_cursor.code_fence_block_name().as_ref());
236
237 let found_by_name = self
238 .kernel_specifications_for_worktree(worktree_id)
239 .find(|runtime_specification| {
240 if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
241 (selected_kernel, runtime_specification)
242 {
243 // Top priority is the selected kernel
244 return runtime_specification.name.to_lowercase() == selected.to_lowercase();
245 }
246 false
247 })
248 .cloned();
249
250 if let Some(found_by_name) = found_by_name {
251 return Some(found_by_name);
252 }
253
254 self.kernel_specifications_for_worktree(worktree_id)
255 .find(|kernel_option| match kernel_option {
256 KernelSpecification::Jupyter(runtime_specification) => {
257 runtime_specification.kernelspec.language.to_lowercase()
258 == language_at_cursor.code_fence_block_name().to_lowercase()
259 }
260 KernelSpecification::PythonEnv(runtime_specification) => {
261 runtime_specification.kernelspec.language.to_lowercase()
262 == language_at_cursor.code_fence_block_name().to_lowercase()
263 }
264 KernelSpecification::Remote(remote_spec) => {
265 remote_spec.kernelspec.language.to_lowercase()
266 == language_at_cursor.code_fence_block_name().to_lowercase()
267 }
268 })
269 .cloned()
270 }
271
272 pub fn get_session(&self, entity_id: EntityId) -> Option<&Entity<Session>> {
273 self.sessions.get(&entity_id)
274 }
275
276 pub fn insert_session(&mut self, entity_id: EntityId, session: Entity<Session>) {
277 self.sessions.insert(entity_id, session);
278 }
279
280 pub fn remove_session(&mut self, entity_id: EntityId) {
281 self.sessions.remove(&entity_id);
282 }
283
284 #[cfg(test)]
285 pub fn set_kernel_specs_for_testing(
286 &mut self,
287 specs: Vec<KernelSpecification>,
288 cx: &mut Context<Self>,
289 ) {
290 self.kernel_specifications = specs;
291 cx.notify();
292 }
293}