1use std::sync::Arc;
2
3use anyhow::Result;
4use client::telemetry::Telemetry;
5use collections::HashMap;
6use command_palette_hooks::CommandPaletteFilter;
7use gpui::{
8 prelude::*, AppContext, EntityId, Global, Model, ModelContext, Subscription, Task, View,
9};
10use jupyter_websocket_client::RemoteServer;
11use language::Language;
12use project::{Fs, Project, WorktreeId};
13use settings::{Settings, SettingsStore};
14
15use crate::kernels::{
16 list_remote_kernelspecs, local_kernel_specifications, python_env_kernel_specifications,
17};
18use crate::{JupyterSettings, KernelSpecification, Session};
19
20struct GlobalReplStore(Model<ReplStore>);
21
22impl Global for GlobalReplStore {}
23
24pub struct ReplStore {
25 fs: Arc<dyn Fs>,
26 enabled: bool,
27 sessions: HashMap<EntityId, View<Session>>,
28 kernel_specifications: Vec<KernelSpecification>,
29 selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
30 kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
31 telemetry: Arc<Telemetry>,
32 _subscriptions: Vec<Subscription>,
33}
34
35impl ReplStore {
36 const NAMESPACE: &'static str = "repl";
37
38 pub(crate) fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
39 let store = cx.new_model(move |cx| Self::new(fs, telemetry, cx));
40
41 store
42 .update(cx, |store, cx| store.refresh_kernelspecs(cx))
43 .detach_and_log_err(cx);
44
45 cx.set_global(GlobalReplStore(store))
46 }
47
48 pub fn global(cx: &AppContext) -> Model<Self> {
49 cx.global::<GlobalReplStore>().0.clone()
50 }
51
52 pub fn new(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut ModelContext<Self>) -> Self {
53 let subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
54 this.set_enabled(JupyterSettings::enabled(cx), cx);
55 })];
56
57 let this = Self {
58 fs,
59 telemetry,
60 enabled: JupyterSettings::enabled(cx),
61 sessions: HashMap::default(),
62 kernel_specifications: Vec::new(),
63 _subscriptions: subscriptions,
64 kernel_specifications_for_worktree: HashMap::default(),
65 selected_kernel_for_worktree: HashMap::default(),
66 };
67 this.on_enabled_changed(cx);
68 this
69 }
70
71 pub fn fs(&self) -> &Arc<dyn Fs> {
72 &self.fs
73 }
74
75 pub fn telemetry(&self) -> &Arc<Telemetry> {
76 &self.telemetry
77 }
78
79 pub fn is_enabled(&self) -> bool {
80 self.enabled
81 }
82
83 pub fn kernel_specifications_for_worktree(
84 &self,
85 worktree_id: WorktreeId,
86 ) -> impl Iterator<Item = &KernelSpecification> {
87 self.kernel_specifications_for_worktree
88 .get(&worktree_id)
89 .into_iter()
90 .flat_map(|specs| specs.iter())
91 .chain(self.kernel_specifications.iter())
92 }
93
94 pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
95 self.kernel_specifications.iter()
96 }
97
98 pub fn sessions(&self) -> impl Iterator<Item = &View<Session>> {
99 self.sessions.values()
100 }
101
102 fn set_enabled(&mut self, enabled: bool, cx: &mut ModelContext<Self>) {
103 if self.enabled == enabled {
104 return;
105 }
106
107 self.enabled = enabled;
108 self.on_enabled_changed(cx);
109 }
110
111 fn on_enabled_changed(&self, cx: &mut ModelContext<Self>) {
112 if !self.enabled {
113 CommandPaletteFilter::update_global(cx, |filter, _cx| {
114 filter.hide_namespace(Self::NAMESPACE);
115 });
116
117 return;
118 }
119
120 CommandPaletteFilter::update_global(cx, |filter, _cx| {
121 filter.show_namespace(Self::NAMESPACE);
122 });
123
124 cx.notify();
125 }
126
127 pub fn refresh_python_kernelspecs(
128 &mut self,
129 worktree_id: WorktreeId,
130 project: &Model<Project>,
131 cx: &mut ModelContext<Self>,
132 ) -> Task<Result<()>> {
133 let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
134 cx.spawn(move |this, mut cx| async move {
135 let kernel_specifications = kernel_specifications
136 .await
137 .map_err(|e| anyhow::anyhow!("Failed to get python kernelspecs: {:?}", e))?;
138
139 this.update(&mut cx, |this, cx| {
140 this.kernel_specifications_for_worktree
141 .insert(worktree_id, kernel_specifications);
142 cx.notify();
143 })
144 })
145 }
146
147 fn get_remote_kernel_specifications(
148 &self,
149 cx: &mut ModelContext<Self>,
150 ) -> Option<Task<Result<Vec<KernelSpecification>>>> {
151 match (
152 std::env::var("JUPYTER_SERVER"),
153 std::env::var("JUPYTER_TOKEN"),
154 ) {
155 (Ok(server), Ok(token)) => {
156 let remote_server = RemoteServer {
157 base_url: server,
158 token,
159 };
160 let http_client = cx.http_client();
161 Some(cx.spawn(|_, _| async move {
162 list_remote_kernelspecs(remote_server, http_client)
163 .await
164 .map(|specs| specs.into_iter().map(KernelSpecification::Remote).collect())
165 }))
166 }
167 _ => None,
168 }
169 }
170
171 pub fn refresh_kernelspecs(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
172 let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
173
174 let remote_kernel_specifications = self.get_remote_kernel_specifications(cx);
175
176 cx.spawn(|this, mut cx| async move {
177 let mut all_specs = local_kernel_specifications
178 .await?
179 .into_iter()
180 .map(KernelSpecification::Jupyter)
181 .collect::<Vec<_>>();
182
183 if let Some(remote_task) = remote_kernel_specifications {
184 if let Ok(remote_specs) = remote_task.await {
185 all_specs.extend(remote_specs);
186 }
187 }
188
189 this.update(&mut cx, |this, cx| {
190 this.kernel_specifications = all_specs;
191 cx.notify();
192 })
193 })
194 }
195
196 pub fn set_active_kernelspec(
197 &mut self,
198 worktree_id: WorktreeId,
199 kernelspec: KernelSpecification,
200 _cx: &mut ModelContext<Self>,
201 ) {
202 self.selected_kernel_for_worktree
203 .insert(worktree_id, kernelspec);
204 }
205
206 pub fn active_kernelspec(
207 &self,
208 worktree_id: WorktreeId,
209 language_at_cursor: Option<Arc<Language>>,
210 cx: &AppContext,
211 ) -> Option<KernelSpecification> {
212 let selected_kernelspec = self.selected_kernel_for_worktree.get(&worktree_id).cloned();
213
214 if let Some(language_at_cursor) = language_at_cursor {
215 selected_kernelspec
216 .or_else(|| self.kernelspec_legacy_by_lang_only(language_at_cursor, cx))
217 } else {
218 selected_kernelspec
219 }
220 }
221
222 fn kernelspec_legacy_by_lang_only(
223 &self,
224 language_at_cursor: Arc<Language>,
225 cx: &AppContext,
226 ) -> Option<KernelSpecification> {
227 let settings = JupyterSettings::get_global(cx);
228 let selected_kernel = settings
229 .kernel_selections
230 .get(language_at_cursor.code_fence_block_name().as_ref());
231
232 let found_by_name = self
233 .kernel_specifications
234 .iter()
235 .find(|runtime_specification| {
236 if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
237 (selected_kernel, runtime_specification)
238 {
239 // Top priority is the selected kernel
240 return runtime_specification.name.to_lowercase() == selected.to_lowercase();
241 }
242 false
243 })
244 .cloned();
245
246 if let Some(found_by_name) = found_by_name {
247 return Some(found_by_name);
248 }
249
250 self.kernel_specifications
251 .iter()
252 .find(|kernel_option| match kernel_option {
253 KernelSpecification::Jupyter(runtime_specification) => {
254 runtime_specification.kernelspec.language.to_lowercase()
255 == language_at_cursor.code_fence_block_name().to_lowercase()
256 }
257 KernelSpecification::PythonEnv(runtime_specification) => {
258 runtime_specification.kernelspec.language.to_lowercase()
259 == language_at_cursor.code_fence_block_name().to_lowercase()
260 }
261 KernelSpecification::Remote(remote_spec) => {
262 remote_spec.kernelspec.language.to_lowercase()
263 == language_at_cursor.code_fence_block_name().to_lowercase()
264 }
265 })
266 .cloned()
267 }
268
269 pub fn get_session(&self, entity_id: EntityId) -> Option<&View<Session>> {
270 self.sessions.get(&entity_id)
271 }
272
273 pub fn insert_session(&mut self, entity_id: EntityId, session: View<Session>) {
274 self.sessions.insert(entity_id, session);
275 }
276
277 pub fn remove_session(&mut self, entity_id: EntityId) {
278 self.sessions.remove(&entity_id);
279 }
280}