native_kernel.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    AsyncBufReadExt as _, SinkExt as _,
  4    channel::mpsc::{self},
  5    io::BufReader,
  6    stream::{FuturesUnordered, SelectAll, StreamExt},
  7};
  8use gpui::{App, AppContext as _, Entity, EntityId, Task, Window};
  9use jupyter_protocol::{
 10    ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent, KernelInfoReply,
 11    connection_info::{ConnectionInfo, Transport},
 12};
 13use project::Fs;
 14use runtimelib::dirs;
 15use smol::{net::TcpListener, process::Command};
 16use std::{
 17    env,
 18    fmt::Debug,
 19    net::{IpAddr, Ipv4Addr, SocketAddr},
 20    path::PathBuf,
 21    sync::Arc,
 22};
 23use uuid::Uuid;
 24
 25use crate::Session;
 26
 27use super::RunningKernel;
 28
 29#[derive(Debug, Clone)]
 30pub struct LocalKernelSpecification {
 31    pub name: String,
 32    pub path: PathBuf,
 33    pub kernelspec: JupyterKernelspec,
 34}
 35
 36impl PartialEq for LocalKernelSpecification {
 37    fn eq(&self, other: &Self) -> bool {
 38        self.name == other.name && self.path == other.path
 39    }
 40}
 41
 42impl Eq for LocalKernelSpecification {}
 43
 44impl LocalKernelSpecification {
 45    #[must_use]
 46    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 47        let argv = &self.kernelspec.argv;
 48
 49        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 50        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 51        anyhow::ensure!(
 52            argv.iter().any(|arg| arg == "{connection_file}"),
 53            "Missing 'connection_file' in argv in kernelspec {}",
 54            self.name
 55        );
 56
 57        let mut cmd = util::command::new_smol_command(&argv[0]);
 58
 59        for arg in &argv[1..] {
 60            if arg == "{connection_file}" {
 61                cmd.arg(connection_path);
 62            } else {
 63                cmd.arg(arg);
 64            }
 65        }
 66
 67        if let Some(env) = &self.kernelspec.env {
 68            cmd.envs(env);
 69        }
 70
 71        Ok(cmd)
 72    }
 73}
 74
 75// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 76// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 77async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 78    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 79    addr_zeroport.set_port(0);
 80    let mut ports: [u16; 5] = [0; 5];
 81    for i in 0..5 {
 82        let listener = TcpListener::bind(addr_zeroport).await?;
 83        let addr = listener.local_addr()?;
 84        ports[i] = addr.port();
 85    }
 86    Ok(ports)
 87}
 88
 89pub struct NativeRunningKernel {
 90    pub process: smol::process::Child,
 91    connection_path: PathBuf,
 92    _process_status_task: Option<Task<()>>,
 93    pub working_directory: PathBuf,
 94    pub request_tx: mpsc::Sender<JupyterMessage>,
 95    pub execution_state: ExecutionState,
 96    pub kernel_info: Option<KernelInfoReply>,
 97}
 98
 99impl Debug for NativeRunningKernel {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("RunningKernel")
102            .field("process", &self.process)
103            .finish()
104    }
105}
106
107impl NativeRunningKernel {
108    pub fn new(
109        kernel_specification: LocalKernelSpecification,
110        entity_id: EntityId,
111        working_directory: PathBuf,
112        fs: Arc<dyn Fs>,
113        // todo: convert to weak view
114        session: Entity<Session>,
115        window: &mut Window,
116        cx: &mut App,
117    ) -> Task<Result<Box<dyn RunningKernel>>> {
118        window.spawn(cx, async move |cx| {
119            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
120            let ports = peek_ports(ip).await?;
121
122            let connection_info = ConnectionInfo {
123                transport: Transport::TCP,
124                ip: ip.to_string(),
125                stdin_port: ports[0],
126                control_port: ports[1],
127                hb_port: ports[2],
128                shell_port: ports[3],
129                iopub_port: ports[4],
130                signature_scheme: "hmac-sha256".to_string(),
131                key: uuid::Uuid::new_v4().to_string(),
132                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
133            };
134
135            let runtime_dir = dirs::runtime_dir();
136            fs.create_dir(&runtime_dir)
137                .await
138                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
139            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
140            let content = serde_json::to_string(&connection_info)?;
141            fs.atomic_write(connection_path.clone(), content).await?;
142
143            let mut cmd = kernel_specification.command(&connection_path)?;
144
145            let mut process = cmd
146                .current_dir(&working_directory)
147                .stdout(std::process::Stdio::piped())
148                .stderr(std::process::Stdio::piped())
149                .stdin(std::process::Stdio::piped())
150                .kill_on_drop(true)
151                .spawn()
152                .context("failed to start the kernel process")?;
153
154            let session_id = Uuid::new_v4().to_string();
155
156            let mut iopub_socket =
157                runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
158                    .await?;
159            let mut shell_socket =
160                runtimelib::create_client_shell_connection(&connection_info, &session_id).await?;
161            let mut control_socket =
162                runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
163
164            let (request_tx, mut request_rx) =
165                futures::channel::mpsc::channel::<JupyterMessage>(100);
166
167            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
168            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
169
170            let mut messages_rx = SelectAll::new();
171            messages_rx.push(control_reply_rx);
172            messages_rx.push(shell_reply_rx);
173
174            cx.spawn({
175                let session = session.clone();
176
177                async move |cx| {
178                    while let Some(message) = messages_rx.next().await {
179                        session
180                            .update_in(cx, |session, window, cx| {
181                                session.route(&message, window, cx);
182                            })
183                            .ok();
184                    }
185                }
186            })
187            .detach();
188
189            // iopub task
190            let iopub_task = cx.spawn({
191                let session = session.clone();
192
193                async move |cx| -> anyhow::Result<()> {
194                    loop {
195                        let message = iopub_socket.read().await?;
196                        session
197                            .update_in(cx, |session, window, cx| {
198                                session.route(&message, window, cx);
199                            })
200                            .ok();
201                    }
202                }
203            });
204
205            let (mut control_request_tx, mut control_request_rx) =
206                futures::channel::mpsc::channel(100);
207            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
208
209            let routing_task = cx.background_spawn({
210                async move {
211                    while let Some(message) = request_rx.next().await {
212                        match message.content {
213                            JupyterMessageContent::DebugRequest(_)
214                            | JupyterMessageContent::InterruptRequest(_)
215                            | JupyterMessageContent::ShutdownRequest(_) => {
216                                control_request_tx.send(message).await?;
217                            }
218                            _ => {
219                                shell_request_tx.send(message).await?;
220                            }
221                        }
222                    }
223                    anyhow::Ok(())
224                }
225            });
226
227            let shell_task = cx.background_spawn({
228                async move {
229                    while let Some(message) = shell_request_rx.next().await {
230                        shell_socket.send(message).await.ok();
231                        let reply = shell_socket.read().await?;
232                        shell_reply_tx.send(reply).await?;
233                    }
234                    anyhow::Ok(())
235                }
236            });
237
238            let control_task = cx.background_spawn({
239                async move {
240                    while let Some(message) = control_request_rx.next().await {
241                        control_socket.send(message).await.ok();
242                        let reply = control_socket.read().await?;
243                        control_reply_tx.send(reply).await?;
244                    }
245                    anyhow::Ok(())
246                }
247            });
248
249            let stderr = process.stderr.take();
250
251            cx.spawn(async move |_cx| {
252                if stderr.is_none() {
253                    return;
254                }
255                let reader = BufReader::new(stderr.unwrap());
256                let mut lines = reader.lines();
257                while let Some(Ok(line)) = lines.next().await {
258                    log::error!("kernel: {}", line);
259                }
260            })
261            .detach();
262
263            let stdout = process.stdout.take();
264
265            cx.spawn(async move |_cx| {
266                if stdout.is_none() {
267                    return;
268                }
269                let reader = BufReader::new(stdout.unwrap());
270                let mut lines = reader.lines();
271                while let Some(Ok(line)) = lines.next().await {
272                    log::info!("kernel: {}", line);
273                }
274            })
275            .detach();
276
277            cx.spawn({
278                let session = session.clone();
279                async move |cx| {
280                    async fn with_name(
281                        name: &'static str,
282                        task: Task<Result<()>>,
283                    ) -> (&'static str, Result<()>) {
284                        (name, task.await)
285                    }
286
287                    let mut tasks = FuturesUnordered::new();
288                    tasks.push(with_name("iopub task", iopub_task));
289                    tasks.push(with_name("shell task", shell_task));
290                    tasks.push(with_name("control task", control_task));
291                    tasks.push(with_name("routing task", routing_task));
292
293                    while let Some((name, result)) = tasks.next().await {
294                        if let Err(err) = result {
295                            log::error!("kernel: handling failed for {name}: {err:?}");
296
297                            session.update(cx, |session, cx| {
298                                session.kernel_errored(
299                                    format!("handling failed for {name}: {err}"),
300                                    cx,
301                                );
302                                cx.notify();
303                            });
304                        }
305                    }
306                }
307            })
308            .detach();
309
310            let status = process.status();
311
312            let process_status_task = cx.spawn(async move |cx| {
313                let error_message = match status.await {
314                    Ok(status) => {
315                        if status.success() {
316                            log::info!("kernel process exited successfully");
317                            return;
318                        }
319
320                        format!("kernel process exited with status: {:?}", status)
321                    }
322                    Err(err) => {
323                        format!("kernel process exited with error: {:?}", err)
324                    }
325                };
326
327                log::error!("{}", error_message);
328
329                session.update(cx, |session, cx| {
330                    session.kernel_errored(error_message, cx);
331
332                    cx.notify();
333                });
334            });
335
336            anyhow::Ok(Box::new(Self {
337                process,
338                request_tx,
339                working_directory,
340                _process_status_task: Some(process_status_task),
341                connection_path,
342                execution_state: ExecutionState::Idle,
343                kernel_info: None,
344            }) as Box<dyn RunningKernel>)
345        })
346    }
347}
348
349impl RunningKernel for NativeRunningKernel {
350    fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
351        self.request_tx.clone()
352    }
353
354    fn working_directory(&self) -> &PathBuf {
355        &self.working_directory
356    }
357
358    fn execution_state(&self) -> &ExecutionState {
359        &self.execution_state
360    }
361
362    fn set_execution_state(&mut self, state: ExecutionState) {
363        self.execution_state = state;
364    }
365
366    fn kernel_info(&self) -> Option<&KernelInfoReply> {
367        self.kernel_info.as_ref()
368    }
369
370    fn set_kernel_info(&mut self, info: KernelInfoReply) {
371        self.kernel_info = Some(info);
372    }
373
374    fn force_shutdown(&mut self, _window: &mut Window, _cx: &mut App) -> Task<anyhow::Result<()>> {
375        self._process_status_task.take();
376        self.request_tx.close_channel();
377        Task::ready(self.process.kill().context("killing the kernel process"))
378    }
379}
380
381impl Drop for NativeRunningKernel {
382    fn drop(&mut self) {
383        std::fs::remove_file(&self.connection_path).ok();
384        self.request_tx.close_channel();
385        self.process.kill().ok();
386    }
387}
388
389async fn read_kernelspec_at(
390    // Path should be a directory to a jupyter kernelspec, as in
391    // /usr/local/share/jupyter/kernels/python3
392    kernel_dir: PathBuf,
393    fs: &dyn Fs,
394) -> Result<LocalKernelSpecification> {
395    let path = kernel_dir;
396    let kernel_name = if let Some(kernel_name) = path.file_name() {
397        kernel_name.to_string_lossy().into_owned()
398    } else {
399        anyhow::bail!("Invalid kernelspec directory: {path:?}");
400    };
401
402    if !fs.is_dir(path.as_path()).await {
403        anyhow::bail!("Not a directory: {path:?}");
404    }
405
406    let expected_kernel_json = path.join("kernel.json");
407    let spec = fs.load(expected_kernel_json.as_path()).await?;
408    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
409
410    Ok(LocalKernelSpecification {
411        name: kernel_name,
412        path,
413        kernelspec: spec,
414    })
415}
416
417/// Read a directory of kernelspec directories
418async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
419    let mut kernelspec_dirs = fs.read_dir(&path).await?;
420
421    let mut valid_kernelspecs = Vec::new();
422    while let Some(path) = kernelspec_dirs.next().await {
423        match path {
424            Ok(path) => {
425                if fs.is_dir(path.as_path()).await
426                    && let Ok(kernelspec) = read_kernelspec_at(path, fs).await
427                {
428                    valid_kernelspecs.push(kernelspec);
429                }
430            }
431            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
432        }
433    }
434
435    Ok(valid_kernelspecs)
436}
437
438pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
439    let mut data_dirs = dirs::data_dirs();
440
441    // Pick up any kernels from conda or conda environment
442    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
443        let conda_prefix = PathBuf::from(conda_prefix);
444        let conda_data_dir = conda_prefix.join("share").join("jupyter");
445        data_dirs.push(conda_data_dir);
446    }
447
448    // Search for kernels inside the base python environment
449    let command = util::command::new_smol_command("python")
450        .arg("-c")
451        .arg("import sys; print(sys.prefix)")
452        .output()
453        .await;
454
455    if let Ok(command) = command
456        && command.status.success()
457    {
458        let python_prefix = String::from_utf8(command.stdout);
459        if let Ok(python_prefix) = python_prefix {
460            let python_prefix = PathBuf::from(python_prefix.trim());
461            let python_data_dir = python_prefix.join("share").join("jupyter");
462            data_dirs.push(python_data_dir);
463        }
464    }
465
466    let kernel_dirs = data_dirs
467        .iter()
468        .map(|dir| dir.join("kernels"))
469        .map(|path| read_kernels_dir(path, fs.as_ref()))
470        .collect::<Vec<_>>();
471
472    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
473    let kernel_dirs = kernel_dirs
474        .into_iter()
475        .filter_map(Result::ok)
476        .flatten()
477        .collect::<Vec<_>>();
478
479    Ok(kernel_dirs)
480}
481
482#[cfg(test)]
483mod test {
484    use super::*;
485    use std::path::PathBuf;
486
487    use gpui::TestAppContext;
488    use project::FakeFs;
489    use serde_json::json;
490
491    #[gpui::test]
492    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
493        let fs = FakeFs::new(cx.executor());
494        fs.insert_tree(
495            "/jupyter",
496            json!({
497                ".zed": {
498                    "settings.json": r#"{ "tab_size": 8 }"#,
499                    "tasks.json": r#"[{
500                        "label": "cargo check",
501                        "command": "cargo",
502                        "args": ["check", "--all"]
503                    },]"#,
504                },
505                "kernels": {
506                    "python": {
507                        "kernel.json": r#"{
508                            "display_name": "Python 3",
509                            "language": "python",
510                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
511                            "env": {}
512                        }"#
513                    },
514                    "deno": {
515                        "kernel.json": r#"{
516                            "display_name": "Deno",
517                            "language": "typescript",
518                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
519                            "env": {}
520                        }"#
521                    }
522                },
523            }),
524        )
525        .await;
526
527        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
528            .await
529            .unwrap();
530
531        kernels.sort_by(|a, b| a.name.cmp(&b.name));
532
533        assert_eq!(
534            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
535            vec!["deno", "python"]
536        );
537    }
538}