1use anyhow::{Context as _, Result};
2use futures::{
3 channel::mpsc::{self},
4 io::BufReader,
5 stream::{SelectAll, StreamExt},
6 AsyncBufReadExt as _, SinkExt as _,
7};
8use gpui::{EntityId, Task, View, WindowContext};
9use jupyter_protocol::{JupyterKernelspec, JupyterMessage, JupyterMessageContent, KernelInfoReply};
10use project::Fs;
11use runtimelib::{dirs, ConnectionInfo, ExecutionState};
12use smol::{net::TcpListener, process::Command};
13use std::{
14 env,
15 fmt::Debug,
16 net::{IpAddr, Ipv4Addr, SocketAddr},
17 path::PathBuf,
18 sync::Arc,
19};
20use uuid::Uuid;
21
22use crate::Session;
23
24use super::RunningKernel;
25
26#[derive(Debug, Clone)]
27pub struct LocalKernelSpecification {
28 pub name: String,
29 pub path: PathBuf,
30 pub kernelspec: JupyterKernelspec,
31}
32
33impl PartialEq for LocalKernelSpecification {
34 fn eq(&self, other: &Self) -> bool {
35 self.name == other.name && self.path == other.path
36 }
37}
38
39impl Eq for LocalKernelSpecification {}
40
41impl LocalKernelSpecification {
42 #[must_use]
43 fn command(&self, connection_path: &PathBuf) -> Result<Command> {
44 let argv = &self.kernelspec.argv;
45
46 anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
47 anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
48 anyhow::ensure!(
49 argv.iter().any(|arg| arg == "{connection_file}"),
50 "Missing 'connection_file' in argv in kernelspec {}",
51 self.name
52 );
53
54 let mut cmd = util::command::new_smol_command(&argv[0]);
55
56 for arg in &argv[1..] {
57 if arg == "{connection_file}" {
58 cmd.arg(connection_path);
59 } else {
60 cmd.arg(arg);
61 }
62 }
63
64 if let Some(env) = &self.kernelspec.env {
65 cmd.envs(env);
66 }
67
68 Ok(cmd)
69 }
70}
71
72// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
73// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
74async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
75 let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
76 addr_zeroport.set_port(0);
77 let mut ports: [u16; 5] = [0; 5];
78 for i in 0..5 {
79 let listener = TcpListener::bind(addr_zeroport).await?;
80 let addr = listener.local_addr()?;
81 ports[i] = addr.port();
82 }
83 Ok(ports)
84}
85
86pub struct NativeRunningKernel {
87 pub process: smol::process::Child,
88 _shell_task: Task<Result<()>>,
89 _control_task: Task<Result<()>>,
90 _routing_task: Task<Result<()>>,
91 connection_path: PathBuf,
92 _process_status_task: Option<Task<()>>,
93 pub working_directory: PathBuf,
94 pub request_tx: mpsc::Sender<JupyterMessage>,
95 pub execution_state: ExecutionState,
96 pub kernel_info: Option<KernelInfoReply>,
97}
98
99impl Debug for NativeRunningKernel {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 f.debug_struct("RunningKernel")
102 .field("process", &self.process)
103 .finish()
104 }
105}
106
107impl NativeRunningKernel {
108 pub fn new(
109 kernel_specification: LocalKernelSpecification,
110 entity_id: EntityId,
111 working_directory: PathBuf,
112 fs: Arc<dyn Fs>,
113 // todo: convert to weak view
114 session: View<Session>,
115 cx: &mut WindowContext,
116 ) -> Task<Result<Box<dyn RunningKernel>>> {
117 cx.spawn(|cx| async move {
118 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
119 let ports = peek_ports(ip).await?;
120
121 let connection_info = ConnectionInfo {
122 transport: "tcp".to_string(),
123 ip: ip.to_string(),
124 stdin_port: ports[0],
125 control_port: ports[1],
126 hb_port: ports[2],
127 shell_port: ports[3],
128 iopub_port: ports[4],
129 signature_scheme: "hmac-sha256".to_string(),
130 key: uuid::Uuid::new_v4().to_string(),
131 kernel_name: Some(format!("zed-{}", kernel_specification.name)),
132 };
133
134 let runtime_dir = dirs::runtime_dir();
135 fs.create_dir(&runtime_dir)
136 .await
137 .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
138 let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
139 let content = serde_json::to_string(&connection_info)?;
140 fs.atomic_write(connection_path.clone(), content).await?;
141
142 let mut cmd = kernel_specification.command(&connection_path)?;
143
144 let mut process = cmd
145 .current_dir(&working_directory)
146 .stdout(std::process::Stdio::piped())
147 .stderr(std::process::Stdio::piped())
148 .stdin(std::process::Stdio::piped())
149 .kill_on_drop(true)
150 .spawn()
151 .context("failed to start the kernel process")?;
152
153 let session_id = Uuid::new_v4().to_string();
154
155 let mut iopub_socket =
156 runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
157 .await?;
158 let mut shell_socket =
159 runtimelib::create_client_shell_connection(&connection_info, &session_id).await?;
160 let mut control_socket =
161 runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
162
163 let (request_tx, mut request_rx) =
164 futures::channel::mpsc::channel::<JupyterMessage>(100);
165
166 let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
167 let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
168
169 let mut messages_rx = SelectAll::new();
170 messages_rx.push(control_reply_rx);
171 messages_rx.push(shell_reply_rx);
172
173 cx.spawn({
174 let session = session.clone();
175
176 |mut cx| async move {
177 while let Some(message) = messages_rx.next().await {
178 session
179 .update(&mut cx, |session, cx| {
180 session.route(&message, cx);
181 })
182 .ok();
183 }
184 anyhow::Ok(())
185 }
186 })
187 .detach();
188
189 // iopub task
190 cx.spawn({
191 let session = session.clone();
192
193 |mut cx| async move {
194 while let Ok(message) = iopub_socket.read().await {
195 session
196 .update(&mut cx, |session, cx| {
197 session.route(&message, cx);
198 })
199 .ok();
200 }
201 anyhow::Ok(())
202 }
203 })
204 .detach();
205
206 let (mut control_request_tx, mut control_request_rx) =
207 futures::channel::mpsc::channel(100);
208 let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
209
210 let routing_task = cx.background_executor().spawn({
211 async move {
212 while let Some(message) = request_rx.next().await {
213 match message.content {
214 JupyterMessageContent::DebugRequest(_)
215 | JupyterMessageContent::InterruptRequest(_)
216 | JupyterMessageContent::ShutdownRequest(_) => {
217 control_request_tx.send(message).await?;
218 }
219 _ => {
220 shell_request_tx.send(message).await?;
221 }
222 }
223 }
224 anyhow::Ok(())
225 }
226 });
227
228 let shell_task = cx.background_executor().spawn({
229 async move {
230 while let Some(message) = shell_request_rx.next().await {
231 shell_socket.send(message).await.ok();
232 let reply = shell_socket.read().await?;
233 shell_reply_tx.send(reply).await?;
234 }
235 anyhow::Ok(())
236 }
237 });
238
239 let control_task = cx.background_executor().spawn({
240 async move {
241 while let Some(message) = control_request_rx.next().await {
242 control_socket.send(message).await.ok();
243 let reply = control_socket.read().await?;
244 control_reply_tx.send(reply).await?;
245 }
246 anyhow::Ok(())
247 }
248 });
249
250 let stderr = process.stderr.take();
251
252 cx.spawn(|mut _cx| async move {
253 if stderr.is_none() {
254 return;
255 }
256 let reader = BufReader::new(stderr.unwrap());
257 let mut lines = reader.lines();
258 while let Some(Ok(line)) = lines.next().await {
259 log::error!("kernel: {}", line);
260 }
261 })
262 .detach();
263
264 let stdout = process.stdout.take();
265
266 cx.spawn(|mut _cx| async move {
267 if stdout.is_none() {
268 return;
269 }
270 let reader = BufReader::new(stdout.unwrap());
271 let mut lines = reader.lines();
272 while let Some(Ok(line)) = lines.next().await {
273 log::info!("kernel: {}", line);
274 }
275 })
276 .detach();
277
278 let status = process.status();
279
280 let process_status_task = cx.spawn(|mut cx| async move {
281 let error_message = match status.await {
282 Ok(status) => {
283 if status.success() {
284 log::info!("kernel process exited successfully");
285 return;
286 }
287
288 format!("kernel process exited with status: {:?}", status)
289 }
290 Err(err) => {
291 format!("kernel process exited with error: {:?}", err)
292 }
293 };
294
295 log::error!("{}", error_message);
296
297 session
298 .update(&mut cx, |session, cx| {
299 session.kernel_errored(error_message, cx);
300
301 cx.notify();
302 })
303 .ok();
304 });
305
306 anyhow::Ok(Box::new(Self {
307 process,
308 request_tx,
309 working_directory,
310 _process_status_task: Some(process_status_task),
311 _shell_task: shell_task,
312 _control_task: control_task,
313 _routing_task: routing_task,
314 connection_path,
315 execution_state: ExecutionState::Idle,
316 kernel_info: None,
317 }) as Box<dyn RunningKernel>)
318 })
319 }
320}
321
322impl RunningKernel for NativeRunningKernel {
323 fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
324 self.request_tx.clone()
325 }
326
327 fn working_directory(&self) -> &PathBuf {
328 &self.working_directory
329 }
330
331 fn execution_state(&self) -> &ExecutionState {
332 &self.execution_state
333 }
334
335 fn set_execution_state(&mut self, state: ExecutionState) {
336 self.execution_state = state;
337 }
338
339 fn kernel_info(&self) -> Option<&KernelInfoReply> {
340 self.kernel_info.as_ref()
341 }
342
343 fn set_kernel_info(&mut self, info: KernelInfoReply) {
344 self.kernel_info = Some(info);
345 }
346
347 fn force_shutdown(&mut self, _cx: &mut WindowContext) -> Task<anyhow::Result<()>> {
348 self._process_status_task.take();
349 self.request_tx.close_channel();
350
351 Task::ready(match self.process.kill() {
352 Ok(_) => Ok(()),
353 Err(error) => Err(anyhow::anyhow!(
354 "Failed to kill the kernel process: {}",
355 error
356 )),
357 })
358 }
359}
360
361impl Drop for NativeRunningKernel {
362 fn drop(&mut self) {
363 std::fs::remove_file(&self.connection_path).ok();
364 self.request_tx.close_channel();
365 self.process.kill().ok();
366 }
367}
368
369async fn read_kernelspec_at(
370 // Path should be a directory to a jupyter kernelspec, as in
371 // /usr/local/share/jupyter/kernels/python3
372 kernel_dir: PathBuf,
373 fs: &dyn Fs,
374) -> Result<LocalKernelSpecification> {
375 let path = kernel_dir;
376 let kernel_name = if let Some(kernel_name) = path.file_name() {
377 kernel_name.to_string_lossy().to_string()
378 } else {
379 anyhow::bail!("Invalid kernelspec directory: {path:?}");
380 };
381
382 if !fs.is_dir(path.as_path()).await {
383 anyhow::bail!("Not a directory: {path:?}");
384 }
385
386 let expected_kernel_json = path.join("kernel.json");
387 let spec = fs.load(expected_kernel_json.as_path()).await?;
388 let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
389
390 Ok(LocalKernelSpecification {
391 name: kernel_name,
392 path,
393 kernelspec: spec,
394 })
395}
396
397/// Read a directory of kernelspec directories
398async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
399 let mut kernelspec_dirs = fs.read_dir(&path).await?;
400
401 let mut valid_kernelspecs = Vec::new();
402 while let Some(path) = kernelspec_dirs.next().await {
403 match path {
404 Ok(path) => {
405 if fs.is_dir(path.as_path()).await {
406 if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
407 valid_kernelspecs.push(kernelspec);
408 }
409 }
410 }
411 Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
412 }
413 }
414
415 Ok(valid_kernelspecs)
416}
417
418pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
419 let mut data_dirs = dirs::data_dirs();
420
421 // Pick up any kernels from conda or conda environment
422 if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
423 let conda_prefix = PathBuf::from(conda_prefix);
424 let conda_data_dir = conda_prefix.join("share").join("jupyter");
425 data_dirs.push(conda_data_dir);
426 }
427
428 // Search for kernels inside the base python environment
429 let command = util::command::new_smol_command("python")
430 .arg("-c")
431 .arg("import sys; print(sys.prefix)")
432 .output()
433 .await;
434
435 if let Ok(command) = command {
436 if command.status.success() {
437 let python_prefix = String::from_utf8(command.stdout);
438 if let Ok(python_prefix) = python_prefix {
439 let python_prefix = PathBuf::from(python_prefix.trim());
440 let python_data_dir = python_prefix.join("share").join("jupyter");
441 data_dirs.push(python_data_dir);
442 }
443 }
444 }
445
446 let kernel_dirs = data_dirs
447 .iter()
448 .map(|dir| dir.join("kernels"))
449 .map(|path| read_kernels_dir(path, fs.as_ref()))
450 .collect::<Vec<_>>();
451
452 let kernel_dirs = futures::future::join_all(kernel_dirs).await;
453 let kernel_dirs = kernel_dirs
454 .into_iter()
455 .filter_map(Result::ok)
456 .flatten()
457 .collect::<Vec<_>>();
458
459 Ok(kernel_dirs)
460}
461
462#[cfg(test)]
463mod test {
464 use super::*;
465 use std::path::PathBuf;
466
467 use gpui::TestAppContext;
468 use project::FakeFs;
469 use serde_json::json;
470
471 #[gpui::test]
472 async fn test_get_kernelspecs(cx: &mut TestAppContext) {
473 let fs = FakeFs::new(cx.executor());
474 fs.insert_tree(
475 "/jupyter",
476 json!({
477 ".zed": {
478 "settings.json": r#"{ "tab_size": 8 }"#,
479 "tasks.json": r#"[{
480 "label": "cargo check",
481 "command": "cargo",
482 "args": ["check", "--all"]
483 },]"#,
484 },
485 "kernels": {
486 "python": {
487 "kernel.json": r#"{
488 "display_name": "Python 3",
489 "language": "python",
490 "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
491 "env": {}
492 }"#
493 },
494 "deno": {
495 "kernel.json": r#"{
496 "display_name": "Deno",
497 "language": "typescript",
498 "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
499 "env": {}
500 }"#
501 }
502 },
503 }),
504 )
505 .await;
506
507 let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
508 .await
509 .unwrap();
510
511 kernels.sort_by(|a, b| a.name.cmp(&b.name));
512
513 assert_eq!(
514 kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
515 vec!["deno", "python"]
516 );
517 }
518}