native_kernel.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    AsyncBufReadExt as _, SinkExt as _,
  4    channel::mpsc::{self},
  5    io::BufReader,
  6    stream::{FuturesUnordered, SelectAll, StreamExt},
  7};
  8use gpui::{App, AppContext as _, Entity, EntityId, Task, Window};
  9use jupyter_protocol::{
 10    ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent, KernelInfoReply,
 11    connection_info::{ConnectionInfo, Transport},
 12};
 13use project::Fs;
 14use runtimelib::dirs;
 15use smol::{net::TcpListener, process::Command};
 16use std::{
 17    env,
 18    fmt::Debug,
 19    net::{IpAddr, Ipv4Addr, SocketAddr},
 20    path::PathBuf,
 21    sync::Arc,
 22};
 23use uuid::Uuid;
 24
 25use super::{KernelSession, RunningKernel};
 26
 27#[derive(Debug, Clone)]
 28pub struct LocalKernelSpecification {
 29    pub name: String,
 30    pub path: PathBuf,
 31    pub kernelspec: JupyterKernelspec,
 32}
 33
 34impl PartialEq for LocalKernelSpecification {
 35    fn eq(&self, other: &Self) -> bool {
 36        self.name == other.name && self.path == other.path
 37    }
 38}
 39
 40impl Eq for LocalKernelSpecification {}
 41
 42impl LocalKernelSpecification {
 43    #[must_use]
 44    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 45        let argv = &self.kernelspec.argv;
 46
 47        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 48        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 49        anyhow::ensure!(
 50            argv.iter().any(|arg| arg == "{connection_file}"),
 51            "Missing 'connection_file' in argv in kernelspec {}",
 52            self.name
 53        );
 54
 55        let mut cmd = util::command::new_smol_command(&argv[0]);
 56
 57        for arg in &argv[1..] {
 58            if arg == "{connection_file}" {
 59                cmd.arg(connection_path);
 60            } else {
 61                cmd.arg(arg);
 62            }
 63        }
 64
 65        if let Some(env) = &self.kernelspec.env {
 66            cmd.envs(env);
 67        }
 68
 69        Ok(cmd)
 70    }
 71}
 72
 73// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 74// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 75async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 76    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 77    addr_zeroport.set_port(0);
 78    let mut ports: [u16; 5] = [0; 5];
 79    for i in 0..5 {
 80        let listener = TcpListener::bind(addr_zeroport).await?;
 81        let addr = listener.local_addr()?;
 82        ports[i] = addr.port();
 83    }
 84    Ok(ports)
 85}
 86
 87pub struct NativeRunningKernel {
 88    pub process: smol::process::Child,
 89    connection_path: PathBuf,
 90    _process_status_task: Option<Task<()>>,
 91    pub working_directory: PathBuf,
 92    pub request_tx: mpsc::Sender<JupyterMessage>,
 93    pub execution_state: ExecutionState,
 94    pub kernel_info: Option<KernelInfoReply>,
 95}
 96
 97impl Debug for NativeRunningKernel {
 98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 99        f.debug_struct("RunningKernel")
100            .field("process", &self.process)
101            .finish()
102    }
103}
104
105impl NativeRunningKernel {
106    pub fn new<S: KernelSession + 'static>(
107        kernel_specification: LocalKernelSpecification,
108        entity_id: EntityId,
109        working_directory: PathBuf,
110        fs: Arc<dyn Fs>,
111        // todo: convert to weak view
112        session: Entity<S>,
113        window: &mut Window,
114        cx: &mut App,
115    ) -> Task<Result<Box<dyn RunningKernel>>> {
116        window.spawn(cx, async move |cx| {
117            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
118            let ports = peek_ports(ip).await?;
119
120            let connection_info = ConnectionInfo {
121                transport: Transport::TCP,
122                ip: ip.to_string(),
123                stdin_port: ports[0],
124                control_port: ports[1],
125                hb_port: ports[2],
126                shell_port: ports[3],
127                iopub_port: ports[4],
128                signature_scheme: "hmac-sha256".to_string(),
129                key: uuid::Uuid::new_v4().to_string(),
130                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
131            };
132
133            let runtime_dir = dirs::runtime_dir();
134            fs.create_dir(&runtime_dir)
135                .await
136                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
137            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
138            let content = serde_json::to_string(&connection_info)?;
139            fs.atomic_write(connection_path.clone(), content).await?;
140
141            let mut cmd = kernel_specification.command(&connection_path)?;
142
143            let mut process = cmd
144                .current_dir(&working_directory)
145                .stdout(std::process::Stdio::piped())
146                .stderr(std::process::Stdio::piped())
147                .stdin(std::process::Stdio::piped())
148                .kill_on_drop(true)
149                .spawn()
150                .context("failed to start the kernel process")?;
151
152            let session_id = Uuid::new_v4().to_string();
153
154            let mut iopub_socket =
155                runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
156                    .await?;
157            let mut shell_socket =
158                runtimelib::create_client_shell_connection(&connection_info, &session_id).await?;
159            let mut control_socket =
160                runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
161
162            let (request_tx, mut request_rx) =
163                futures::channel::mpsc::channel::<JupyterMessage>(100);
164
165            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
166            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
167
168            let mut messages_rx = SelectAll::new();
169            messages_rx.push(control_reply_rx);
170            messages_rx.push(shell_reply_rx);
171
172            cx.spawn({
173                let session = session.clone();
174
175                async move |cx| {
176                    while let Some(message) = messages_rx.next().await {
177                        session
178                            .update_in(cx, |session, window, cx| {
179                                session.route(&message, window, cx);
180                            })
181                            .ok();
182                    }
183                }
184            })
185            .detach();
186
187            // iopub task
188            let iopub_task = cx.spawn({
189                let session = session.clone();
190
191                async move |cx| -> anyhow::Result<()> {
192                    loop {
193                        let message = iopub_socket.read().await?;
194                        session
195                            .update_in(cx, |session, window, cx| {
196                                session.route(&message, window, cx);
197                            })
198                            .ok();
199                    }
200                }
201            });
202
203            let (mut control_request_tx, mut control_request_rx) =
204                futures::channel::mpsc::channel(100);
205            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
206
207            let routing_task = cx.background_spawn({
208                async move {
209                    while let Some(message) = request_rx.next().await {
210                        match message.content {
211                            JupyterMessageContent::DebugRequest(_)
212                            | JupyterMessageContent::InterruptRequest(_)
213                            | JupyterMessageContent::ShutdownRequest(_) => {
214                                control_request_tx.send(message).await?;
215                            }
216                            _ => {
217                                shell_request_tx.send(message).await?;
218                            }
219                        }
220                    }
221                    anyhow::Ok(())
222                }
223            });
224
225            let shell_task = cx.background_spawn({
226                async move {
227                    while let Some(message) = shell_request_rx.next().await {
228                        shell_socket.send(message).await.ok();
229                        let reply = shell_socket.read().await?;
230                        shell_reply_tx.send(reply).await?;
231                    }
232                    anyhow::Ok(())
233                }
234            });
235
236            let control_task = cx.background_spawn({
237                async move {
238                    while let Some(message) = control_request_rx.next().await {
239                        control_socket.send(message).await.ok();
240                        let reply = control_socket.read().await?;
241                        control_reply_tx.send(reply).await?;
242                    }
243                    anyhow::Ok(())
244                }
245            });
246
247            let stderr = process.stderr.take();
248
249            cx.spawn(async move |_cx| {
250                if stderr.is_none() {
251                    return;
252                }
253                let reader = BufReader::new(stderr.unwrap());
254                let mut lines = reader.lines();
255                while let Some(Ok(line)) = lines.next().await {
256                    log::error!("kernel: {}", line);
257                }
258            })
259            .detach();
260
261            let stdout = process.stdout.take();
262
263            cx.spawn(async move |_cx| {
264                if stdout.is_none() {
265                    return;
266                }
267                let reader = BufReader::new(stdout.unwrap());
268                let mut lines = reader.lines();
269                while let Some(Ok(line)) = lines.next().await {
270                    log::info!("kernel: {}", line);
271                }
272            })
273            .detach();
274
275            cx.spawn({
276                let session = session.clone();
277                async move |cx| {
278                    async fn with_name(
279                        name: &'static str,
280                        task: Task<Result<()>>,
281                    ) -> (&'static str, Result<()>) {
282                        (name, task.await)
283                    }
284
285                    let mut tasks = FuturesUnordered::new();
286                    tasks.push(with_name("iopub task", iopub_task));
287                    tasks.push(with_name("shell task", shell_task));
288                    tasks.push(with_name("control task", control_task));
289                    tasks.push(with_name("routing task", routing_task));
290
291                    while let Some((name, result)) = tasks.next().await {
292                        if let Err(err) = result {
293                            log::error!("kernel: handling failed for {name}: {err:?}");
294
295                            session.update(cx, |session, cx| {
296                                session.kernel_errored(
297                                    format!("handling failed for {name}: {err}"),
298                                    cx,
299                                );
300                                cx.notify();
301                            });
302                        }
303                    }
304                }
305            })
306            .detach();
307
308            let status = process.status();
309
310            let process_status_task = cx.spawn(async move |cx| {
311                let error_message = match status.await {
312                    Ok(status) => {
313                        if status.success() {
314                            log::info!("kernel process exited successfully");
315                            return;
316                        }
317
318                        format!("kernel process exited with status: {:?}", status)
319                    }
320                    Err(err) => {
321                        format!("kernel process exited with error: {:?}", err)
322                    }
323                };
324
325                log::error!("{}", error_message);
326
327                session.update(cx, |session, cx| {
328                    session.kernel_errored(error_message, cx);
329
330                    cx.notify();
331                });
332            });
333
334            anyhow::Ok(Box::new(Self {
335                process,
336                request_tx,
337                working_directory,
338                _process_status_task: Some(process_status_task),
339                connection_path,
340                execution_state: ExecutionState::Idle,
341                kernel_info: None,
342            }) as Box<dyn RunningKernel>)
343        })
344    }
345}
346
347impl RunningKernel for NativeRunningKernel {
348    fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
349        self.request_tx.clone()
350    }
351
352    fn working_directory(&self) -> &PathBuf {
353        &self.working_directory
354    }
355
356    fn execution_state(&self) -> &ExecutionState {
357        &self.execution_state
358    }
359
360    fn set_execution_state(&mut self, state: ExecutionState) {
361        self.execution_state = state;
362    }
363
364    fn kernel_info(&self) -> Option<&KernelInfoReply> {
365        self.kernel_info.as_ref()
366    }
367
368    fn set_kernel_info(&mut self, info: KernelInfoReply) {
369        self.kernel_info = Some(info);
370    }
371
372    fn force_shutdown(&mut self, _window: &mut Window, _cx: &mut App) -> Task<anyhow::Result<()>> {
373        self._process_status_task.take();
374        self.request_tx.close_channel();
375        Task::ready(self.process.kill().context("killing the kernel process"))
376    }
377}
378
379impl Drop for NativeRunningKernel {
380    fn drop(&mut self) {
381        std::fs::remove_file(&self.connection_path).ok();
382        self.request_tx.close_channel();
383        self.process.kill().ok();
384    }
385}
386
387async fn read_kernelspec_at(
388    // Path should be a directory to a jupyter kernelspec, as in
389    // /usr/local/share/jupyter/kernels/python3
390    kernel_dir: PathBuf,
391    fs: &dyn Fs,
392) -> Result<LocalKernelSpecification> {
393    let path = kernel_dir;
394    let kernel_name = if let Some(kernel_name) = path.file_name() {
395        kernel_name.to_string_lossy().into_owned()
396    } else {
397        anyhow::bail!("Invalid kernelspec directory: {path:?}");
398    };
399
400    if !fs.is_dir(path.as_path()).await {
401        anyhow::bail!("Not a directory: {path:?}");
402    }
403
404    let expected_kernel_json = path.join("kernel.json");
405    let spec = fs.load(expected_kernel_json.as_path()).await?;
406    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
407
408    Ok(LocalKernelSpecification {
409        name: kernel_name,
410        path,
411        kernelspec: spec,
412    })
413}
414
415/// Read a directory of kernelspec directories
416async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
417    let mut kernelspec_dirs = fs.read_dir(&path).await?;
418
419    let mut valid_kernelspecs = Vec::new();
420    while let Some(path) = kernelspec_dirs.next().await {
421        match path {
422            Ok(path) => {
423                if fs.is_dir(path.as_path()).await
424                    && let Ok(kernelspec) = read_kernelspec_at(path, fs).await
425                {
426                    valid_kernelspecs.push(kernelspec);
427                }
428            }
429            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
430        }
431    }
432
433    Ok(valid_kernelspecs)
434}
435
436pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
437    let mut data_dirs = dirs::data_dirs();
438
439    // Pick up any kernels from conda or conda environment
440    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
441        let conda_prefix = PathBuf::from(conda_prefix);
442        let conda_data_dir = conda_prefix.join("share").join("jupyter");
443        data_dirs.push(conda_data_dir);
444    }
445
446    // Search for kernels inside the base python environment
447    let command = util::command::new_smol_command("python")
448        .arg("-c")
449        .arg("import sys; print(sys.prefix)")
450        .output()
451        .await;
452
453    if let Ok(command) = command
454        && command.status.success()
455    {
456        let python_prefix = String::from_utf8(command.stdout);
457        if let Ok(python_prefix) = python_prefix {
458            let python_prefix = PathBuf::from(python_prefix.trim());
459            let python_data_dir = python_prefix.join("share").join("jupyter");
460            data_dirs.push(python_data_dir);
461        }
462    }
463
464    let kernel_dirs = data_dirs
465        .iter()
466        .map(|dir| dir.join("kernels"))
467        .map(|path| read_kernels_dir(path, fs.as_ref()))
468        .collect::<Vec<_>>();
469
470    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
471    let kernel_dirs = kernel_dirs
472        .into_iter()
473        .filter_map(Result::ok)
474        .flatten()
475        .collect::<Vec<_>>();
476
477    Ok(kernel_dirs)
478}
479
480#[cfg(test)]
481mod test {
482    use super::*;
483    use std::path::PathBuf;
484
485    use gpui::TestAppContext;
486    use project::FakeFs;
487    use serde_json::json;
488
489    #[gpui::test]
490    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
491        let fs = FakeFs::new(cx.executor());
492        fs.insert_tree(
493            "/jupyter",
494            json!({
495                ".zed": {
496                    "settings.json": r#"{ "tab_size": 8 }"#,
497                    "tasks.json": r#"[{
498                        "label": "cargo check",
499                        "command": "cargo",
500                        "args": ["check", "--all"]
501                    },]"#,
502                },
503                "kernels": {
504                    "python": {
505                        "kernel.json": r#"{
506                            "display_name": "Python 3",
507                            "language": "python",
508                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
509                            "env": {}
510                        }"#
511                    },
512                    "deno": {
513                        "kernel.json": r#"{
514                            "display_name": "Deno",
515                            "language": "typescript",
516                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
517                            "env": {}
518                        }"#
519                    }
520                },
521            }),
522        )
523        .await;
524
525        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
526            .await
527            .unwrap();
528
529        kernels.sort_by(|a, b| a.name.cmp(&b.name));
530
531        assert_eq!(
532            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
533            vec!["deno", "python"]
534        );
535    }
536}