native_kernel.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    AsyncBufReadExt as _, FutureExt as _, StreamExt as _,
  4    channel::mpsc::{self},
  5    io::BufReader,
  6    stream::FuturesUnordered,
  7};
  8use gpui::{App, AppContext as _, ClipboardItem, Entity, EntityId, Task, Window};
  9use jupyter_protocol::{
 10    ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent, KernelInfoReply,
 11    connection_info::{ConnectionInfo, Transport},
 12};
 13use project::Fs;
 14use runtimelib::{RuntimeError, dirs};
 15use smol::{net::TcpListener, process::Command};
 16use std::{
 17    env,
 18    fmt::Debug,
 19    net::{IpAddr, Ipv4Addr, SocketAddr},
 20    path::PathBuf,
 21    sync::Arc,
 22};
 23use uuid::Uuid;
 24
 25use super::{KernelSession, RunningKernel};
 26
 27#[derive(Debug, Clone)]
 28pub struct LocalKernelSpecification {
 29    pub name: String,
 30    pub path: PathBuf,
 31    pub kernelspec: JupyterKernelspec,
 32}
 33
 34impl PartialEq for LocalKernelSpecification {
 35    fn eq(&self, other: &Self) -> bool {
 36        self.name == other.name && self.path == other.path
 37    }
 38}
 39
 40impl Eq for LocalKernelSpecification {}
 41
 42impl LocalKernelSpecification {
 43    #[must_use]
 44    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 45        let argv = &self.kernelspec.argv;
 46
 47        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 48        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 49        anyhow::ensure!(
 50            argv.iter().any(|arg| arg == "{connection_file}"),
 51            "Missing 'connection_file' in argv in kernelspec {}",
 52            self.name
 53        );
 54
 55        let mut cmd = util::command::new_smol_command(&argv[0]);
 56
 57        for arg in &argv[1..] {
 58            if arg == "{connection_file}" {
 59                cmd.arg(connection_path);
 60            } else {
 61                cmd.arg(arg);
 62            }
 63        }
 64
 65        if let Some(env) = &self.kernelspec.env {
 66            cmd.envs(env);
 67        }
 68
 69        Ok(cmd)
 70    }
 71}
 72
 73// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 74// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 75async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 76    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 77    addr_zeroport.set_port(0);
 78    let mut ports: [u16; 5] = [0; 5];
 79    for i in 0..5 {
 80        let listener = TcpListener::bind(addr_zeroport).await?;
 81        let addr = listener.local_addr()?;
 82        ports[i] = addr.port();
 83    }
 84    Ok(ports)
 85}
 86
 87pub struct NativeRunningKernel {
 88    pub process: smol::process::Child,
 89    connection_path: PathBuf,
 90    _process_status_task: Option<Task<()>>,
 91    pub working_directory: PathBuf,
 92    pub request_tx: mpsc::Sender<JupyterMessage>,
 93    pub stdin_tx: mpsc::Sender<JupyterMessage>,
 94    pub execution_state: ExecutionState,
 95    pub kernel_info: Option<KernelInfoReply>,
 96}
 97
 98impl Debug for NativeRunningKernel {
 99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        f.debug_struct("RunningKernel")
101            .field("process", &self.process)
102            .finish()
103    }
104}
105
106impl NativeRunningKernel {
107    pub fn new<S: KernelSession + 'static>(
108        kernel_specification: LocalKernelSpecification,
109        entity_id: EntityId,
110        working_directory: PathBuf,
111        fs: Arc<dyn Fs>,
112        // todo: convert to weak view
113        session: Entity<S>,
114        window: &mut Window,
115        cx: &mut App,
116    ) -> Task<Result<Box<dyn RunningKernel>>> {
117        window.spawn(cx, async move |cx| {
118            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
119            let ports = peek_ports(ip).await?;
120
121            let connection_info = ConnectionInfo {
122                transport: Transport::TCP,
123                ip: ip.to_string(),
124                stdin_port: ports[0],
125                control_port: ports[1],
126                hb_port: ports[2],
127                shell_port: ports[3],
128                iopub_port: ports[4],
129                signature_scheme: "hmac-sha256".to_string(),
130                key: uuid::Uuid::new_v4().to_string(),
131                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
132            };
133
134            let runtime_dir = dirs::runtime_dir();
135            fs.create_dir(&runtime_dir)
136                .await
137                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
138            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
139            let content = serde_json::to_string(&connection_info)?;
140            fs.atomic_write(connection_path.clone(), content).await?;
141
142            let mut cmd = kernel_specification.command(&connection_path)?;
143
144            let mut process = cmd
145                .current_dir(&working_directory)
146                .stdout(std::process::Stdio::piped())
147                .stderr(std::process::Stdio::piped())
148                .stdin(std::process::Stdio::piped())
149                .kill_on_drop(true)
150                .spawn()
151                .context("failed to start the kernel process")?;
152
153            let session_id = Uuid::new_v4().to_string();
154
155            let iopub_socket =
156                runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
157                    .await?;
158            let control_socket =
159                runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
160
161            let peer_identity = runtimelib::peer_identity_for_session(&session_id)?;
162            let shell_socket =
163                runtimelib::create_client_shell_connection_with_identity(
164                    &connection_info,
165                    &session_id,
166                    peer_identity.clone(),
167                )
168                .await?;
169            let stdin_socket = runtimelib::create_client_stdin_connection_with_identity(
170                &connection_info,
171                &session_id,
172                peer_identity,
173            )
174            .await?;
175
176            let (mut shell_send, shell_recv) = shell_socket.split();
177            let (mut control_send, control_recv) = control_socket.split();
178            let (mut stdin_send, stdin_recv) = stdin_socket.split();
179
180            let (request_tx, mut request_rx) =
181                futures::channel::mpsc::channel::<JupyterMessage>(100);
182            let (stdin_tx, mut stdin_rx) =
183                futures::channel::mpsc::channel::<JupyterMessage>(100);
184
185            let recv_task = cx.spawn({
186                let session = session.clone();
187                let mut iopub = iopub_socket;
188                let mut shell = shell_recv;
189                let mut control = control_recv;
190                let mut stdin = stdin_recv;
191
192                async move |cx| -> anyhow::Result<()> {
193                    loop {
194                        let (channel, result) = futures::select! {
195                            msg = iopub.read().fuse() => ("iopub", msg),
196                            msg = shell.read().fuse() => ("shell", msg),
197                            msg = control.read().fuse() => ("control", msg),
198                            msg = stdin.read().fuse() => ("stdin", msg),
199                        };
200                        match result {
201                            Ok(message) => {
202                                session
203                                    .update_in(cx, |session, window, cx| {
204                                        session.route(&message, window, cx);
205                                    })
206                                    .ok();
207                            }
208                            Err(
209                                ref err @ (RuntimeError::ParseError { .. }
210                                | RuntimeError::SerdeError(_)),
211                            ) => {
212                                let error_detail =
213                                    format!("Kernel issue on {channel} channel\n\n{err}");
214                                log::warn!("kernel: {error_detail}");
215                                let workspace_window = session
216                                    .update_in(cx, |_, window, _cx| {
217                                        window
218                                            .window_handle()
219                                            .downcast::<workspace::Workspace>()
220                                    })
221                                    .ok()
222                                    .flatten();
223                                if let Some(workspace_window) = workspace_window {
224                                    workspace_window
225                                        .update(cx, |workspace, _window, cx| {
226                                            struct KernelReadError;
227                                            workspace.show_toast(
228                                                workspace::Toast::new(
229                                                    workspace::notifications::NotificationId::unique::<KernelReadError>(),
230                                                    error_detail.clone(),
231                                                )
232                                                .on_click(
233                                                    "Copy Error",
234                                                    move |_window, cx| {
235                                                        cx.write_to_clipboard(
236                                                            ClipboardItem::new_string(
237                                                                error_detail.clone(),
238                                                            ),
239                                                        );
240                                                    },
241                                                ),
242                                                cx,
243                                            );
244                                        })
245                                        .ok();
246                                }
247                            }
248                            Err(err) => {
249                                anyhow::bail!("{channel} recv: {err}");
250                            }
251                        }
252                    }
253                }
254            });
255
256            let routing_task = cx.background_spawn({
257                async move {
258                    while let Some(message) = request_rx.next().await {
259                        match message.content {
260                            JupyterMessageContent::DebugRequest(_)
261                            | JupyterMessageContent::InterruptRequest(_)
262                            | JupyterMessageContent::ShutdownRequest(_) => {
263                                control_send.send(message).await?;
264                            }
265                            _ => {
266                                shell_send.send(message).await?;
267                            }
268                        }
269                    }
270                    anyhow::Ok(())
271                }
272            });
273
274            let stdin_routing_task = cx.background_spawn({
275                async move {
276                    while let Some(message) = stdin_rx.next().await {
277                        stdin_send.send(message).await?;
278                    }
279                    anyhow::Ok(())
280                }
281            });
282
283            let stderr = process.stderr.take();
284            let stdout = process.stdout.take();
285
286            cx.spawn(async move |_cx| {
287                use futures::future::Either;
288
289                let stderr_lines = match stderr {
290                    Some(s) => Either::Left(
291                        BufReader::new(s)
292                            .lines()
293                            .map(|line| (log::Level::Error, line)),
294                    ),
295                    None => Either::Right(futures::stream::empty()),
296                };
297                let stdout_lines = match stdout {
298                    Some(s) => Either::Left(
299                        BufReader::new(s)
300                            .lines()
301                            .map(|line| (log::Level::Info, line)),
302                    ),
303                    None => Either::Right(futures::stream::empty()),
304                };
305                let mut lines = futures::stream::select(stderr_lines, stdout_lines);
306                while let Some((level, Ok(line))) = lines.next().await {
307                    log::log!(level, "kernel: {}", line);
308                }
309            })
310            .detach();
311
312            cx.spawn({
313                let session = session.clone();
314                async move |cx| {
315                    async fn with_name(
316                        name: &'static str,
317                        task: Task<Result<()>>,
318                    ) -> (&'static str, Result<()>) {
319                        (name, task.await)
320                    }
321
322                    let mut tasks = FuturesUnordered::new();
323                    tasks.push(with_name("recv task", recv_task));
324                    tasks.push(with_name("routing task", routing_task));
325                    tasks.push(with_name("stdin routing task", stdin_routing_task));
326
327                    while let Some((name, result)) = tasks.next().await {
328                        if let Err(err) = result {
329                            log::error!("kernel: handling failed for {name}: {err:?}");
330
331                            session.update(cx, |session, cx| {
332                                session.kernel_errored(
333                                    format!("handling failed for {name}: {err}"),
334                                    cx,
335                                );
336                                cx.notify();
337                            });
338                        }
339                    }
340                }
341            })
342            .detach();
343
344            let status = process.status();
345
346            let process_status_task = cx.spawn(async move |cx| {
347                let error_message = match status.await {
348                    Ok(status) => {
349                        if status.success() {
350                            log::info!("kernel process exited successfully");
351                            return;
352                        }
353
354                        format!("kernel process exited with status: {:?}", status)
355                    }
356                    Err(err) => {
357                        format!("kernel process exited with error: {:?}", err)
358                    }
359                };
360
361                log::error!("{}", error_message);
362
363                session.update(cx, |session, cx| {
364                    session.kernel_errored(error_message, cx);
365
366                    cx.notify();
367                });
368            });
369
370            anyhow::Ok(Box::new(Self {
371                process,
372                request_tx,
373                stdin_tx,
374                working_directory,
375                _process_status_task: Some(process_status_task),
376                connection_path,
377                execution_state: ExecutionState::Idle,
378                kernel_info: None,
379            }) as Box<dyn RunningKernel>)
380        })
381    }
382}
383
384impl RunningKernel for NativeRunningKernel {
385    fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
386        self.request_tx.clone()
387    }
388
389    fn stdin_tx(&self) -> mpsc::Sender<JupyterMessage> {
390        self.stdin_tx.clone()
391    }
392
393    fn working_directory(&self) -> &PathBuf {
394        &self.working_directory
395    }
396
397    fn execution_state(&self) -> &ExecutionState {
398        &self.execution_state
399    }
400
401    fn set_execution_state(&mut self, state: ExecutionState) {
402        self.execution_state = state;
403    }
404
405    fn kernel_info(&self) -> Option<&KernelInfoReply> {
406        self.kernel_info.as_ref()
407    }
408
409    fn set_kernel_info(&mut self, info: KernelInfoReply) {
410        self.kernel_info = Some(info);
411    }
412
413    fn force_shutdown(&mut self, _window: &mut Window, _cx: &mut App) -> Task<anyhow::Result<()>> {
414        self.kill();
415        Task::ready(Ok(()))
416    }
417
418    fn kill(&mut self) {
419        self._process_status_task.take();
420        self.request_tx.close_channel();
421        self.stdin_tx.close_channel();
422        self.process.kill().ok();
423    }
424}
425
426impl Drop for NativeRunningKernel {
427    fn drop(&mut self) {
428        std::fs::remove_file(&self.connection_path).ok();
429        self.kill();
430    }
431}
432
433async fn read_kernelspec_at(
434    // Path should be a directory to a jupyter kernelspec, as in
435    // /usr/local/share/jupyter/kernels/python3
436    kernel_dir: PathBuf,
437    fs: &dyn Fs,
438) -> Result<LocalKernelSpecification> {
439    let path = kernel_dir;
440    let kernel_name = if let Some(kernel_name) = path.file_name() {
441        kernel_name.to_string_lossy().into_owned()
442    } else {
443        anyhow::bail!("Invalid kernelspec directory: {path:?}");
444    };
445
446    if !fs.is_dir(path.as_path()).await {
447        anyhow::bail!("Not a directory: {path:?}");
448    }
449
450    let expected_kernel_json = path.join("kernel.json");
451    let spec = fs.load(expected_kernel_json.as_path()).await?;
452    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
453
454    Ok(LocalKernelSpecification {
455        name: kernel_name,
456        path,
457        kernelspec: spec,
458    })
459}
460
461/// Read a directory of kernelspec directories
462async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
463    let mut kernelspec_dirs = fs.read_dir(&path).await?;
464
465    let mut valid_kernelspecs = Vec::new();
466    while let Some(path) = kernelspec_dirs.next().await {
467        match path {
468            Ok(path) => {
469                if fs.is_dir(path.as_path()).await
470                    && let Ok(kernelspec) = read_kernelspec_at(path, fs).await
471                {
472                    valid_kernelspecs.push(kernelspec);
473                }
474            }
475            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
476        }
477    }
478
479    Ok(valid_kernelspecs)
480}
481
482pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
483    let mut data_dirs = dirs::data_dirs();
484
485    // Pick up any kernels from conda or conda environment
486    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
487        let conda_prefix = PathBuf::from(conda_prefix);
488        let conda_data_dir = conda_prefix.join("share").join("jupyter");
489        data_dirs.push(conda_data_dir);
490    }
491
492    // Search for kernels inside the base python environment
493    let command = util::command::new_smol_command("python")
494        .arg("-c")
495        .arg("import sys; print(sys.prefix)")
496        .output()
497        .await;
498
499    if let Ok(command) = command
500        && command.status.success()
501    {
502        let python_prefix = String::from_utf8(command.stdout);
503        if let Ok(python_prefix) = python_prefix {
504            let python_prefix = PathBuf::from(python_prefix.trim());
505            let python_data_dir = python_prefix.join("share").join("jupyter");
506            data_dirs.push(python_data_dir);
507        }
508    }
509
510    let kernel_dirs = data_dirs
511        .iter()
512        .map(|dir| dir.join("kernels"))
513        .map(|path| read_kernels_dir(path, fs.as_ref()))
514        .collect::<Vec<_>>();
515
516    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
517    let kernel_dirs = kernel_dirs
518        .into_iter()
519        .filter_map(Result::ok)
520        .flatten()
521        .collect::<Vec<_>>();
522
523    Ok(kernel_dirs)
524}
525
526#[cfg(test)]
527mod test {
528    use super::*;
529    use std::path::PathBuf;
530
531    use gpui::TestAppContext;
532    use project::FakeFs;
533    use serde_json::json;
534
535    #[gpui::test]
536    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
537        let fs = FakeFs::new(cx.executor());
538        fs.insert_tree(
539            "/jupyter",
540            json!({
541                ".zed": {
542                    "settings.json": r#"{ "tab_size": 8 }"#,
543                    "tasks.json": r#"[{
544                        "label": "cargo check",
545                        "command": "cargo",
546                        "args": ["check", "--all"]
547                    },]"#,
548                },
549                "kernels": {
550                    "python": {
551                        "kernel.json": r#"{
552                            "display_name": "Python 3",
553                            "language": "python",
554                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
555                            "env": {}
556                        }"#
557                    },
558                    "deno": {
559                        "kernel.json": r#"{
560                            "display_name": "Deno",
561                            "language": "typescript",
562                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
563                            "env": {}
564                        }"#
565                    }
566                },
567            }),
568        )
569        .await;
570
571        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
572            .await
573            .unwrap();
574
575        kernels.sort_by(|a, b| a.name.cmp(&b.name));
576
577        assert_eq!(
578            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
579            vec!["deno", "python"]
580        );
581    }
582}