kernels.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    channel::mpsc::{self, Receiver},
  4    future::Shared,
  5    stream::{self, SelectAll, StreamExt},
  6    SinkExt as _,
  7};
  8use gpui::{AppContext, EntityId, Task};
  9use project::Fs;
 10use runtimelib::{
 11    dirs, ConnectionInfo, ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent,
 12    KernelInfoReply,
 13};
 14use smol::{net::TcpListener, process::Command};
 15use std::{
 16    env,
 17    fmt::Debug,
 18    net::{IpAddr, Ipv4Addr, SocketAddr},
 19    path::PathBuf,
 20    sync::Arc,
 21};
 22use uuid::Uuid;
 23
 24#[derive(Debug, Clone)]
 25pub struct KernelSpecification {
 26    pub name: String,
 27    pub path: PathBuf,
 28    pub kernelspec: JupyterKernelspec,
 29}
 30
 31impl KernelSpecification {
 32    #[must_use]
 33    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 34        let argv = &self.kernelspec.argv;
 35
 36        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 37        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 38        anyhow::ensure!(
 39            argv.iter().any(|arg| arg == "{connection_file}"),
 40            "Missing 'connection_file' in argv in kernelspec {}",
 41            self.name
 42        );
 43
 44        let mut cmd = Command::new(&argv[0]);
 45
 46        for arg in &argv[1..] {
 47            if arg == "{connection_file}" {
 48                cmd.arg(connection_path);
 49            } else {
 50                cmd.arg(arg);
 51            }
 52        }
 53
 54        if let Some(env) = &self.kernelspec.env {
 55            cmd.envs(env);
 56        }
 57
 58        #[cfg(windows)]
 59        {
 60            use smol::process::windows::CommandExt;
 61            cmd.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
 62        }
 63
 64        Ok(cmd)
 65    }
 66}
 67
 68// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 69// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 70async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 71    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 72    addr_zeroport.set_port(0);
 73    let mut ports: [u16; 5] = [0; 5];
 74    for i in 0..5 {
 75        let listener = TcpListener::bind(addr_zeroport).await?;
 76        let addr = listener.local_addr()?;
 77        ports[i] = addr.port();
 78    }
 79    Ok(ports)
 80}
 81
 82#[derive(Debug, Clone)]
 83pub enum KernelStatus {
 84    Idle,
 85    Busy,
 86    Starting,
 87    Error,
 88    ShuttingDown,
 89    Shutdown,
 90    Restarting,
 91}
 92
 93impl KernelStatus {
 94    pub fn is_connected(&self) -> bool {
 95        match self {
 96            KernelStatus::Idle | KernelStatus::Busy => true,
 97            _ => false,
 98        }
 99    }
100}
101
102impl ToString for KernelStatus {
103    fn to_string(&self) -> String {
104        match self {
105            KernelStatus::Idle => "Idle".to_string(),
106            KernelStatus::Busy => "Busy".to_string(),
107            KernelStatus::Starting => "Starting".to_string(),
108            KernelStatus::Error => "Error".to_string(),
109            KernelStatus::ShuttingDown => "Shutting Down".to_string(),
110            KernelStatus::Shutdown => "Shutdown".to_string(),
111            KernelStatus::Restarting => "Restarting".to_string(),
112        }
113    }
114}
115
116impl From<&Kernel> for KernelStatus {
117    fn from(kernel: &Kernel) -> Self {
118        match kernel {
119            Kernel::RunningKernel(kernel) => match kernel.execution_state {
120                ExecutionState::Idle => KernelStatus::Idle,
121                ExecutionState::Busy => KernelStatus::Busy,
122            },
123            Kernel::StartingKernel(_) => KernelStatus::Starting,
124            Kernel::ErroredLaunch(_) => KernelStatus::Error,
125            Kernel::ShuttingDown => KernelStatus::ShuttingDown,
126            Kernel::Shutdown => KernelStatus::Shutdown,
127            Kernel::Restarting => KernelStatus::Restarting,
128        }
129    }
130}
131
132#[derive(Debug)]
133pub enum Kernel {
134    RunningKernel(RunningKernel),
135    StartingKernel(Shared<Task<()>>),
136    ErroredLaunch(String),
137    ShuttingDown,
138    Shutdown,
139    Restarting,
140}
141
142impl Kernel {
143    pub fn status(&self) -> KernelStatus {
144        self.into()
145    }
146
147    pub fn set_execution_state(&mut self, status: &ExecutionState) {
148        match self {
149            Kernel::RunningKernel(running_kernel) => {
150                running_kernel.execution_state = status.clone();
151            }
152            _ => {}
153        }
154    }
155
156    pub fn set_kernel_info(&mut self, kernel_info: &KernelInfoReply) {
157        match self {
158            Kernel::RunningKernel(running_kernel) => {
159                running_kernel.kernel_info = Some(kernel_info.clone());
160            }
161            _ => {}
162        }
163    }
164
165    pub fn is_shutting_down(&self) -> bool {
166        match self {
167            Kernel::Restarting | Kernel::ShuttingDown => true,
168            Kernel::RunningKernel(_)
169            | Kernel::StartingKernel(_)
170            | Kernel::ErroredLaunch(_)
171            | Kernel::Shutdown => false,
172        }
173    }
174}
175
176pub struct RunningKernel {
177    pub process: smol::process::Child,
178    _shell_task: Task<Result<()>>,
179    _iopub_task: Task<Result<()>>,
180    _control_task: Task<Result<()>>,
181    _routing_task: Task<Result<()>>,
182    connection_path: PathBuf,
183    pub working_directory: PathBuf,
184    pub request_tx: mpsc::Sender<JupyterMessage>,
185    pub execution_state: ExecutionState,
186    pub kernel_info: Option<KernelInfoReply>,
187}
188
189type JupyterMessageChannel = stream::SelectAll<Receiver<JupyterMessage>>;
190
191impl Debug for RunningKernel {
192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        f.debug_struct("RunningKernel")
194            .field("process", &self.process)
195            .finish()
196    }
197}
198
199impl RunningKernel {
200    pub fn new(
201        kernel_specification: KernelSpecification,
202        entity_id: EntityId,
203        working_directory: PathBuf,
204        fs: Arc<dyn Fs>,
205        cx: &mut AppContext,
206    ) -> Task<Result<(Self, JupyterMessageChannel)>> {
207        cx.spawn(|cx| async move {
208            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
209            let ports = peek_ports(ip).await?;
210
211            let connection_info = ConnectionInfo {
212                transport: "tcp".to_string(),
213                ip: ip.to_string(),
214                stdin_port: ports[0],
215                control_port: ports[1],
216                hb_port: ports[2],
217                shell_port: ports[3],
218                iopub_port: ports[4],
219                signature_scheme: "hmac-sha256".to_string(),
220                key: uuid::Uuid::new_v4().to_string(),
221                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
222            };
223
224            let runtime_dir = dirs::runtime_dir();
225            fs.create_dir(&runtime_dir)
226                .await
227                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
228            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
229            let content = serde_json::to_string(&connection_info)?;
230            fs.atomic_write(connection_path.clone(), content).await?;
231
232            let mut cmd = kernel_specification.command(&connection_path)?;
233
234            let process = cmd
235                .current_dir(&working_directory)
236                .stdout(std::process::Stdio::piped())
237                .stderr(std::process::Stdio::piped())
238                .stdin(std::process::Stdio::piped())
239                .kill_on_drop(true)
240                .spawn()
241                .context("failed to start the kernel process")?;
242
243            let session_id = Uuid::new_v4().to_string();
244
245            let mut iopub_socket = connection_info
246                .create_client_iopub_connection("", &session_id)
247                .await?;
248            let mut shell_socket = connection_info
249                .create_client_shell_connection(&session_id)
250                .await?;
251            let mut control_socket = connection_info
252                .create_client_control_connection(&session_id)
253                .await?;
254
255            let (mut iopub, iosub) = futures::channel::mpsc::channel(100);
256
257            let (request_tx, mut request_rx) =
258                futures::channel::mpsc::channel::<JupyterMessage>(100);
259
260            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
261            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
262
263            let mut messages_rx = SelectAll::new();
264            messages_rx.push(iosub);
265            messages_rx.push(control_reply_rx);
266            messages_rx.push(shell_reply_rx);
267
268            let iopub_task = cx.background_executor().spawn({
269                async move {
270                    while let Ok(message) = iopub_socket.read().await {
271                        iopub.send(message).await?;
272                    }
273                    anyhow::Ok(())
274                }
275            });
276
277            let (mut control_request_tx, mut control_request_rx) =
278                futures::channel::mpsc::channel(100);
279            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
280
281            let routing_task = cx.background_executor().spawn({
282                async move {
283                    while let Some(message) = request_rx.next().await {
284                        match message.content {
285                            JupyterMessageContent::DebugRequest(_)
286                            | JupyterMessageContent::InterruptRequest(_)
287                            | JupyterMessageContent::ShutdownRequest(_) => {
288                                control_request_tx.send(message).await?;
289                            }
290                            _ => {
291                                shell_request_tx.send(message).await?;
292                            }
293                        }
294                    }
295                    anyhow::Ok(())
296                }
297            });
298
299            let shell_task = cx.background_executor().spawn({
300                async move {
301                    while let Some(message) = shell_request_rx.next().await {
302                        shell_socket.send(message).await.ok();
303                        let reply = shell_socket.read().await?;
304                        shell_reply_tx.send(reply).await?;
305                    }
306                    anyhow::Ok(())
307                }
308            });
309
310            let control_task = cx.background_executor().spawn({
311                async move {
312                    while let Some(message) = control_request_rx.next().await {
313                        control_socket.send(message).await.ok();
314                        let reply = control_socket.read().await?;
315                        control_reply_tx.send(reply).await?;
316                    }
317                    anyhow::Ok(())
318                }
319            });
320
321            anyhow::Ok((
322                Self {
323                    process,
324                    request_tx,
325                    working_directory,
326                    _shell_task: shell_task,
327                    _iopub_task: iopub_task,
328                    _control_task: control_task,
329                    _routing_task: routing_task,
330                    connection_path,
331                    execution_state: ExecutionState::Idle,
332                    kernel_info: None,
333                },
334                messages_rx,
335            ))
336        })
337    }
338}
339
340impl Drop for RunningKernel {
341    fn drop(&mut self) {
342        std::fs::remove_file(&self.connection_path).ok();
343        self.request_tx.close_channel();
344        self.process.kill().ok();
345    }
346}
347
348async fn read_kernelspec_at(
349    // Path should be a directory to a jupyter kernelspec, as in
350    // /usr/local/share/jupyter/kernels/python3
351    kernel_dir: PathBuf,
352    fs: &dyn Fs,
353) -> Result<KernelSpecification> {
354    let path = kernel_dir;
355    let kernel_name = if let Some(kernel_name) = path.file_name() {
356        kernel_name.to_string_lossy().to_string()
357    } else {
358        anyhow::bail!("Invalid kernelspec directory: {path:?}");
359    };
360
361    if !fs.is_dir(path.as_path()).await {
362        anyhow::bail!("Not a directory: {path:?}");
363    }
364
365    let expected_kernel_json = path.join("kernel.json");
366    let spec = fs.load(expected_kernel_json.as_path()).await?;
367    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
368
369    Ok(KernelSpecification {
370        name: kernel_name,
371        path,
372        kernelspec: spec,
373    })
374}
375
376/// Read a directory of kernelspec directories
377async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<KernelSpecification>> {
378    let mut kernelspec_dirs = fs.read_dir(&path).await?;
379
380    let mut valid_kernelspecs = Vec::new();
381    while let Some(path) = kernelspec_dirs.next().await {
382        match path {
383            Ok(path) => {
384                if fs.is_dir(path.as_path()).await {
385                    if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
386                        valid_kernelspecs.push(kernelspec);
387                    }
388                }
389            }
390            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
391        }
392    }
393
394    Ok(valid_kernelspecs)
395}
396
397pub async fn kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<KernelSpecification>> {
398    let mut data_dirs = dirs::data_dirs();
399
400    // Pick up any kernels from conda or conda environment
401    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
402        let conda_prefix = PathBuf::from(conda_prefix);
403        let conda_data_dir = conda_prefix.join("share").join("jupyter");
404        data_dirs.push(conda_data_dir);
405    }
406
407    // Search for kernels inside the base python environment
408    let mut command = Command::new("python");
409    command.arg("-c");
410    command.arg("import sys; print(sys.prefix)");
411
412    #[cfg(windows)]
413    {
414        use smol::process::windows::CommandExt;
415        command.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
416    }
417
418    let command = command.output().await;
419
420    if let Ok(command) = command {
421        if command.status.success() {
422            let python_prefix = String::from_utf8(command.stdout);
423            if let Ok(python_prefix) = python_prefix {
424                let python_prefix = PathBuf::from(python_prefix.trim());
425                let python_data_dir = python_prefix.join("share").join("jupyter");
426                data_dirs.push(python_data_dir);
427            }
428        }
429    }
430
431    let kernel_dirs = data_dirs
432        .iter()
433        .map(|dir| dir.join("kernels"))
434        .map(|path| read_kernels_dir(path, fs.as_ref()))
435        .collect::<Vec<_>>();
436
437    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
438    let kernel_dirs = kernel_dirs
439        .into_iter()
440        .filter_map(Result::ok)
441        .flatten()
442        .collect::<Vec<_>>();
443
444    Ok(kernel_dirs)
445}
446
447#[cfg(test)]
448mod test {
449    use super::*;
450    use std::path::PathBuf;
451
452    use gpui::TestAppContext;
453    use project::FakeFs;
454    use serde_json::json;
455
456    #[gpui::test]
457    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
458        let fs = FakeFs::new(cx.executor());
459        fs.insert_tree(
460            "/jupyter",
461            json!({
462                ".zed": {
463                    "settings.json": r#"{ "tab_size": 8 }"#,
464                    "tasks.json": r#"[{
465                        "label": "cargo check",
466                        "command": "cargo",
467                        "args": ["check", "--all"]
468                    },]"#,
469                },
470                "kernels": {
471                    "python": {
472                        "kernel.json": r#"{
473                            "display_name": "Python 3",
474                            "language": "python",
475                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
476                            "env": {}
477                        }"#
478                    },
479                    "deno": {
480                        "kernel.json": r#"{
481                            "display_name": "Deno",
482                            "language": "typescript",
483                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
484                            "env": {}
485                        }"#
486                    }
487                },
488            }),
489        )
490        .await;
491
492        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
493            .await
494            .unwrap();
495
496        kernels.sort_by(|a, b| a.name.cmp(&b.name));
497
498        assert_eq!(
499            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
500            vec!["deno", "python"]
501        );
502    }
503}