1use std::sync::Arc;
2
3use anyhow::Result;
4use client::telemetry::Telemetry;
5use collections::HashMap;
6use command_palette_hooks::CommandPaletteFilter;
7use gpui::{
8 prelude::*, AppContext, EntityId, Global, Model, ModelContext, Subscription, Task, View,
9};
10use language::Language;
11use project::{Fs, Project, WorktreeId};
12use settings::{Settings, SettingsStore};
13
14use crate::kernels::{local_kernel_specifications, python_env_kernel_specifications};
15use crate::{JupyterSettings, KernelSpecification, Session};
16
17struct GlobalReplStore(Model<ReplStore>);
18
19impl Global for GlobalReplStore {}
20
21pub struct ReplStore {
22 fs: Arc<dyn Fs>,
23 enabled: bool,
24 sessions: HashMap<EntityId, View<Session>>,
25 kernel_specifications: Vec<KernelSpecification>,
26 selected_kernel_for_worktree: HashMap<WorktreeId, KernelSpecification>,
27 kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
28 telemetry: Arc<Telemetry>,
29 _subscriptions: Vec<Subscription>,
30}
31
32impl ReplStore {
33 const NAMESPACE: &'static str = "repl";
34
35 pub(crate) fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
36 let store = cx.new_model(move |cx| Self::new(fs, telemetry, cx));
37
38 store
39 .update(cx, |store, cx| store.refresh_kernelspecs(cx))
40 .detach_and_log_err(cx);
41
42 cx.set_global(GlobalReplStore(store))
43 }
44
45 pub fn global(cx: &AppContext) -> Model<Self> {
46 cx.global::<GlobalReplStore>().0.clone()
47 }
48
49 pub fn new(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut ModelContext<Self>) -> Self {
50 let subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
51 this.set_enabled(JupyterSettings::enabled(cx), cx);
52 })];
53
54 let this = Self {
55 fs,
56 telemetry,
57 enabled: JupyterSettings::enabled(cx),
58 sessions: HashMap::default(),
59 kernel_specifications: Vec::new(),
60 _subscriptions: subscriptions,
61 kernel_specifications_for_worktree: HashMap::default(),
62 selected_kernel_for_worktree: HashMap::default(),
63 };
64 this.on_enabled_changed(cx);
65 this
66 }
67
68 pub fn fs(&self) -> &Arc<dyn Fs> {
69 &self.fs
70 }
71
72 pub fn telemetry(&self) -> &Arc<Telemetry> {
73 &self.telemetry
74 }
75
76 pub fn is_enabled(&self) -> bool {
77 self.enabled
78 }
79
80 pub fn kernel_specifications_for_worktree(
81 &self,
82 worktree_id: WorktreeId,
83 ) -> impl Iterator<Item = &KernelSpecification> {
84 self.kernel_specifications_for_worktree
85 .get(&worktree_id)
86 .into_iter()
87 .flat_map(|specs| specs.iter())
88 .chain(self.kernel_specifications.iter())
89 }
90
91 pub fn pure_jupyter_kernel_specifications(&self) -> impl Iterator<Item = &KernelSpecification> {
92 self.kernel_specifications.iter()
93 }
94
95 pub fn sessions(&self) -> impl Iterator<Item = &View<Session>> {
96 self.sessions.values()
97 }
98
99 fn set_enabled(&mut self, enabled: bool, cx: &mut ModelContext<Self>) {
100 if self.enabled == enabled {
101 return;
102 }
103
104 self.enabled = enabled;
105 self.on_enabled_changed(cx);
106 }
107
108 fn on_enabled_changed(&self, cx: &mut ModelContext<Self>) {
109 if !self.enabled {
110 CommandPaletteFilter::update_global(cx, |filter, _cx| {
111 filter.hide_namespace(Self::NAMESPACE);
112 });
113
114 return;
115 }
116
117 CommandPaletteFilter::update_global(cx, |filter, _cx| {
118 filter.show_namespace(Self::NAMESPACE);
119 });
120
121 cx.notify();
122 }
123
124 pub fn refresh_python_kernelspecs(
125 &mut self,
126 worktree_id: WorktreeId,
127 project: &Model<Project>,
128 cx: &mut ModelContext<Self>,
129 ) -> Task<Result<()>> {
130 let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
131 cx.spawn(move |this, mut cx| async move {
132 let kernel_specifications = kernel_specifications
133 .await
134 .map_err(|e| anyhow::anyhow!("Failed to get python kernelspecs: {:?}", e))?;
135
136 this.update(&mut cx, |this, cx| {
137 this.kernel_specifications_for_worktree
138 .insert(worktree_id, kernel_specifications);
139 cx.notify();
140 })
141 })
142 }
143
144 pub fn refresh_kernelspecs(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
145 let local_kernel_specifications = local_kernel_specifications(self.fs.clone());
146
147 cx.spawn(|this, mut cx| async move {
148 let local_kernel_specifications = local_kernel_specifications.await?;
149
150 let mut kernel_options = Vec::new();
151 for kernel_specification in local_kernel_specifications {
152 kernel_options.push(KernelSpecification::Jupyter(kernel_specification));
153 }
154
155 this.update(&mut cx, |this, cx| {
156 this.kernel_specifications = kernel_options;
157 cx.notify();
158 })
159 })
160 }
161
162 pub fn set_active_kernelspec(
163 &mut self,
164 worktree_id: WorktreeId,
165 kernelspec: KernelSpecification,
166 _cx: &mut ModelContext<Self>,
167 ) {
168 self.selected_kernel_for_worktree
169 .insert(worktree_id, kernelspec);
170 }
171
172 pub fn active_kernelspec(
173 &self,
174 worktree_id: WorktreeId,
175 language_at_cursor: Option<Arc<Language>>,
176 cx: &AppContext,
177 ) -> Option<KernelSpecification> {
178 let selected_kernelspec = self.selected_kernel_for_worktree.get(&worktree_id).cloned();
179
180 if let Some(language_at_cursor) = language_at_cursor {
181 selected_kernelspec
182 .or_else(|| self.kernelspec_legacy_by_lang_only(language_at_cursor, cx))
183 } else {
184 selected_kernelspec
185 }
186 }
187
188 fn kernelspec_legacy_by_lang_only(
189 &self,
190 language_at_cursor: Arc<Language>,
191 cx: &AppContext,
192 ) -> Option<KernelSpecification> {
193 let settings = JupyterSettings::get_global(cx);
194 let selected_kernel = settings
195 .kernel_selections
196 .get(language_at_cursor.code_fence_block_name().as_ref());
197
198 let found_by_name = self
199 .kernel_specifications
200 .iter()
201 .find(|runtime_specification| {
202 if let (Some(selected), KernelSpecification::Jupyter(runtime_specification)) =
203 (selected_kernel, runtime_specification)
204 {
205 // Top priority is the selected kernel
206 return runtime_specification.name.to_lowercase() == selected.to_lowercase();
207 }
208 false
209 })
210 .cloned();
211
212 if let Some(found_by_name) = found_by_name {
213 return Some(found_by_name);
214 }
215
216 self.kernel_specifications
217 .iter()
218 .find(|kernel_option| match kernel_option {
219 KernelSpecification::Jupyter(runtime_specification) => {
220 runtime_specification.kernelspec.language.to_lowercase()
221 == language_at_cursor.code_fence_block_name().to_lowercase()
222 }
223 KernelSpecification::PythonEnv(runtime_specification) => {
224 runtime_specification.kernelspec.language.to_lowercase()
225 == language_at_cursor.code_fence_block_name().to_lowercase()
226 }
227 KernelSpecification::Remote(_) => {
228 unimplemented!()
229 }
230 })
231 .cloned()
232 }
233
234 pub fn get_session(&self, entity_id: EntityId) -> Option<&View<Session>> {
235 self.sessions.get(&entity_id)
236 }
237
238 pub fn insert_session(&mut self, entity_id: EntityId, session: View<Session>) {
239 self.sessions.insert(entity_id, session);
240 }
241
242 pub fn remove_session(&mut self, entity_id: EntityId) {
243 self.sessions.remove(&entity_id);
244 }
245}