native_kernel.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    AsyncBufReadExt as _, FutureExt as _, StreamExt as _,
  4    channel::mpsc::{self},
  5    io::BufReader,
  6    stream::FuturesUnordered,
  7};
  8use gpui::{App, AppContext as _, Entity, EntityId, Task, Window};
  9use jupyter_protocol::{
 10    ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent, KernelInfoReply,
 11    connection_info::{ConnectionInfo, Transport},
 12};
 13use project::Fs;
 14use runtimelib::dirs;
 15use smol::{net::TcpListener, process::Command};
 16use std::{
 17    env,
 18    fmt::Debug,
 19    net::{IpAddr, Ipv4Addr, SocketAddr},
 20    path::PathBuf,
 21    sync::Arc,
 22};
 23use uuid::Uuid;
 24
 25use super::{KernelSession, RunningKernel};
 26
 27#[derive(Debug, Clone)]
 28pub struct LocalKernelSpecification {
 29    pub name: String,
 30    pub path: PathBuf,
 31    pub kernelspec: JupyterKernelspec,
 32}
 33
 34impl PartialEq for LocalKernelSpecification {
 35    fn eq(&self, other: &Self) -> bool {
 36        self.name == other.name && self.path == other.path
 37    }
 38}
 39
 40impl Eq for LocalKernelSpecification {}
 41
 42impl LocalKernelSpecification {
 43    #[must_use]
 44    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 45        let argv = &self.kernelspec.argv;
 46
 47        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 48        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 49        anyhow::ensure!(
 50            argv.iter().any(|arg| arg == "{connection_file}"),
 51            "Missing 'connection_file' in argv in kernelspec {}",
 52            self.name
 53        );
 54
 55        let mut cmd = util::command::new_smol_command(&argv[0]);
 56
 57        for arg in &argv[1..] {
 58            if arg == "{connection_file}" {
 59                cmd.arg(connection_path);
 60            } else {
 61                cmd.arg(arg);
 62            }
 63        }
 64
 65        if let Some(env) = &self.kernelspec.env {
 66            cmd.envs(env);
 67        }
 68
 69        Ok(cmd)
 70    }
 71}
 72
 73// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 74// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 75async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 76    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 77    addr_zeroport.set_port(0);
 78    let mut ports: [u16; 5] = [0; 5];
 79    for i in 0..5 {
 80        let listener = TcpListener::bind(addr_zeroport).await?;
 81        let addr = listener.local_addr()?;
 82        ports[i] = addr.port();
 83    }
 84    Ok(ports)
 85}
 86
 87pub struct NativeRunningKernel {
 88    pub process: smol::process::Child,
 89    connection_path: PathBuf,
 90    _process_status_task: Option<Task<()>>,
 91    pub working_directory: PathBuf,
 92    pub request_tx: mpsc::Sender<JupyterMessage>,
 93    pub execution_state: ExecutionState,
 94    pub kernel_info: Option<KernelInfoReply>,
 95}
 96
 97impl Debug for NativeRunningKernel {
 98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 99        f.debug_struct("RunningKernel")
100            .field("process", &self.process)
101            .finish()
102    }
103}
104
105impl NativeRunningKernel {
106    pub fn new<S: KernelSession + 'static>(
107        kernel_specification: LocalKernelSpecification,
108        entity_id: EntityId,
109        working_directory: PathBuf,
110        fs: Arc<dyn Fs>,
111        // todo: convert to weak view
112        session: Entity<S>,
113        window: &mut Window,
114        cx: &mut App,
115    ) -> Task<Result<Box<dyn RunningKernel>>> {
116        window.spawn(cx, async move |cx| {
117            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
118            let ports = peek_ports(ip).await?;
119
120            let connection_info = ConnectionInfo {
121                transport: Transport::TCP,
122                ip: ip.to_string(),
123                stdin_port: ports[0],
124                control_port: ports[1],
125                hb_port: ports[2],
126                shell_port: ports[3],
127                iopub_port: ports[4],
128                signature_scheme: "hmac-sha256".to_string(),
129                key: uuid::Uuid::new_v4().to_string(),
130                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
131            };
132
133            let runtime_dir = dirs::runtime_dir();
134            fs.create_dir(&runtime_dir)
135                .await
136                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
137            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
138            let content = serde_json::to_string(&connection_info)?;
139            fs.atomic_write(connection_path.clone(), content).await?;
140
141            let mut cmd = kernel_specification.command(&connection_path)?;
142
143            let mut process = cmd
144                .current_dir(&working_directory)
145                .stdout(std::process::Stdio::piped())
146                .stderr(std::process::Stdio::piped())
147                .stdin(std::process::Stdio::piped())
148                .kill_on_drop(true)
149                .spawn()
150                .context("failed to start the kernel process")?;
151
152            let session_id = Uuid::new_v4().to_string();
153
154            let iopub_socket =
155                runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
156                    .await?;
157            let shell_socket =
158                runtimelib::create_client_shell_connection(&connection_info, &session_id).await?;
159            let control_socket =
160                runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
161
162            let (mut shell_send, shell_recv) = shell_socket.split();
163            let (mut control_send, control_recv) = control_socket.split();
164
165            let (request_tx, mut request_rx) =
166                futures::channel::mpsc::channel::<JupyterMessage>(100);
167
168            let recv_task = cx.spawn({
169                let session = session.clone();
170                let mut iopub = iopub_socket;
171                let mut shell = shell_recv;
172                let mut control = control_recv;
173
174                async move |cx| -> anyhow::Result<()> {
175                    loop {
176                        let message = futures::select! {
177                            msg = iopub.read().fuse() => msg.context("iopub recv")?,
178                            msg = shell.read().fuse() => msg.context("shell recv")?,
179                            msg = control.read().fuse() => msg.context("control recv")?,
180                        };
181                        session
182                            .update_in(cx, |session, window, cx| {
183                                session.route(&message, window, cx);
184                            })
185                            .ok();
186                    }
187                }
188            });
189
190            let routing_task = cx.background_spawn({
191                async move {
192                    while let Some(message) = request_rx.next().await {
193                        match message.content {
194                            JupyterMessageContent::DebugRequest(_)
195                            | JupyterMessageContent::InterruptRequest(_)
196                            | JupyterMessageContent::ShutdownRequest(_) => {
197                                control_send.send(message).await?;
198                            }
199                            _ => {
200                                shell_send.send(message).await?;
201                            }
202                        }
203                    }
204                    anyhow::Ok(())
205                }
206            });
207
208            let stderr = process.stderr.take();
209            let stdout = process.stdout.take();
210
211            cx.spawn(async move |_cx| {
212                use futures::future::Either;
213
214                let stderr_lines = match stderr {
215                    Some(s) => Either::Left(
216                        BufReader::new(s)
217                            .lines()
218                            .map(|line| (log::Level::Error, line)),
219                    ),
220                    None => Either::Right(futures::stream::empty()),
221                };
222                let stdout_lines = match stdout {
223                    Some(s) => Either::Left(
224                        BufReader::new(s)
225                            .lines()
226                            .map(|line| (log::Level::Info, line)),
227                    ),
228                    None => Either::Right(futures::stream::empty()),
229                };
230                let mut lines = futures::stream::select(stderr_lines, stdout_lines);
231                while let Some((level, Ok(line))) = lines.next().await {
232                    log::log!(level, "kernel: {}", line);
233                }
234            })
235            .detach();
236
237            cx.spawn({
238                let session = session.clone();
239                async move |cx| {
240                    async fn with_name(
241                        name: &'static str,
242                        task: Task<Result<()>>,
243                    ) -> (&'static str, Result<()>) {
244                        (name, task.await)
245                    }
246
247                    let mut tasks = FuturesUnordered::new();
248                    tasks.push(with_name("recv task", recv_task));
249                    tasks.push(with_name("routing task", routing_task));
250
251                    while let Some((name, result)) = tasks.next().await {
252                        if let Err(err) = result {
253                            log::error!("kernel: handling failed for {name}: {err:?}");
254
255                            session.update(cx, |session, cx| {
256                                session.kernel_errored(
257                                    format!("handling failed for {name}: {err}"),
258                                    cx,
259                                );
260                                cx.notify();
261                            });
262                        }
263                    }
264                }
265            })
266            .detach();
267
268            let status = process.status();
269
270            let process_status_task = cx.spawn(async move |cx| {
271                let error_message = match status.await {
272                    Ok(status) => {
273                        if status.success() {
274                            log::info!("kernel process exited successfully");
275                            return;
276                        }
277
278                        format!("kernel process exited with status: {:?}", status)
279                    }
280                    Err(err) => {
281                        format!("kernel process exited with error: {:?}", err)
282                    }
283                };
284
285                log::error!("{}", error_message);
286
287                session.update(cx, |session, cx| {
288                    session.kernel_errored(error_message, cx);
289
290                    cx.notify();
291                });
292            });
293
294            anyhow::Ok(Box::new(Self {
295                process,
296                request_tx,
297                working_directory,
298                _process_status_task: Some(process_status_task),
299                connection_path,
300                execution_state: ExecutionState::Idle,
301                kernel_info: None,
302            }) as Box<dyn RunningKernel>)
303        })
304    }
305}
306
307impl RunningKernel for NativeRunningKernel {
308    fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
309        self.request_tx.clone()
310    }
311
312    fn working_directory(&self) -> &PathBuf {
313        &self.working_directory
314    }
315
316    fn execution_state(&self) -> &ExecutionState {
317        &self.execution_state
318    }
319
320    fn set_execution_state(&mut self, state: ExecutionState) {
321        self.execution_state = state;
322    }
323
324    fn kernel_info(&self) -> Option<&KernelInfoReply> {
325        self.kernel_info.as_ref()
326    }
327
328    fn set_kernel_info(&mut self, info: KernelInfoReply) {
329        self.kernel_info = Some(info);
330    }
331
332    fn force_shutdown(&mut self, _window: &mut Window, _cx: &mut App) -> Task<anyhow::Result<()>> {
333        self.kill();
334        Task::ready(Ok(()))
335    }
336
337    fn kill(&mut self) {
338        self._process_status_task.take();
339        self.request_tx.close_channel();
340        self.process.kill().ok();
341    }
342}
343
344impl Drop for NativeRunningKernel {
345    fn drop(&mut self) {
346        std::fs::remove_file(&self.connection_path).ok();
347        self.kill();
348    }
349}
350
351async fn read_kernelspec_at(
352    // Path should be a directory to a jupyter kernelspec, as in
353    // /usr/local/share/jupyter/kernels/python3
354    kernel_dir: PathBuf,
355    fs: &dyn Fs,
356) -> Result<LocalKernelSpecification> {
357    let path = kernel_dir;
358    let kernel_name = if let Some(kernel_name) = path.file_name() {
359        kernel_name.to_string_lossy().into_owned()
360    } else {
361        anyhow::bail!("Invalid kernelspec directory: {path:?}");
362    };
363
364    if !fs.is_dir(path.as_path()).await {
365        anyhow::bail!("Not a directory: {path:?}");
366    }
367
368    let expected_kernel_json = path.join("kernel.json");
369    let spec = fs.load(expected_kernel_json.as_path()).await?;
370    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
371
372    Ok(LocalKernelSpecification {
373        name: kernel_name,
374        path,
375        kernelspec: spec,
376    })
377}
378
379/// Read a directory of kernelspec directories
380async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
381    let mut kernelspec_dirs = fs.read_dir(&path).await?;
382
383    let mut valid_kernelspecs = Vec::new();
384    while let Some(path) = kernelspec_dirs.next().await {
385        match path {
386            Ok(path) => {
387                if fs.is_dir(path.as_path()).await
388                    && let Ok(kernelspec) = read_kernelspec_at(path, fs).await
389                {
390                    valid_kernelspecs.push(kernelspec);
391                }
392            }
393            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
394        }
395    }
396
397    Ok(valid_kernelspecs)
398}
399
400pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
401    let mut data_dirs = dirs::data_dirs();
402
403    // Pick up any kernels from conda or conda environment
404    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
405        let conda_prefix = PathBuf::from(conda_prefix);
406        let conda_data_dir = conda_prefix.join("share").join("jupyter");
407        data_dirs.push(conda_data_dir);
408    }
409
410    // Search for kernels inside the base python environment
411    let command = util::command::new_smol_command("python")
412        .arg("-c")
413        .arg("import sys; print(sys.prefix)")
414        .output()
415        .await;
416
417    if let Ok(command) = command
418        && command.status.success()
419    {
420        let python_prefix = String::from_utf8(command.stdout);
421        if let Ok(python_prefix) = python_prefix {
422            let python_prefix = PathBuf::from(python_prefix.trim());
423            let python_data_dir = python_prefix.join("share").join("jupyter");
424            data_dirs.push(python_data_dir);
425        }
426    }
427
428    let kernel_dirs = data_dirs
429        .iter()
430        .map(|dir| dir.join("kernels"))
431        .map(|path| read_kernels_dir(path, fs.as_ref()))
432        .collect::<Vec<_>>();
433
434    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
435    let kernel_dirs = kernel_dirs
436        .into_iter()
437        .filter_map(Result::ok)
438        .flatten()
439        .collect::<Vec<_>>();
440
441    Ok(kernel_dirs)
442}
443
444#[cfg(test)]
445mod test {
446    use super::*;
447    use std::path::PathBuf;
448
449    use gpui::TestAppContext;
450    use project::FakeFs;
451    use serde_json::json;
452
453    #[gpui::test]
454    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
455        let fs = FakeFs::new(cx.executor());
456        fs.insert_tree(
457            "/jupyter",
458            json!({
459                ".zed": {
460                    "settings.json": r#"{ "tab_size": 8 }"#,
461                    "tasks.json": r#"[{
462                        "label": "cargo check",
463                        "command": "cargo",
464                        "args": ["check", "--all"]
465                    },]"#,
466                },
467                "kernels": {
468                    "python": {
469                        "kernel.json": r#"{
470                            "display_name": "Python 3",
471                            "language": "python",
472                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
473                            "env": {}
474                        }"#
475                    },
476                    "deno": {
477                        "kernel.json": r#"{
478                            "display_name": "Deno",
479                            "language": "typescript",
480                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
481                            "env": {}
482                        }"#
483                    }
484                },
485            }),
486        )
487        .await;
488
489        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
490            .await
491            .unwrap();
492
493        kernels.sort_by(|a, b| a.name.cmp(&b.name));
494
495        assert_eq!(
496            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
497            vec!["deno", "python"]
498        );
499    }
500}