native_kernel.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    AsyncBufReadExt as _, FutureExt as _, StreamExt as _,
  4    channel::mpsc::{self},
  5    io::BufReader,
  6    stream::FuturesUnordered,
  7};
  8use gpui::{App, AppContext as _, ClipboardItem, Entity, EntityId, Task, Window};
  9use jupyter_protocol::{
 10    ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent, KernelInfoReply,
 11    connection_info::{ConnectionInfo, Transport},
 12};
 13use project::Fs;
 14use runtimelib::{RuntimeError, dirs};
 15use smol::net::TcpListener;
 16use std::{
 17    env,
 18    fmt::Debug,
 19    net::{IpAddr, Ipv4Addr, SocketAddr},
 20    path::PathBuf,
 21    sync::Arc,
 22};
 23use util::command::Command;
 24use uuid::Uuid;
 25
 26use super::{KernelSession, RunningKernel};
 27
 28#[derive(Debug, Clone)]
 29pub struct LocalKernelSpecification {
 30    pub name: String,
 31    pub path: PathBuf,
 32    pub kernelspec: JupyterKernelspec,
 33}
 34
 35impl PartialEq for LocalKernelSpecification {
 36    fn eq(&self, other: &Self) -> bool {
 37        self.name == other.name && self.path == other.path
 38    }
 39}
 40
 41impl Eq for LocalKernelSpecification {}
 42
 43impl LocalKernelSpecification {
 44    #[must_use]
 45    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 46        let argv = &self.kernelspec.argv;
 47
 48        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 49        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 50        anyhow::ensure!(
 51            argv.iter().any(|arg| arg == "{connection_file}"),
 52            "Missing 'connection_file' in argv in kernelspec {}",
 53            self.name
 54        );
 55
 56        let mut cmd = util::command::new_command(&argv[0]);
 57
 58        for arg in &argv[1..] {
 59            if arg == "{connection_file}" {
 60                cmd.arg(connection_path);
 61            } else {
 62                cmd.arg(arg);
 63            }
 64        }
 65
 66        if let Some(env) = &self.kernelspec.env {
 67            cmd.envs(env);
 68        }
 69
 70        Ok(cmd)
 71    }
 72}
 73
 74// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 75// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 76async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 77    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 78    addr_zeroport.set_port(0);
 79    let mut ports: [u16; 5] = [0; 5];
 80    for i in 0..5 {
 81        let listener = TcpListener::bind(addr_zeroport).await?;
 82        let addr = listener.local_addr()?;
 83        ports[i] = addr.port();
 84    }
 85    Ok(ports)
 86}
 87
 88pub struct NativeRunningKernel {
 89    pub process: util::command::Child,
 90    connection_path: PathBuf,
 91    _process_status_task: Option<Task<()>>,
 92    pub working_directory: PathBuf,
 93    pub request_tx: mpsc::Sender<JupyterMessage>,
 94    pub stdin_tx: mpsc::Sender<JupyterMessage>,
 95    pub execution_state: ExecutionState,
 96    pub kernel_info: Option<KernelInfoReply>,
 97}
 98
 99impl Debug for NativeRunningKernel {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("RunningKernel")
102            .field("process", &self.process)
103            .finish()
104    }
105}
106
107impl NativeRunningKernel {
108    pub fn new<S: KernelSession + 'static>(
109        kernel_specification: LocalKernelSpecification,
110        entity_id: EntityId,
111        working_directory: PathBuf,
112        fs: Arc<dyn Fs>,
113        // todo: convert to weak view
114        session: Entity<S>,
115        window: &mut Window,
116        cx: &mut App,
117    ) -> Task<Result<Box<dyn RunningKernel>>> {
118        window.spawn(cx, async move |cx| {
119            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
120            let ports = peek_ports(ip).await?;
121
122            let connection_info = ConnectionInfo {
123                transport: Transport::TCP,
124                ip: ip.to_string(),
125                stdin_port: ports[0],
126                control_port: ports[1],
127                hb_port: ports[2],
128                shell_port: ports[3],
129                iopub_port: ports[4],
130                signature_scheme: "hmac-sha256".to_string(),
131                key: uuid::Uuid::new_v4().to_string(),
132                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
133            };
134
135            let runtime_dir = dirs::runtime_dir();
136            fs.create_dir(&runtime_dir)
137                .await
138                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
139            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
140            let content = serde_json::to_string(&connection_info)?;
141            fs.atomic_write(connection_path.clone(), content).await?;
142
143            let mut cmd = kernel_specification.command(&connection_path)?;
144
145            let mut process = cmd
146                .current_dir(&working_directory)
147                .stdout(util::command::Stdio::piped())
148                .stderr(util::command::Stdio::piped())
149                .stdin(util::command::Stdio::piped())
150                .kill_on_drop(true)
151                .spawn()
152                .context("failed to start the kernel process")?;
153
154            let session_id = Uuid::new_v4().to_string();
155
156            let iopub_socket =
157                runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
158                    .await?;
159            let control_socket =
160                runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
161
162            let peer_identity = runtimelib::peer_identity_for_session(&session_id)?;
163            let shell_socket =
164                runtimelib::create_client_shell_connection_with_identity(
165                    &connection_info,
166                    &session_id,
167                    peer_identity.clone(),
168                )
169                .await?;
170            let stdin_socket = runtimelib::create_client_stdin_connection_with_identity(
171                &connection_info,
172                &session_id,
173                peer_identity,
174            )
175            .await?;
176
177            let (mut shell_send, shell_recv) = shell_socket.split();
178            let (mut control_send, control_recv) = control_socket.split();
179            let (mut stdin_send, stdin_recv) = stdin_socket.split();
180
181            let (request_tx, mut request_rx) =
182                futures::channel::mpsc::channel::<JupyterMessage>(100);
183            let (stdin_tx, mut stdin_rx) =
184                futures::channel::mpsc::channel::<JupyterMessage>(100);
185
186            let recv_task = cx.spawn({
187                let session = session.clone();
188                let mut iopub = iopub_socket;
189                let mut shell = shell_recv;
190                let mut control = control_recv;
191                let mut stdin = stdin_recv;
192
193                async move |cx| -> anyhow::Result<()> {
194                    loop {
195                        let (channel, result) = futures::select! {
196                            msg = iopub.read().fuse() => ("iopub", msg),
197                            msg = shell.read().fuse() => ("shell", msg),
198                            msg = control.read().fuse() => ("control", msg),
199                            msg = stdin.read().fuse() => ("stdin", msg),
200                        };
201                        match result {
202                            Ok(message) => {
203                                session
204                                    .update_in(cx, |session, window, cx| {
205                                        session.route(&message, window, cx);
206                                    })
207                                    .ok();
208                            }
209                            Err(
210                                ref err @ (RuntimeError::ParseError { .. }
211                                | RuntimeError::SerdeError(_)),
212                            ) => {
213                                let error_detail =
214                                    format!("Kernel issue on {channel} channel\n\n{err}");
215                                log::warn!("kernel: {error_detail}");
216                                let workspace_window = session
217                                    .update_in(cx, |_, window, _cx| {
218                                        window
219                                            .window_handle()
220                                            .downcast::<workspace::Workspace>()
221                                    })
222                                    .ok()
223                                    .flatten();
224                                if let Some(workspace_window) = workspace_window {
225                                    workspace_window
226                                        .update(cx, |workspace, _window, cx| {
227                                            struct KernelReadError;
228                                            workspace.show_toast(
229                                                workspace::Toast::new(
230                                                    workspace::notifications::NotificationId::unique::<KernelReadError>(),
231                                                    error_detail.clone(),
232                                                )
233                                                .on_click(
234                                                    "Copy Error",
235                                                    move |_window, cx| {
236                                                        cx.write_to_clipboard(
237                                                            ClipboardItem::new_string(
238                                                                error_detail.clone(),
239                                                            ),
240                                                        );
241                                                    },
242                                                ),
243                                                cx,
244                                            );
245                                        })
246                                        .ok();
247                                }
248                            }
249                            Err(err) => {
250                                anyhow::bail!("{channel} recv: {err}");
251                            }
252                        }
253                    }
254                }
255            });
256
257            let routing_task = cx.background_spawn({
258                async move {
259                    while let Some(message) = request_rx.next().await {
260                        match message.content {
261                            JupyterMessageContent::DebugRequest(_)
262                            | JupyterMessageContent::InterruptRequest(_)
263                            | JupyterMessageContent::ShutdownRequest(_) => {
264                                control_send.send(message).await?;
265                            }
266                            _ => {
267                                shell_send.send(message).await?;
268                            }
269                        }
270                    }
271                    anyhow::Ok(())
272                }
273            });
274
275            let stdin_routing_task = cx.background_spawn({
276                async move {
277                    while let Some(message) = stdin_rx.next().await {
278                        stdin_send.send(message).await?;
279                    }
280                    anyhow::Ok(())
281                }
282            });
283
284            let stderr = process.stderr.take();
285            let stdout = process.stdout.take();
286
287            cx.spawn(async move |_cx| {
288                use futures::future::Either;
289
290                let stderr_lines = match stderr {
291                    Some(s) => Either::Left(
292                        BufReader::new(s)
293                            .lines()
294                            .map(|line| (log::Level::Error, line)),
295                    ),
296                    None => Either::Right(futures::stream::empty()),
297                };
298                let stdout_lines = match stdout {
299                    Some(s) => Either::Left(
300                        BufReader::new(s)
301                            .lines()
302                            .map(|line| (log::Level::Info, line)),
303                    ),
304                    None => Either::Right(futures::stream::empty()),
305                };
306                let mut lines = futures::stream::select(stderr_lines, stdout_lines);
307                while let Some((level, Ok(line))) = lines.next().await {
308                    log::log!(level, "kernel: {}", line);
309                }
310            })
311            .detach();
312
313            cx.spawn({
314                let session = session.clone();
315                async move |cx| {
316                    async fn with_name(
317                        name: &'static str,
318                        task: Task<Result<()>>,
319                    ) -> (&'static str, Result<()>) {
320                        (name, task.await)
321                    }
322
323                    let mut tasks = FuturesUnordered::new();
324                    tasks.push(with_name("recv task", recv_task));
325                    tasks.push(with_name("routing task", routing_task));
326                    tasks.push(with_name("stdin routing task", stdin_routing_task));
327
328                    while let Some((name, result)) = tasks.next().await {
329                        if let Err(err) = result {
330                            log::error!("kernel: handling failed for {name}: {err:?}");
331
332                            session.update(cx, |session, cx| {
333                                session.kernel_errored(
334                                    format!("handling failed for {name}: {err}"),
335                                    cx,
336                                );
337                                cx.notify();
338                            });
339                        }
340                    }
341                }
342            })
343            .detach();
344
345            let status = process.status();
346
347            let process_status_task = cx.spawn(async move |cx| {
348                let error_message = match status.await {
349                    Ok(status) => {
350                        if status.success() {
351                            log::info!("kernel process exited successfully");
352                            return;
353                        }
354
355                        format!("kernel process exited with status: {:?}", status)
356                    }
357                    Err(err) => {
358                        format!("kernel process exited with error: {:?}", err)
359                    }
360                };
361
362                log::error!("{}", error_message);
363
364                session.update(cx, |session, cx| {
365                    session.kernel_errored(error_message, cx);
366
367                    cx.notify();
368                });
369            });
370
371            anyhow::Ok(Box::new(Self {
372                process,
373                request_tx,
374                stdin_tx,
375                working_directory,
376                _process_status_task: Some(process_status_task),
377                connection_path,
378                execution_state: ExecutionState::Idle,
379                kernel_info: None,
380            }) as Box<dyn RunningKernel>)
381        })
382    }
383}
384
385impl RunningKernel for NativeRunningKernel {
386    fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
387        self.request_tx.clone()
388    }
389
390    fn stdin_tx(&self) -> mpsc::Sender<JupyterMessage> {
391        self.stdin_tx.clone()
392    }
393
394    fn working_directory(&self) -> &PathBuf {
395        &self.working_directory
396    }
397
398    fn execution_state(&self) -> &ExecutionState {
399        &self.execution_state
400    }
401
402    fn set_execution_state(&mut self, state: ExecutionState) {
403        self.execution_state = state;
404    }
405
406    fn kernel_info(&self) -> Option<&KernelInfoReply> {
407        self.kernel_info.as_ref()
408    }
409
410    fn set_kernel_info(&mut self, info: KernelInfoReply) {
411        self.kernel_info = Some(info);
412    }
413
414    fn force_shutdown(&mut self, _window: &mut Window, _cx: &mut App) -> Task<anyhow::Result<()>> {
415        self.kill();
416        Task::ready(Ok(()))
417    }
418
419    fn kill(&mut self) {
420        self._process_status_task.take();
421        self.request_tx.close_channel();
422        self.stdin_tx.close_channel();
423        self.process.kill().ok();
424    }
425}
426
427impl Drop for NativeRunningKernel {
428    fn drop(&mut self) {
429        std::fs::remove_file(&self.connection_path).ok();
430        self.kill();
431    }
432}
433
434async fn read_kernelspec_at(
435    // Path should be a directory to a jupyter kernelspec, as in
436    // /usr/local/share/jupyter/kernels/python3
437    kernel_dir: PathBuf,
438    fs: &dyn Fs,
439) -> Result<LocalKernelSpecification> {
440    let path = kernel_dir;
441    let kernel_name = if let Some(kernel_name) = path.file_name() {
442        kernel_name.to_string_lossy().into_owned()
443    } else {
444        anyhow::bail!("Invalid kernelspec directory: {path:?}");
445    };
446
447    if !fs.is_dir(path.as_path()).await {
448        anyhow::bail!("Not a directory: {path:?}");
449    }
450
451    let expected_kernel_json = path.join("kernel.json");
452    let spec = fs.load(expected_kernel_json.as_path()).await?;
453    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
454
455    Ok(LocalKernelSpecification {
456        name: kernel_name,
457        path,
458        kernelspec: spec,
459    })
460}
461
462/// Read a directory of kernelspec directories
463async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
464    let mut kernelspec_dirs = fs.read_dir(&path).await?;
465
466    let mut valid_kernelspecs = Vec::new();
467    while let Some(path) = kernelspec_dirs.next().await {
468        match path {
469            Ok(path) => {
470                if fs.is_dir(path.as_path()).await
471                    && let Ok(kernelspec) = read_kernelspec_at(path, fs).await
472                {
473                    valid_kernelspecs.push(kernelspec);
474                }
475            }
476            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
477        }
478    }
479
480    Ok(valid_kernelspecs)
481}
482
483pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
484    let mut data_dirs = dirs::data_dirs();
485
486    // Pick up any kernels from conda or conda environment
487    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
488        let conda_prefix = PathBuf::from(conda_prefix);
489        let conda_data_dir = conda_prefix.join("share").join("jupyter");
490        data_dirs.push(conda_data_dir);
491    }
492
493    // Search for kernels inside the base python environment
494    let command = util::command::new_command("python")
495        .arg("-c")
496        .arg("import sys; print(sys.prefix)")
497        .output()
498        .await;
499
500    if let Ok(command) = command
501        && command.status.success()
502    {
503        let python_prefix = String::from_utf8(command.stdout);
504        if let Ok(python_prefix) = python_prefix {
505            let python_prefix = PathBuf::from(python_prefix.trim());
506            let python_data_dir = python_prefix.join("share").join("jupyter");
507            data_dirs.push(python_data_dir);
508        }
509    }
510
511    let kernel_dirs = data_dirs
512        .iter()
513        .map(|dir| dir.join("kernels"))
514        .map(|path| read_kernels_dir(path, fs.as_ref()))
515        .collect::<Vec<_>>();
516
517    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
518    let kernel_dirs = kernel_dirs
519        .into_iter()
520        .filter_map(Result::ok)
521        .flatten()
522        .collect::<Vec<_>>();
523
524    Ok(kernel_dirs)
525}
526
527#[cfg(test)]
528mod test {
529    use super::*;
530    use std::path::PathBuf;
531
532    use gpui::TestAppContext;
533    use project::FakeFs;
534    use serde_json::json;
535
536    #[gpui::test]
537    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
538        let fs = FakeFs::new(cx.executor());
539        fs.insert_tree(
540            "/jupyter",
541            json!({
542                ".zed": {
543                    "settings.json": r#"{ "tab_size": 8 }"#,
544                    "tasks.json": r#"[{
545                        "label": "cargo check",
546                        "command": "cargo",
547                        "args": ["check", "--all"]
548                    },]"#,
549                },
550                "kernels": {
551                    "python": {
552                        "kernel.json": r#"{
553                            "display_name": "Python 3",
554                            "language": "python",
555                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
556                            "env": {}
557                        }"#
558                    },
559                    "deno": {
560                        "kernel.json": r#"{
561                            "display_name": "Deno",
562                            "language": "typescript",
563                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
564                            "env": {}
565                        }"#
566                    }
567                },
568            }),
569        )
570        .await;
571
572        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
573            .await
574            .unwrap();
575
576        kernels.sort_by(|a, b| a.name.cmp(&b.name));
577
578        assert_eq!(
579            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
580            vec!["deno", "python"]
581        );
582    }
583}