1use anyhow::{Context as _, Result};
2use futures::{
3 channel::mpsc::{self},
4 io::BufReader,
5 stream::{SelectAll, StreamExt},
6 AsyncBufReadExt as _, SinkExt as _,
7};
8use gpui::{EntityId, Task, View, WindowContext};
9use jupyter_protocol::{
10 connection_info::{ConnectionInfo, Transport},
11 ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent, KernelInfoReply,
12};
13use project::Fs;
14use runtimelib::dirs;
15use smol::{net::TcpListener, process::Command};
16use std::{
17 env,
18 fmt::Debug,
19 net::{IpAddr, Ipv4Addr, SocketAddr},
20 path::PathBuf,
21 sync::Arc,
22};
23use uuid::Uuid;
24
25use crate::Session;
26
27use super::RunningKernel;
28
29#[derive(Debug, Clone)]
30pub struct LocalKernelSpecification {
31 pub name: String,
32 pub path: PathBuf,
33 pub kernelspec: JupyterKernelspec,
34}
35
36impl PartialEq for LocalKernelSpecification {
37 fn eq(&self, other: &Self) -> bool {
38 self.name == other.name && self.path == other.path
39 }
40}
41
42impl Eq for LocalKernelSpecification {}
43
44impl LocalKernelSpecification {
45 #[must_use]
46 fn command(&self, connection_path: &PathBuf) -> Result<Command> {
47 let argv = &self.kernelspec.argv;
48
49 anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
50 anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
51 anyhow::ensure!(
52 argv.iter().any(|arg| arg == "{connection_file}"),
53 "Missing 'connection_file' in argv in kernelspec {}",
54 self.name
55 );
56
57 let mut cmd = util::command::new_smol_command(&argv[0]);
58
59 for arg in &argv[1..] {
60 if arg == "{connection_file}" {
61 cmd.arg(connection_path);
62 } else {
63 cmd.arg(arg);
64 }
65 }
66
67 if let Some(env) = &self.kernelspec.env {
68 cmd.envs(env);
69 }
70
71 Ok(cmd)
72 }
73}
74
75// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
76// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
77async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
78 let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
79 addr_zeroport.set_port(0);
80 let mut ports: [u16; 5] = [0; 5];
81 for i in 0..5 {
82 let listener = TcpListener::bind(addr_zeroport).await?;
83 let addr = listener.local_addr()?;
84 ports[i] = addr.port();
85 }
86 Ok(ports)
87}
88
89pub struct NativeRunningKernel {
90 pub process: smol::process::Child,
91 _shell_task: Task<Result<()>>,
92 _control_task: Task<Result<()>>,
93 _routing_task: Task<Result<()>>,
94 connection_path: PathBuf,
95 _process_status_task: Option<Task<()>>,
96 pub working_directory: PathBuf,
97 pub request_tx: mpsc::Sender<JupyterMessage>,
98 pub execution_state: ExecutionState,
99 pub kernel_info: Option<KernelInfoReply>,
100}
101
102impl Debug for NativeRunningKernel {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("RunningKernel")
105 .field("process", &self.process)
106 .finish()
107 }
108}
109
110impl NativeRunningKernel {
111 pub fn new(
112 kernel_specification: LocalKernelSpecification,
113 entity_id: EntityId,
114 working_directory: PathBuf,
115 fs: Arc<dyn Fs>,
116 // todo: convert to weak view
117 session: View<Session>,
118 cx: &mut WindowContext,
119 ) -> Task<Result<Box<dyn RunningKernel>>> {
120 cx.spawn(|cx| async move {
121 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
122 let ports = peek_ports(ip).await?;
123
124 let connection_info = ConnectionInfo {
125 transport: Transport::TCP,
126 ip: ip.to_string(),
127 stdin_port: ports[0],
128 control_port: ports[1],
129 hb_port: ports[2],
130 shell_port: ports[3],
131 iopub_port: ports[4],
132 signature_scheme: "hmac-sha256".to_string(),
133 key: uuid::Uuid::new_v4().to_string(),
134 kernel_name: Some(format!("zed-{}", kernel_specification.name)),
135 };
136
137 let runtime_dir = dirs::runtime_dir();
138 fs.create_dir(&runtime_dir)
139 .await
140 .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
141 let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
142 let content = serde_json::to_string(&connection_info)?;
143 fs.atomic_write(connection_path.clone(), content).await?;
144
145 let mut cmd = kernel_specification.command(&connection_path)?;
146
147 let mut process = cmd
148 .current_dir(&working_directory)
149 .stdout(std::process::Stdio::piped())
150 .stderr(std::process::Stdio::piped())
151 .stdin(std::process::Stdio::piped())
152 .kill_on_drop(true)
153 .spawn()
154 .context("failed to start the kernel process")?;
155
156 let session_id = Uuid::new_v4().to_string();
157
158 let mut iopub_socket =
159 runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
160 .await?;
161 let mut shell_socket =
162 runtimelib::create_client_shell_connection(&connection_info, &session_id).await?;
163 let mut control_socket =
164 runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
165
166 let (request_tx, mut request_rx) =
167 futures::channel::mpsc::channel::<JupyterMessage>(100);
168
169 let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
170 let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
171
172 let mut messages_rx = SelectAll::new();
173 messages_rx.push(control_reply_rx);
174 messages_rx.push(shell_reply_rx);
175
176 cx.spawn({
177 let session = session.clone();
178
179 |mut cx| async move {
180 while let Some(message) = messages_rx.next().await {
181 session
182 .update(&mut cx, |session, cx| {
183 session.route(&message, cx);
184 })
185 .ok();
186 }
187 anyhow::Ok(())
188 }
189 })
190 .detach();
191
192 // iopub task
193 cx.spawn({
194 let session = session.clone();
195
196 |mut cx| async move {
197 while let Ok(message) = iopub_socket.read().await {
198 session
199 .update(&mut cx, |session, cx| {
200 session.route(&message, cx);
201 })
202 .ok();
203 }
204 anyhow::Ok(())
205 }
206 })
207 .detach();
208
209 let (mut control_request_tx, mut control_request_rx) =
210 futures::channel::mpsc::channel(100);
211 let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
212
213 let routing_task = cx.background_executor().spawn({
214 async move {
215 while let Some(message) = request_rx.next().await {
216 match message.content {
217 JupyterMessageContent::DebugRequest(_)
218 | JupyterMessageContent::InterruptRequest(_)
219 | JupyterMessageContent::ShutdownRequest(_) => {
220 control_request_tx.send(message).await?;
221 }
222 _ => {
223 shell_request_tx.send(message).await?;
224 }
225 }
226 }
227 anyhow::Ok(())
228 }
229 });
230
231 let shell_task = cx.background_executor().spawn({
232 async move {
233 while let Some(message) = shell_request_rx.next().await {
234 shell_socket.send(message).await.ok();
235 let reply = shell_socket.read().await?;
236 shell_reply_tx.send(reply).await?;
237 }
238 anyhow::Ok(())
239 }
240 });
241
242 let control_task = cx.background_executor().spawn({
243 async move {
244 while let Some(message) = control_request_rx.next().await {
245 control_socket.send(message).await.ok();
246 let reply = control_socket.read().await?;
247 control_reply_tx.send(reply).await?;
248 }
249 anyhow::Ok(())
250 }
251 });
252
253 let stderr = process.stderr.take();
254
255 cx.spawn(|mut _cx| async move {
256 if stderr.is_none() {
257 return;
258 }
259 let reader = BufReader::new(stderr.unwrap());
260 let mut lines = reader.lines();
261 while let Some(Ok(line)) = lines.next().await {
262 log::error!("kernel: {}", line);
263 }
264 })
265 .detach();
266
267 let stdout = process.stdout.take();
268
269 cx.spawn(|mut _cx| async move {
270 if stdout.is_none() {
271 return;
272 }
273 let reader = BufReader::new(stdout.unwrap());
274 let mut lines = reader.lines();
275 while let Some(Ok(line)) = lines.next().await {
276 log::info!("kernel: {}", line);
277 }
278 })
279 .detach();
280
281 let status = process.status();
282
283 let process_status_task = cx.spawn(|mut cx| async move {
284 let error_message = match status.await {
285 Ok(status) => {
286 if status.success() {
287 log::info!("kernel process exited successfully");
288 return;
289 }
290
291 format!("kernel process exited with status: {:?}", status)
292 }
293 Err(err) => {
294 format!("kernel process exited with error: {:?}", err)
295 }
296 };
297
298 log::error!("{}", error_message);
299
300 session
301 .update(&mut cx, |session, cx| {
302 session.kernel_errored(error_message, cx);
303
304 cx.notify();
305 })
306 .ok();
307 });
308
309 anyhow::Ok(Box::new(Self {
310 process,
311 request_tx,
312 working_directory,
313 _process_status_task: Some(process_status_task),
314 _shell_task: shell_task,
315 _control_task: control_task,
316 _routing_task: routing_task,
317 connection_path,
318 execution_state: ExecutionState::Idle,
319 kernel_info: None,
320 }) as Box<dyn RunningKernel>)
321 })
322 }
323}
324
325impl RunningKernel for NativeRunningKernel {
326 fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
327 self.request_tx.clone()
328 }
329
330 fn working_directory(&self) -> &PathBuf {
331 &self.working_directory
332 }
333
334 fn execution_state(&self) -> &ExecutionState {
335 &self.execution_state
336 }
337
338 fn set_execution_state(&mut self, state: ExecutionState) {
339 self.execution_state = state;
340 }
341
342 fn kernel_info(&self) -> Option<&KernelInfoReply> {
343 self.kernel_info.as_ref()
344 }
345
346 fn set_kernel_info(&mut self, info: KernelInfoReply) {
347 self.kernel_info = Some(info);
348 }
349
350 fn force_shutdown(&mut self, _cx: &mut WindowContext) -> Task<anyhow::Result<()>> {
351 self._process_status_task.take();
352 self.request_tx.close_channel();
353
354 Task::ready(match self.process.kill() {
355 Ok(_) => Ok(()),
356 Err(error) => Err(anyhow::anyhow!(
357 "Failed to kill the kernel process: {}",
358 error
359 )),
360 })
361 }
362}
363
364impl Drop for NativeRunningKernel {
365 fn drop(&mut self) {
366 std::fs::remove_file(&self.connection_path).ok();
367 self.request_tx.close_channel();
368 self.process.kill().ok();
369 }
370}
371
372async fn read_kernelspec_at(
373 // Path should be a directory to a jupyter kernelspec, as in
374 // /usr/local/share/jupyter/kernels/python3
375 kernel_dir: PathBuf,
376 fs: &dyn Fs,
377) -> Result<LocalKernelSpecification> {
378 let path = kernel_dir;
379 let kernel_name = if let Some(kernel_name) = path.file_name() {
380 kernel_name.to_string_lossy().to_string()
381 } else {
382 anyhow::bail!("Invalid kernelspec directory: {path:?}");
383 };
384
385 if !fs.is_dir(path.as_path()).await {
386 anyhow::bail!("Not a directory: {path:?}");
387 }
388
389 let expected_kernel_json = path.join("kernel.json");
390 let spec = fs.load(expected_kernel_json.as_path()).await?;
391 let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
392
393 Ok(LocalKernelSpecification {
394 name: kernel_name,
395 path,
396 kernelspec: spec,
397 })
398}
399
400/// Read a directory of kernelspec directories
401async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
402 let mut kernelspec_dirs = fs.read_dir(&path).await?;
403
404 let mut valid_kernelspecs = Vec::new();
405 while let Some(path) = kernelspec_dirs.next().await {
406 match path {
407 Ok(path) => {
408 if fs.is_dir(path.as_path()).await {
409 if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
410 valid_kernelspecs.push(kernelspec);
411 }
412 }
413 }
414 Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
415 }
416 }
417
418 Ok(valid_kernelspecs)
419}
420
421pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
422 let mut data_dirs = dirs::data_dirs();
423
424 // Pick up any kernels from conda or conda environment
425 if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
426 let conda_prefix = PathBuf::from(conda_prefix);
427 let conda_data_dir = conda_prefix.join("share").join("jupyter");
428 data_dirs.push(conda_data_dir);
429 }
430
431 // Search for kernels inside the base python environment
432 let command = util::command::new_smol_command("python")
433 .arg("-c")
434 .arg("import sys; print(sys.prefix)")
435 .output()
436 .await;
437
438 if let Ok(command) = command {
439 if command.status.success() {
440 let python_prefix = String::from_utf8(command.stdout);
441 if let Ok(python_prefix) = python_prefix {
442 let python_prefix = PathBuf::from(python_prefix.trim());
443 let python_data_dir = python_prefix.join("share").join("jupyter");
444 data_dirs.push(python_data_dir);
445 }
446 }
447 }
448
449 let kernel_dirs = data_dirs
450 .iter()
451 .map(|dir| dir.join("kernels"))
452 .map(|path| read_kernels_dir(path, fs.as_ref()))
453 .collect::<Vec<_>>();
454
455 let kernel_dirs = futures::future::join_all(kernel_dirs).await;
456 let kernel_dirs = kernel_dirs
457 .into_iter()
458 .filter_map(Result::ok)
459 .flatten()
460 .collect::<Vec<_>>();
461
462 Ok(kernel_dirs)
463}
464
465#[cfg(test)]
466mod test {
467 use super::*;
468 use std::path::PathBuf;
469
470 use gpui::TestAppContext;
471 use project::FakeFs;
472 use serde_json::json;
473
474 #[gpui::test]
475 async fn test_get_kernelspecs(cx: &mut TestAppContext) {
476 let fs = FakeFs::new(cx.executor());
477 fs.insert_tree(
478 "/jupyter",
479 json!({
480 ".zed": {
481 "settings.json": r#"{ "tab_size": 8 }"#,
482 "tasks.json": r#"[{
483 "label": "cargo check",
484 "command": "cargo",
485 "args": ["check", "--all"]
486 },]"#,
487 },
488 "kernels": {
489 "python": {
490 "kernel.json": r#"{
491 "display_name": "Python 3",
492 "language": "python",
493 "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
494 "env": {}
495 }"#
496 },
497 "deno": {
498 "kernel.json": r#"{
499 "display_name": "Deno",
500 "language": "typescript",
501 "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
502 "env": {}
503 }"#
504 }
505 },
506 }),
507 )
508 .await;
509
510 let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
511 .await
512 .unwrap();
513
514 kernels.sort_by(|a, b| a.name.cmp(&b.name));
515
516 assert_eq!(
517 kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
518 vec!["deno", "python"]
519 );
520 }
521}