1use anyhow::{Context as _, Result};
2use futures::{
3 channel::mpsc::{self},
4 stream::{SelectAll, StreamExt},
5 SinkExt as _,
6};
7use gpui::{AppContext, EntityId, Task};
8use jupyter_protocol::{JupyterMessage, JupyterMessageContent, KernelInfoReply};
9use project::Fs;
10use runtimelib::{dirs, ConnectionInfo, ExecutionState, JupyterKernelspec};
11use smol::{net::TcpListener, process::Command};
12use std::{
13 env,
14 fmt::Debug,
15 net::{IpAddr, Ipv4Addr, SocketAddr},
16 path::PathBuf,
17 sync::Arc,
18};
19use uuid::Uuid;
20
21use super::{JupyterMessageChannel, RunningKernel};
22
23#[derive(Debug, Clone)]
24pub struct LocalKernelSpecification {
25 pub name: String,
26 pub path: PathBuf,
27 pub kernelspec: JupyterKernelspec,
28}
29
30impl PartialEq for LocalKernelSpecification {
31 fn eq(&self, other: &Self) -> bool {
32 self.name == other.name && self.path == other.path
33 }
34}
35
36impl Eq for LocalKernelSpecification {}
37
38impl LocalKernelSpecification {
39 #[must_use]
40 fn command(&self, connection_path: &PathBuf) -> Result<Command> {
41 let argv = &self.kernelspec.argv;
42
43 anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
44 anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
45 anyhow::ensure!(
46 argv.iter().any(|arg| arg == "{connection_file}"),
47 "Missing 'connection_file' in argv in kernelspec {}",
48 self.name
49 );
50
51 let mut cmd = util::command::new_smol_command(&argv[0]);
52
53 for arg in &argv[1..] {
54 if arg == "{connection_file}" {
55 cmd.arg(connection_path);
56 } else {
57 cmd.arg(arg);
58 }
59 }
60
61 if let Some(env) = &self.kernelspec.env {
62 cmd.envs(env);
63 }
64
65 Ok(cmd)
66 }
67}
68
69// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
70// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
71async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
72 let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
73 addr_zeroport.set_port(0);
74 let mut ports: [u16; 5] = [0; 5];
75 for i in 0..5 {
76 let listener = TcpListener::bind(addr_zeroport).await?;
77 let addr = listener.local_addr()?;
78 ports[i] = addr.port();
79 }
80 Ok(ports)
81}
82
83pub struct NativeRunningKernel {
84 pub process: smol::process::Child,
85 _shell_task: Task<Result<()>>,
86 _iopub_task: Task<Result<()>>,
87 _control_task: Task<Result<()>>,
88 _routing_task: Task<Result<()>>,
89 connection_path: PathBuf,
90 pub working_directory: PathBuf,
91 pub request_tx: mpsc::Sender<JupyterMessage>,
92 pub execution_state: ExecutionState,
93 pub kernel_info: Option<KernelInfoReply>,
94}
95
96impl Debug for NativeRunningKernel {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("RunningKernel")
99 .field("process", &self.process)
100 .finish()
101 }
102}
103
104impl NativeRunningKernel {
105 pub fn new(
106 kernel_specification: LocalKernelSpecification,
107 entity_id: EntityId,
108 working_directory: PathBuf,
109 fs: Arc<dyn Fs>,
110 cx: &mut AppContext,
111 ) -> Task<Result<(Self, JupyterMessageChannel)>> {
112 cx.spawn(|cx| async move {
113 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
114 let ports = peek_ports(ip).await?;
115
116 let connection_info = ConnectionInfo {
117 transport: "tcp".to_string(),
118 ip: ip.to_string(),
119 stdin_port: ports[0],
120 control_port: ports[1],
121 hb_port: ports[2],
122 shell_port: ports[3],
123 iopub_port: ports[4],
124 signature_scheme: "hmac-sha256".to_string(),
125 key: uuid::Uuid::new_v4().to_string(),
126 kernel_name: Some(format!("zed-{}", kernel_specification.name)),
127 };
128
129 let runtime_dir = dirs::runtime_dir();
130 fs.create_dir(&runtime_dir)
131 .await
132 .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
133 let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
134 let content = serde_json::to_string(&connection_info)?;
135 fs.atomic_write(connection_path.clone(), content).await?;
136
137 let mut cmd = kernel_specification.command(&connection_path)?;
138
139 let process = cmd
140 .current_dir(&working_directory)
141 .stdout(std::process::Stdio::piped())
142 .stderr(std::process::Stdio::piped())
143 .stdin(std::process::Stdio::piped())
144 .kill_on_drop(true)
145 .spawn()
146 .context("failed to start the kernel process")?;
147
148 let session_id = Uuid::new_v4().to_string();
149
150 let mut iopub_socket =
151 runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
152 .await?;
153 let mut shell_socket =
154 runtimelib::create_client_shell_connection(&connection_info, &session_id).await?;
155 let mut control_socket =
156 runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
157
158 let (mut iopub, iosub) = futures::channel::mpsc::channel(100);
159
160 let (request_tx, mut request_rx) =
161 futures::channel::mpsc::channel::<JupyterMessage>(100);
162
163 let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
164 let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
165
166 let mut messages_rx = SelectAll::new();
167 messages_rx.push(iosub);
168 messages_rx.push(control_reply_rx);
169 messages_rx.push(shell_reply_rx);
170
171 let iopub_task = cx.background_executor().spawn({
172 async move {
173 while let Ok(message) = iopub_socket.read().await {
174 iopub.send(message).await?;
175 }
176 anyhow::Ok(())
177 }
178 });
179
180 let (mut control_request_tx, mut control_request_rx) =
181 futures::channel::mpsc::channel(100);
182 let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
183
184 let routing_task = cx.background_executor().spawn({
185 async move {
186 while let Some(message) = request_rx.next().await {
187 match message.content {
188 JupyterMessageContent::DebugRequest(_)
189 | JupyterMessageContent::InterruptRequest(_)
190 | JupyterMessageContent::ShutdownRequest(_) => {
191 control_request_tx.send(message).await?;
192 }
193 _ => {
194 shell_request_tx.send(message).await?;
195 }
196 }
197 }
198 anyhow::Ok(())
199 }
200 });
201
202 let shell_task = cx.background_executor().spawn({
203 async move {
204 while let Some(message) = shell_request_rx.next().await {
205 shell_socket.send(message).await.ok();
206 let reply = shell_socket.read().await?;
207 shell_reply_tx.send(reply).await?;
208 }
209 anyhow::Ok(())
210 }
211 });
212
213 let control_task = cx.background_executor().spawn({
214 async move {
215 while let Some(message) = control_request_rx.next().await {
216 control_socket.send(message).await.ok();
217 let reply = control_socket.read().await?;
218 control_reply_tx.send(reply).await?;
219 }
220 anyhow::Ok(())
221 }
222 });
223
224 anyhow::Ok((
225 Self {
226 process,
227 request_tx,
228 working_directory,
229 _shell_task: shell_task,
230 _iopub_task: iopub_task,
231 _control_task: control_task,
232 _routing_task: routing_task,
233 connection_path,
234 execution_state: ExecutionState::Idle,
235 kernel_info: None,
236 },
237 messages_rx,
238 ))
239 })
240 }
241}
242
243impl RunningKernel for NativeRunningKernel {
244 fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
245 self.request_tx.clone()
246 }
247
248 fn working_directory(&self) -> &PathBuf {
249 &self.working_directory
250 }
251
252 fn execution_state(&self) -> &ExecutionState {
253 &self.execution_state
254 }
255
256 fn set_execution_state(&mut self, state: ExecutionState) {
257 self.execution_state = state;
258 }
259
260 fn kernel_info(&self) -> Option<&KernelInfoReply> {
261 self.kernel_info.as_ref()
262 }
263
264 fn set_kernel_info(&mut self, info: KernelInfoReply) {
265 self.kernel_info = Some(info);
266 }
267
268 fn force_shutdown(&mut self) -> anyhow::Result<()> {
269 match self.process.kill() {
270 Ok(_) => Ok(()),
271 Err(error) => Err(anyhow::anyhow!(
272 "Failed to kill the kernel process: {}",
273 error
274 )),
275 }
276 }
277}
278
279impl Drop for NativeRunningKernel {
280 fn drop(&mut self) {
281 std::fs::remove_file(&self.connection_path).ok();
282 self.request_tx.close_channel();
283 self.process.kill().ok();
284 }
285}
286
287async fn read_kernelspec_at(
288 // Path should be a directory to a jupyter kernelspec, as in
289 // /usr/local/share/jupyter/kernels/python3
290 kernel_dir: PathBuf,
291 fs: &dyn Fs,
292) -> Result<LocalKernelSpecification> {
293 let path = kernel_dir;
294 let kernel_name = if let Some(kernel_name) = path.file_name() {
295 kernel_name.to_string_lossy().to_string()
296 } else {
297 anyhow::bail!("Invalid kernelspec directory: {path:?}");
298 };
299
300 if !fs.is_dir(path.as_path()).await {
301 anyhow::bail!("Not a directory: {path:?}");
302 }
303
304 let expected_kernel_json = path.join("kernel.json");
305 let spec = fs.load(expected_kernel_json.as_path()).await?;
306 let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
307
308 Ok(LocalKernelSpecification {
309 name: kernel_name,
310 path,
311 kernelspec: spec,
312 })
313}
314
315/// Read a directory of kernelspec directories
316async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
317 let mut kernelspec_dirs = fs.read_dir(&path).await?;
318
319 let mut valid_kernelspecs = Vec::new();
320 while let Some(path) = kernelspec_dirs.next().await {
321 match path {
322 Ok(path) => {
323 if fs.is_dir(path.as_path()).await {
324 if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
325 valid_kernelspecs.push(kernelspec);
326 }
327 }
328 }
329 Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
330 }
331 }
332
333 Ok(valid_kernelspecs)
334}
335
336pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
337 let mut data_dirs = dirs::data_dirs();
338
339 // Pick up any kernels from conda or conda environment
340 if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
341 let conda_prefix = PathBuf::from(conda_prefix);
342 let conda_data_dir = conda_prefix.join("share").join("jupyter");
343 data_dirs.push(conda_data_dir);
344 }
345
346 // Search for kernels inside the base python environment
347 let command = util::command::new_smol_command("python")
348 .arg("-c")
349 .arg("import sys; print(sys.prefix)")
350 .output()
351 .await;
352
353 if let Ok(command) = command {
354 if command.status.success() {
355 let python_prefix = String::from_utf8(command.stdout);
356 if let Ok(python_prefix) = python_prefix {
357 let python_prefix = PathBuf::from(python_prefix.trim());
358 let python_data_dir = python_prefix.join("share").join("jupyter");
359 data_dirs.push(python_data_dir);
360 }
361 }
362 }
363
364 let kernel_dirs = data_dirs
365 .iter()
366 .map(|dir| dir.join("kernels"))
367 .map(|path| read_kernels_dir(path, fs.as_ref()))
368 .collect::<Vec<_>>();
369
370 let kernel_dirs = futures::future::join_all(kernel_dirs).await;
371 let kernel_dirs = kernel_dirs
372 .into_iter()
373 .filter_map(Result::ok)
374 .flatten()
375 .collect::<Vec<_>>();
376
377 Ok(kernel_dirs)
378}
379
380#[cfg(test)]
381mod test {
382 use super::*;
383 use std::path::PathBuf;
384
385 use gpui::TestAppContext;
386 use project::FakeFs;
387 use serde_json::json;
388
389 #[gpui::test]
390 async fn test_get_kernelspecs(cx: &mut TestAppContext) {
391 let fs = FakeFs::new(cx.executor());
392 fs.insert_tree(
393 "/jupyter",
394 json!({
395 ".zed": {
396 "settings.json": r#"{ "tab_size": 8 }"#,
397 "tasks.json": r#"[{
398 "label": "cargo check",
399 "command": "cargo",
400 "args": ["check", "--all"]
401 },]"#,
402 },
403 "kernels": {
404 "python": {
405 "kernel.json": r#"{
406 "display_name": "Python 3",
407 "language": "python",
408 "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
409 "env": {}
410 }"#
411 },
412 "deno": {
413 "kernel.json": r#"{
414 "display_name": "Deno",
415 "language": "typescript",
416 "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
417 "env": {}
418 }"#
419 }
420 },
421 }),
422 )
423 .await;
424
425 let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
426 .await
427 .unwrap();
428
429 kernels.sort_by(|a, b| a.name.cmp(&b.name));
430
431 assert_eq!(
432 kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
433 vec!["deno", "python"]
434 );
435 }
436}