native_kernel.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    AsyncBufReadExt as _, FutureExt as _, StreamExt as _,
  4    channel::mpsc::{self},
  5    io::BufReader,
  6    stream::FuturesUnordered,
  7};
  8use gpui::{App, AppContext as _, ClipboardItem, Entity, EntityId, Task, Window};
  9use jupyter_protocol::{
 10    ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent, KernelInfoReply,
 11    connection_info::{ConnectionInfo, Transport},
 12};
 13use project::Fs;
 14use runtimelib::{RuntimeError, dirs};
 15use smol::{net::TcpListener, process::Command};
 16use std::{
 17    env,
 18    fmt::Debug,
 19    net::{IpAddr, Ipv4Addr, SocketAddr},
 20    path::PathBuf,
 21    sync::Arc,
 22};
 23use uuid::Uuid;
 24
 25use super::{KernelSession, RunningKernel};
 26
 27#[derive(Debug, Clone)]
 28pub struct LocalKernelSpecification {
 29    pub name: String,
 30    pub path: PathBuf,
 31    pub kernelspec: JupyterKernelspec,
 32}
 33
 34impl PartialEq for LocalKernelSpecification {
 35    fn eq(&self, other: &Self) -> bool {
 36        self.name == other.name && self.path == other.path
 37    }
 38}
 39
 40impl Eq for LocalKernelSpecification {}
 41
 42impl LocalKernelSpecification {
 43    #[must_use]
 44    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 45        let argv = &self.kernelspec.argv;
 46
 47        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 48        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 49        anyhow::ensure!(
 50            argv.iter().any(|arg| arg == "{connection_file}"),
 51            "Missing 'connection_file' in argv in kernelspec {}",
 52            self.name
 53        );
 54
 55        let mut cmd = util::command::new_smol_command(&argv[0]);
 56
 57        for arg in &argv[1..] {
 58            if arg == "{connection_file}" {
 59                cmd.arg(connection_path);
 60            } else {
 61                cmd.arg(arg);
 62            }
 63        }
 64
 65        if let Some(env) = &self.kernelspec.env {
 66            cmd.envs(env);
 67        }
 68
 69        Ok(cmd)
 70    }
 71}
 72
 73// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 74// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 75async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 76    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 77    addr_zeroport.set_port(0);
 78    let mut ports: [u16; 5] = [0; 5];
 79    for i in 0..5 {
 80        let listener = TcpListener::bind(addr_zeroport).await?;
 81        let addr = listener.local_addr()?;
 82        ports[i] = addr.port();
 83    }
 84    Ok(ports)
 85}
 86
 87pub struct NativeRunningKernel {
 88    pub process: smol::process::Child,
 89    connection_path: PathBuf,
 90    _process_status_task: Option<Task<()>>,
 91    pub working_directory: PathBuf,
 92    pub request_tx: mpsc::Sender<JupyterMessage>,
 93    pub execution_state: ExecutionState,
 94    pub kernel_info: Option<KernelInfoReply>,
 95}
 96
 97impl Debug for NativeRunningKernel {
 98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 99        f.debug_struct("RunningKernel")
100            .field("process", &self.process)
101            .finish()
102    }
103}
104
105impl NativeRunningKernel {
106    pub fn new<S: KernelSession + 'static>(
107        kernel_specification: LocalKernelSpecification,
108        entity_id: EntityId,
109        working_directory: PathBuf,
110        fs: Arc<dyn Fs>,
111        // todo: convert to weak view
112        session: Entity<S>,
113        window: &mut Window,
114        cx: &mut App,
115    ) -> Task<Result<Box<dyn RunningKernel>>> {
116        window.spawn(cx, async move |cx| {
117            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
118            let ports = peek_ports(ip).await?;
119
120            let connection_info = ConnectionInfo {
121                transport: Transport::TCP,
122                ip: ip.to_string(),
123                stdin_port: ports[0],
124                control_port: ports[1],
125                hb_port: ports[2],
126                shell_port: ports[3],
127                iopub_port: ports[4],
128                signature_scheme: "hmac-sha256".to_string(),
129                key: uuid::Uuid::new_v4().to_string(),
130                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
131            };
132
133            let runtime_dir = dirs::runtime_dir();
134            fs.create_dir(&runtime_dir)
135                .await
136                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
137            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
138            let content = serde_json::to_string(&connection_info)?;
139            fs.atomic_write(connection_path.clone(), content).await?;
140
141            let mut cmd = kernel_specification.command(&connection_path)?;
142
143            let mut process = cmd
144                .current_dir(&working_directory)
145                .stdout(std::process::Stdio::piped())
146                .stderr(std::process::Stdio::piped())
147                .stdin(std::process::Stdio::piped())
148                .kill_on_drop(true)
149                .spawn()
150                .context("failed to start the kernel process")?;
151
152            let session_id = Uuid::new_v4().to_string();
153
154            let iopub_socket =
155                runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
156                    .await?;
157            let shell_socket =
158                runtimelib::create_client_shell_connection(&connection_info, &session_id).await?;
159            let control_socket =
160                runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
161
162            let (mut shell_send, shell_recv) = shell_socket.split();
163            let (mut control_send, control_recv) = control_socket.split();
164
165            let (request_tx, mut request_rx) =
166                futures::channel::mpsc::channel::<JupyterMessage>(100);
167
168            let recv_task = cx.spawn({
169                let session = session.clone();
170                let mut iopub = iopub_socket;
171                let mut shell = shell_recv;
172                let mut control = control_recv;
173
174                async move |cx| -> anyhow::Result<()> {
175                    loop {
176                        let (channel, result) = futures::select! {
177                            msg = iopub.read().fuse() => ("iopub", msg),
178                            msg = shell.read().fuse() => ("shell", msg),
179                            msg = control.read().fuse() => ("control", msg),
180                        };
181                        match result {
182                            Ok(message) => {
183                                session
184                                    .update_in(cx, |session, window, cx| {
185                                        session.route(&message, window, cx);
186                                    })
187                                    .ok();
188                            }
189                            Err(
190                                ref err @ (RuntimeError::ParseError { .. }
191                                | RuntimeError::SerdeError(_)),
192                            ) => {
193                                let error_detail =
194                                    format!("Kernel issue on {channel} channel\n\n{err}");
195                                log::warn!("kernel: {error_detail}");
196                                let workspace_window = session
197                                    .update_in(cx, |_, window, _cx| {
198                                        window
199                                            .window_handle()
200                                            .downcast::<workspace::Workspace>()
201                                    })
202                                    .ok()
203                                    .flatten();
204                                if let Some(workspace_window) = workspace_window {
205                                    workspace_window
206                                        .update(cx, |workspace, _window, cx| {
207                                            struct KernelReadError;
208                                            workspace.show_toast(
209                                                workspace::Toast::new(
210                                                    workspace::notifications::NotificationId::unique::<KernelReadError>(),
211                                                    error_detail.clone(),
212                                                )
213                                                .on_click(
214                                                    "Copy Error",
215                                                    move |_window, cx| {
216                                                        cx.write_to_clipboard(
217                                                            ClipboardItem::new_string(
218                                                                error_detail.clone(),
219                                                            ),
220                                                        );
221                                                    },
222                                                ),
223                                                cx,
224                                            );
225                                        })
226                                        .ok();
227                                }
228                            }
229                            Err(err) => {
230                                anyhow::bail!("{channel} recv: {err}");
231                            }
232                        }
233                    }
234                }
235            });
236
237            let routing_task = cx.background_spawn({
238                async move {
239                    while let Some(message) = request_rx.next().await {
240                        match message.content {
241                            JupyterMessageContent::DebugRequest(_)
242                            | JupyterMessageContent::InterruptRequest(_)
243                            | JupyterMessageContent::ShutdownRequest(_) => {
244                                control_send.send(message).await?;
245                            }
246                            _ => {
247                                shell_send.send(message).await?;
248                            }
249                        }
250                    }
251                    anyhow::Ok(())
252                }
253            });
254
255            let stderr = process.stderr.take();
256            let stdout = process.stdout.take();
257
258            cx.spawn(async move |_cx| {
259                use futures::future::Either;
260
261                let stderr_lines = match stderr {
262                    Some(s) => Either::Left(
263                        BufReader::new(s)
264                            .lines()
265                            .map(|line| (log::Level::Error, line)),
266                    ),
267                    None => Either::Right(futures::stream::empty()),
268                };
269                let stdout_lines = match stdout {
270                    Some(s) => Either::Left(
271                        BufReader::new(s)
272                            .lines()
273                            .map(|line| (log::Level::Info, line)),
274                    ),
275                    None => Either::Right(futures::stream::empty()),
276                };
277                let mut lines = futures::stream::select(stderr_lines, stdout_lines);
278                while let Some((level, Ok(line))) = lines.next().await {
279                    log::log!(level, "kernel: {}", line);
280                }
281            })
282            .detach();
283
284            cx.spawn({
285                let session = session.clone();
286                async move |cx| {
287                    async fn with_name(
288                        name: &'static str,
289                        task: Task<Result<()>>,
290                    ) -> (&'static str, Result<()>) {
291                        (name, task.await)
292                    }
293
294                    let mut tasks = FuturesUnordered::new();
295                    tasks.push(with_name("recv task", recv_task));
296                    tasks.push(with_name("routing task", routing_task));
297
298                    while let Some((name, result)) = tasks.next().await {
299                        if let Err(err) = result {
300                            log::error!("kernel: handling failed for {name}: {err:?}");
301
302                            session.update(cx, |session, cx| {
303                                session.kernel_errored(
304                                    format!("handling failed for {name}: {err}"),
305                                    cx,
306                                );
307                                cx.notify();
308                            });
309                        }
310                    }
311                }
312            })
313            .detach();
314
315            let status = process.status();
316
317            let process_status_task = cx.spawn(async move |cx| {
318                let error_message = match status.await {
319                    Ok(status) => {
320                        if status.success() {
321                            log::info!("kernel process exited successfully");
322                            return;
323                        }
324
325                        format!("kernel process exited with status: {:?}", status)
326                    }
327                    Err(err) => {
328                        format!("kernel process exited with error: {:?}", err)
329                    }
330                };
331
332                log::error!("{}", error_message);
333
334                session.update(cx, |session, cx| {
335                    session.kernel_errored(error_message, cx);
336
337                    cx.notify();
338                });
339            });
340
341            anyhow::Ok(Box::new(Self {
342                process,
343                request_tx,
344                working_directory,
345                _process_status_task: Some(process_status_task),
346                connection_path,
347                execution_state: ExecutionState::Idle,
348                kernel_info: None,
349            }) as Box<dyn RunningKernel>)
350        })
351    }
352}
353
354impl RunningKernel for NativeRunningKernel {
355    fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
356        self.request_tx.clone()
357    }
358
359    fn working_directory(&self) -> &PathBuf {
360        &self.working_directory
361    }
362
363    fn execution_state(&self) -> &ExecutionState {
364        &self.execution_state
365    }
366
367    fn set_execution_state(&mut self, state: ExecutionState) {
368        self.execution_state = state;
369    }
370
371    fn kernel_info(&self) -> Option<&KernelInfoReply> {
372        self.kernel_info.as_ref()
373    }
374
375    fn set_kernel_info(&mut self, info: KernelInfoReply) {
376        self.kernel_info = Some(info);
377    }
378
379    fn force_shutdown(&mut self, _window: &mut Window, _cx: &mut App) -> Task<anyhow::Result<()>> {
380        self.kill();
381        Task::ready(Ok(()))
382    }
383
384    fn kill(&mut self) {
385        self._process_status_task.take();
386        self.request_tx.close_channel();
387        self.process.kill().ok();
388    }
389}
390
391impl Drop for NativeRunningKernel {
392    fn drop(&mut self) {
393        std::fs::remove_file(&self.connection_path).ok();
394        self.kill();
395    }
396}
397
398async fn read_kernelspec_at(
399    // Path should be a directory to a jupyter kernelspec, as in
400    // /usr/local/share/jupyter/kernels/python3
401    kernel_dir: PathBuf,
402    fs: &dyn Fs,
403) -> Result<LocalKernelSpecification> {
404    let path = kernel_dir;
405    let kernel_name = if let Some(kernel_name) = path.file_name() {
406        kernel_name.to_string_lossy().into_owned()
407    } else {
408        anyhow::bail!("Invalid kernelspec directory: {path:?}");
409    };
410
411    if !fs.is_dir(path.as_path()).await {
412        anyhow::bail!("Not a directory: {path:?}");
413    }
414
415    let expected_kernel_json = path.join("kernel.json");
416    let spec = fs.load(expected_kernel_json.as_path()).await?;
417    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
418
419    Ok(LocalKernelSpecification {
420        name: kernel_name,
421        path,
422        kernelspec: spec,
423    })
424}
425
426/// Read a directory of kernelspec directories
427async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
428    let mut kernelspec_dirs = fs.read_dir(&path).await?;
429
430    let mut valid_kernelspecs = Vec::new();
431    while let Some(path) = kernelspec_dirs.next().await {
432        match path {
433            Ok(path) => {
434                if fs.is_dir(path.as_path()).await
435                    && let Ok(kernelspec) = read_kernelspec_at(path, fs).await
436                {
437                    valid_kernelspecs.push(kernelspec);
438                }
439            }
440            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
441        }
442    }
443
444    Ok(valid_kernelspecs)
445}
446
447pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
448    let mut data_dirs = dirs::data_dirs();
449
450    // Pick up any kernels from conda or conda environment
451    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
452        let conda_prefix = PathBuf::from(conda_prefix);
453        let conda_data_dir = conda_prefix.join("share").join("jupyter");
454        data_dirs.push(conda_data_dir);
455    }
456
457    // Search for kernels inside the base python environment
458    let command = util::command::new_smol_command("python")
459        .arg("-c")
460        .arg("import sys; print(sys.prefix)")
461        .output()
462        .await;
463
464    if let Ok(command) = command
465        && command.status.success()
466    {
467        let python_prefix = String::from_utf8(command.stdout);
468        if let Ok(python_prefix) = python_prefix {
469            let python_prefix = PathBuf::from(python_prefix.trim());
470            let python_data_dir = python_prefix.join("share").join("jupyter");
471            data_dirs.push(python_data_dir);
472        }
473    }
474
475    let kernel_dirs = data_dirs
476        .iter()
477        .map(|dir| dir.join("kernels"))
478        .map(|path| read_kernels_dir(path, fs.as_ref()))
479        .collect::<Vec<_>>();
480
481    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
482    let kernel_dirs = kernel_dirs
483        .into_iter()
484        .filter_map(Result::ok)
485        .flatten()
486        .collect::<Vec<_>>();
487
488    Ok(kernel_dirs)
489}
490
491#[cfg(test)]
492mod test {
493    use super::*;
494    use std::path::PathBuf;
495
496    use gpui::TestAppContext;
497    use project::FakeFs;
498    use serde_json::json;
499
500    #[gpui::test]
501    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
502        let fs = FakeFs::new(cx.executor());
503        fs.insert_tree(
504            "/jupyter",
505            json!({
506                ".zed": {
507                    "settings.json": r#"{ "tab_size": 8 }"#,
508                    "tasks.json": r#"[{
509                        "label": "cargo check",
510                        "command": "cargo",
511                        "args": ["check", "--all"]
512                    },]"#,
513                },
514                "kernels": {
515                    "python": {
516                        "kernel.json": r#"{
517                            "display_name": "Python 3",
518                            "language": "python",
519                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
520                            "env": {}
521                        }"#
522                    },
523                    "deno": {
524                        "kernel.json": r#"{
525                            "display_name": "Deno",
526                            "language": "typescript",
527                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
528                            "env": {}
529                        }"#
530                    }
531                },
532            }),
533        )
534        .await;
535
536        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
537            .await
538            .unwrap();
539
540        kernels.sort_by(|a, b| a.name.cmp(&b.name));
541
542        assert_eq!(
543            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
544            vec!["deno", "python"]
545        );
546    }
547}