1use std::sync::Arc;
2
3use anyhow::Result;
4use collections::HashMap;
5use command_palette_hooks::CommandPaletteFilter;
6use gpui::{
7 prelude::*, AppContext, EntityId, Global, Model, ModelContext, Subscription, Task, View,
8};
9use jupyter_websocket_client::RemoteServer;
10use language::Language;
11use project::{Fs, Project, WorktreeId};
12use settings::{Settings, SettingsStore};
13
14use crate::kernels::{
15 list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
16};
17use crate::{JupyterSettings, KernelSpecification, Session};
18
19struct GlobalReplStore(Model<ReplStore>);
20
21impl Global for GlobalReplStore {}
22
23pub struct ReplStore {
24 fs: Arc<dyn Fs>,
25 enabled: bool,
26 sessions: HashMap<EntityId, View<Session>>,
27 kernel_specifications: Vec<KernelSpecification>,
28 selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
29 kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
30 _subscriptions: Vec<Subscription>,
31}
32
33impl ReplStore {
34 const NAMESPACE: &'static str = "repl";
35
36 pub(crate) fn init(fs: Arc<dyn Fs>, cx: &mut AppContext) {
37 let store = cx.new_model(move |cx| Self::new(fs, cx));
38
39 store
40 .update(cx, |store, cx| store.refresh_kernelspecs(cx))
41 .detach_and_log_err(cx);
42
43 cx.set_global(GlobalReplStore(store))
44 }
45
46 pub fn global(cx: &AppContext) -> Model<Self> {
47 cx.global::<GlobalReplStore>().0.clone()
48 }
49
50 pub fn new(fs: Arc<dyn Fs>, cx: &mut ModelContext<Self>) -> Self {
51 let subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
52 this.set_enabled(JupyterSettings::enabled(cx), cx);
53 })];
54
55 let this = Self {
56 fs,
57 enabled: JupyterSettings::enabled(cx),
58 sessions: HashMap::default(),
59 kernel_specifications: Vec::new(),
60 _subscriptions: subscriptions,
61 kernel_specifications_for_worktree: HashMap::default(),
62 selected_kernel_for_worktree: HashMap::default(),
63 };
64 this.on_enabled_changed(cx);
65 this
66 }
67
68 pub fn fs(&self) -> &Arc<dyn Fs> {
69 &self.fs
70 }
71
72 pub fn is_enabled(&self) -> bool {
73 self.enabled
74 }
75
76 pub fn kernel_specifications_for_worktree(
77 &self,
78 worktree_id: WorktreeId,
79 ) -> impl Iterator<Item = &KernelSpecification> {
80 self.kernel_specifications_for_worktree
81 .get(&worktree_id)
82 .into_iter()
83 .flat_map(|specs| specs.iter())
84 .chain(self.kernel_specifications.iter())
85 }
86
87 pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
88 self.kernel_specifications.iter()
89 }
90
91 pub fn sessions(&self) -> impl Iterator<Item = &View<Session>> {
92 self.sessions.values()
93 }
94
95 fn set_enabled(&mut self, enabled: bool, cx: &mut ModelContext<Self>) {
96 if self.enabled == enabled {
97 return;
98 }
99
100 self.enabled = enabled;
101 self.on_enabled_changed(cx);
102 }
103
104 fn on_enabled_changed(&self, cx: &mut ModelContext<Self>) {
105 if !self.enabled {
106 CommandPaletteFilter::update_global(cx, |filter, _cx| {
107 filter.hide_namespace(Self::NAMESPACE);
108 });
109
110 return;
111 }
112
113 CommandPaletteFilter::update_global(cx, |filter, _cx| {
114 filter.show_namespace(Self::NAMESPACE);
115 });
116
117 cx.notify();
118 }
119
120 pub fn refresh_python_kernelspecs(
121 &mut self,
122 worktree_id: WorktreeId,
123 project: &Model<Project>,
124 cx: &mut ModelContext<Self>,
125 ) -> Task<Result<()>> {
126 let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
127 cx.spawn(move |this, mut cx| async move {
128 let kernel_specifications = kernel_specifications
129 .await
130 .map_err(|e| anyhow::anyhow!("Failed to get python kernelspecs: {:?}", e))?;
131
132 this.update(&mut cx, |this, cx| {
133 this.kernel_specifications_for_worktree
134 .insert(worktree_id, kernel_specifications);
135 cx.notify();
136 })
137 })
138 }
139
140 fn get_remote_kernel_specifications(
141 &self,
142 cx: &mut ModelContext<Self>,
143 ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
144 match (
145 std::env::var("JUPYTER_SERVER"),
146 std::env::var("JUPYTER_TOKEN"),
147 ) {
148 (Ok(server), Ok(token)) => {
149 let remote_server = RemoteServer {
150 base_url: server,
151 token,
152 };
153 let http_client = cx.http_client();
154 Some(cx.spawn(|_, _| async move {
155 list_remote_kernelspecs(remote_server, http_client)
156 .await
157 .map(|specs| specs.into_iter().map(KernelSpecification::Remote).collect())
158 }))
159 }
160 _ => None,
161 }
162 }
163
164 pub fn refresh_kernelspecs(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
165 let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
166
167 let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
168
169 let all_specs = cx.background_executor().spawn(async move {
170 let mut all_specs = local_kernel_specifications
171 .await?
172 .into_iter()
173 .map(KernelSpecification::Jupyter)
174 .collect::<Vec<_>>();
175
176 if let Some(remote_task) = remote_kernel_specifications {
177 if let Ok(remote_specs) = remote_task.await {
178 all_specs.extend(remote_specs);
179 }
180 }
181
182 anyhow::Ok(all_specs)
183 });
184
185 cx.spawn(|this, mut cx| async move {
186 let all_specs = all_specs.await;
187
188 if let Ok(specs) = all_specs {
189 this.update(&mut cx, |this, cx| {
190 this.kernel_specifications = specs;
191 cx.notify();
192 })
193 .ok();
194 }
195
196 anyhow::Ok(())
197 })
198 }
199
200 pub fn set_active_kernelspec(
201 &mut self,
202 worktree_id: WorktreeId,
203 kernelspec: KernelSpecification,
204 _cx: &mut ModelContext<Self>,
205 ) {
206 self.selected_kernel_for_worktree
207 .insert(worktree_id, kernelspec);
208 }
209
210 pub fn active_kernelspec(
211 &self,
212 worktree_id: WorktreeId,
213 language_at_cursor: Option<Arc<Language>>,
214 cx: &AppContext,
215 ) -> Option<KernelSpecification> {
216 let selected_kernelspec = self.selected_kernel_for_worktree.get(&worktree_id).cloned();
217
218 if let Some(language_at_cursor) = language_at_cursor {
219 selected_kernelspec
220 .or_else(|| self.kernelspec_legacy_by_lang_only(language_at_cursor, cx))
221 } else {
222 selected_kernelspec
223 }
224 }
225
226 fn kernelspec_legacy_by_lang_only(
227 &self,
228 language_at_cursor: Arc<Language>,
229 cx: &AppContext,
230 ) -> Option<KernelSpecification> {
231 let settings = JupyterSettings::get_global(cx);
232 let selected_kernel = settings
233 .kernel_selections
234 .get(language_at_cursor.code_fence_block_name().as_ref());
235
236 let found_by_name = self
237 .kernel_specifications
238 .iter()
239 .find(|runtime_specification| {
240 if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
241 (selected_kernel, runtime_specification)
242 {
243 // Top priority is the selected kernel
244 return runtime_specification.name.to_lowercase() == selected.to_lowercase();
245 }
246 false
247 })
248 .cloned();
249
250 if let Some(found_by_name) = found_by_name {
251 return Some(found_by_name);
252 }
253
254 self.kernel_specifications
255 .iter()
256 .find(|kernel_option| match kernel_option {
257 KernelSpecification::Jupyter(runtime_specification) => {
258 runtime_specification.kernelspec.language.to_lowercase()
259 == language_at_cursor.code_fence_block_name().to_lowercase()
260 }
261 KernelSpecification::PythonEnv(runtime_specification) => {
262 runtime_specification.kernelspec.language.to_lowercase()
263 == language_at_cursor.code_fence_block_name().to_lowercase()
264 }
265 KernelSpecification::Remote(remote_spec) => {
266 remote_spec.kernelspec.language.to_lowercase()
267 == language_at_cursor.code_fence_block_name().to_lowercase()
268 }
269 })
270 .cloned()
271 }
272
273 pub fn get_session(&self, entity_id: EntityId) -> Option<&View<Session>> {
274 self.sessions.get(&entity_id)
275 }
276
277 pub fn insert_session(&mut self, entity_id: EntityId, session: View<Session>) {
278 self.sessions.insert(entity_id, session);
279 }
280
281 pub fn remove_session(&mut self, entity_id: EntityId) {
282 self.sessions.remove(&entity_id);
283 }
284}