kernels.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    channel::mpsc::{self, Receiver},
  4    future::Shared,
  5    stream::{self, SelectAll, StreamExt},
  6    SinkExt as _,
  7};
  8use gpui::{AppContext, EntityId, Task};
  9use project::Fs;
 10use runtimelib::{
 11    dirs, ConnectionInfo, ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent,
 12    KernelInfoReply,
 13};
 14use smol::{net::TcpListener, process::Command};
 15use std::{
 16    env,
 17    fmt::Debug,
 18    net::{IpAddr, Ipv4Addr, SocketAddr},
 19    path::PathBuf,
 20    sync::Arc,
 21};
 22use uuid::Uuid;
 23
 24#[derive(Debug, Clone)]
 25pub struct KernelSpecification {
 26    pub name: String,
 27    pub path: PathBuf,
 28    pub kernelspec: JupyterKernelspec,
 29}
 30
 31impl KernelSpecification {
 32    #[must_use]
 33    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 34        let argv = &self.kernelspec.argv;
 35
 36        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 37        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 38        anyhow::ensure!(
 39            argv.iter().any(|arg| arg == "{connection_file}"),
 40            "Missing 'connection_file' in argv in kernelspec {}",
 41            self.name
 42        );
 43
 44        let mut cmd = Command::new(&argv[0]);
 45
 46        for arg in &argv[1..] {
 47            if arg == "{connection_file}" {
 48                cmd.arg(connection_path);
 49            } else {
 50                cmd.arg(arg);
 51            }
 52        }
 53
 54        if let Some(env) = &self.kernelspec.env {
 55            cmd.envs(env);
 56        }
 57
 58        #[cfg(windows)]
 59        {
 60            use smol::process::windows::CommandExt;
 61            cmd.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
 62        }
 63
 64        Ok(cmd)
 65    }
 66}
 67
 68// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 69// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 70async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 71    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 72    addr_zeroport.set_port(0);
 73    let mut ports: [u16; 5] = [0; 5];
 74    for i in 0..5 {
 75        let listener = TcpListener::bind(addr_zeroport).await?;
 76        let addr = listener.local_addr()?;
 77        ports[i] = addr.port();
 78    }
 79    Ok(ports)
 80}
 81
 82#[derive(Debug, Clone)]
 83pub enum KernelStatus {
 84    Idle,
 85    Busy,
 86    Starting,
 87    Error,
 88    ShuttingDown,
 89    Shutdown,
 90}
 91
 92impl KernelStatus {
 93    pub fn is_connected(&self) -> bool {
 94        match self {
 95            KernelStatus::Idle | KernelStatus::Busy => true,
 96            _ => false,
 97        }
 98    }
 99}
100
101impl ToString for KernelStatus {
102    fn to_string(&self) -> String {
103        match self {
104            KernelStatus::Idle => "Idle".to_string(),
105            KernelStatus::Busy => "Busy".to_string(),
106            KernelStatus::Starting => "Starting".to_string(),
107            KernelStatus::Error => "Error".to_string(),
108            KernelStatus::ShuttingDown => "Shutting Down".to_string(),
109            KernelStatus::Shutdown => "Shutdown".to_string(),
110        }
111    }
112}
113
114impl From<&Kernel> for KernelStatus {
115    fn from(kernel: &Kernel) -> Self {
116        match kernel {
117            Kernel::RunningKernel(kernel) => match kernel.execution_state {
118                ExecutionState::Idle => KernelStatus::Idle,
119                ExecutionState::Busy => KernelStatus::Busy,
120            },
121            Kernel::StartingKernel(_) => KernelStatus::Starting,
122            Kernel::ErroredLaunch(_) => KernelStatus::Error,
123            Kernel::ShuttingDown => KernelStatus::ShuttingDown,
124            Kernel::Shutdown => KernelStatus::Shutdown,
125        }
126    }
127}
128
129#[derive(Debug)]
130pub enum Kernel {
131    RunningKernel(RunningKernel),
132    StartingKernel(Shared<Task<()>>),
133    ErroredLaunch(String),
134    ShuttingDown,
135    Shutdown,
136}
137
138impl Kernel {
139    pub fn status(&self) -> KernelStatus {
140        self.into()
141    }
142
143    pub fn set_execution_state(&mut self, status: &ExecutionState) {
144        match self {
145            Kernel::RunningKernel(running_kernel) => {
146                running_kernel.execution_state = status.clone();
147            }
148            _ => {}
149        }
150    }
151
152    pub fn set_kernel_info(&mut self, kernel_info: &KernelInfoReply) {
153        match self {
154            Kernel::RunningKernel(running_kernel) => {
155                running_kernel.kernel_info = Some(kernel_info.clone());
156            }
157            _ => {}
158        }
159    }
160
161    pub fn is_shutting_down(&self) -> bool {
162        match self {
163            Kernel::ShuttingDown => true,
164            Kernel::RunningKernel(_)
165            | Kernel::StartingKernel(_)
166            | Kernel::ErroredLaunch(_)
167            | Kernel::Shutdown => false,
168        }
169    }
170}
171
172pub struct RunningKernel {
173    pub process: smol::process::Child,
174    _shell_task: Task<Result<()>>,
175    _iopub_task: Task<Result<()>>,
176    _control_task: Task<Result<()>>,
177    _routing_task: Task<Result<()>>,
178    connection_path: PathBuf,
179    pub working_directory: PathBuf,
180    pub request_tx: mpsc::Sender<JupyterMessage>,
181    pub execution_state: ExecutionState,
182    pub kernel_info: Option<KernelInfoReply>,
183}
184
185type JupyterMessageChannel = stream::SelectAll<Receiver<JupyterMessage>>;
186
187impl Debug for RunningKernel {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        f.debug_struct("RunningKernel")
190            .field("process", &self.process)
191            .finish()
192    }
193}
194
195impl RunningKernel {
196    pub fn new(
197        kernel_specification: KernelSpecification,
198        entity_id: EntityId,
199        working_directory: PathBuf,
200        fs: Arc<dyn Fs>,
201        cx: &mut AppContext,
202    ) -> Task<Result<(Self, JupyterMessageChannel)>> {
203        cx.spawn(|cx| async move {
204            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
205            let ports = peek_ports(ip).await?;
206
207            let connection_info = ConnectionInfo {
208                transport: "tcp".to_string(),
209                ip: ip.to_string(),
210                stdin_port: ports[0],
211                control_port: ports[1],
212                hb_port: ports[2],
213                shell_port: ports[3],
214                iopub_port: ports[4],
215                signature_scheme: "hmac-sha256".to_string(),
216                key: uuid::Uuid::new_v4().to_string(),
217                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
218            };
219
220            let runtime_dir = dirs::runtime_dir();
221            fs.create_dir(&runtime_dir)
222                .await
223                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
224            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
225            let content = serde_json::to_string(&connection_info)?;
226            fs.atomic_write(connection_path.clone(), content).await?;
227
228            let mut cmd = kernel_specification.command(&connection_path)?;
229
230            let process = cmd
231                .current_dir(&working_directory)
232                .stdout(std::process::Stdio::piped())
233                .stderr(std::process::Stdio::piped())
234                .stdin(std::process::Stdio::piped())
235                .kill_on_drop(true)
236                .spawn()
237                .context("failed to start the kernel process")?;
238
239            let session_id = Uuid::new_v4().to_string();
240
241            let mut iopub_socket = connection_info
242                .create_client_iopub_connection("", &session_id)
243                .await?;
244            let mut shell_socket = connection_info
245                .create_client_shell_connection(&session_id)
246                .await?;
247            let mut control_socket = connection_info
248                .create_client_control_connection(&session_id)
249                .await?;
250
251            let (mut iopub, iosub) = futures::channel::mpsc::channel(100);
252
253            let (request_tx, mut request_rx) =
254                futures::channel::mpsc::channel::<JupyterMessage>(100);
255
256            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
257            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
258
259            let mut messages_rx = SelectAll::new();
260            messages_rx.push(iosub);
261            messages_rx.push(control_reply_rx);
262            messages_rx.push(shell_reply_rx);
263
264            let _iopub_task = cx.background_executor().spawn({
265                async move {
266                    while let Ok(message) = iopub_socket.read().await {
267                        iopub.send(message).await?;
268                    }
269                    anyhow::Ok(())
270                }
271            });
272
273            let (mut control_request_tx, mut control_request_rx) =
274                futures::channel::mpsc::channel(100);
275            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
276
277            let _routing_task = cx.background_executor().spawn({
278                async move {
279                    while let Some(message) = request_rx.next().await {
280                        match message.content {
281                            JupyterMessageContent::DebugRequest(_)
282                            | JupyterMessageContent::InterruptRequest(_)
283                            | JupyterMessageContent::ShutdownRequest(_) => {
284                                control_request_tx.send(message).await?;
285                            }
286                            _ => {
287                                shell_request_tx.send(message).await?;
288                            }
289                        }
290                    }
291                    anyhow::Ok(())
292                }
293            });
294
295            let _shell_task = cx.background_executor().spawn({
296                async move {
297                    while let Some(message) = shell_request_rx.next().await {
298                        shell_socket.send(message).await.ok();
299                        let reply = shell_socket.read().await?;
300                        shell_reply_tx.send(reply).await?;
301                    }
302                    anyhow::Ok(())
303                }
304            });
305
306            let _control_task = cx.background_executor().spawn({
307                async move {
308                    while let Some(message) = control_request_rx.next().await {
309                        control_socket.send(message).await.ok();
310                        let reply = control_socket.read().await?;
311                        control_reply_tx.send(reply).await?;
312                    }
313                    anyhow::Ok(())
314                }
315            });
316
317            anyhow::Ok((
318                Self {
319                    process,
320                    request_tx,
321                    working_directory,
322                    _shell_task,
323                    _iopub_task,
324                    _control_task,
325                    _routing_task,
326                    connection_path,
327                    execution_state: ExecutionState::Busy,
328                    kernel_info: None,
329                },
330                messages_rx,
331            ))
332        })
333    }
334}
335
336impl Drop for RunningKernel {
337    fn drop(&mut self) {
338        std::fs::remove_file(&self.connection_path).ok();
339        self.request_tx.close_channel();
340        self.process.kill().ok();
341    }
342}
343
344async fn read_kernelspec_at(
345    // Path should be a directory to a jupyter kernelspec, as in
346    // /usr/local/share/jupyter/kernels/python3
347    kernel_dir: PathBuf,
348    fs: &dyn Fs,
349) -> Result<KernelSpecification> {
350    let path = kernel_dir;
351    let kernel_name = if let Some(kernel_name) = path.file_name() {
352        kernel_name.to_string_lossy().to_string()
353    } else {
354        anyhow::bail!("Invalid kernelspec directory: {path:?}");
355    };
356
357    if !fs.is_dir(path.as_path()).await {
358        anyhow::bail!("Not a directory: {path:?}");
359    }
360
361    let expected_kernel_json = path.join("kernel.json");
362    let spec = fs.load(expected_kernel_json.as_path()).await?;
363    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
364
365    Ok(KernelSpecification {
366        name: kernel_name,
367        path,
368        kernelspec: spec,
369    })
370}
371
372/// Read a directory of kernelspec directories
373async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<KernelSpecification>> {
374    let mut kernelspec_dirs = fs.read_dir(&path).await?;
375
376    let mut valid_kernelspecs = Vec::new();
377    while let Some(path) = kernelspec_dirs.next().await {
378        match path {
379            Ok(path) => {
380                if fs.is_dir(path.as_path()).await {
381                    if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
382                        valid_kernelspecs.push(kernelspec);
383                    }
384                }
385            }
386            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
387        }
388    }
389
390    Ok(valid_kernelspecs)
391}
392
393pub async fn kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<KernelSpecification>> {
394    let mut data_dirs = dirs::data_dirs();
395
396    // Pick up any kernels from conda or conda environment
397    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
398        let conda_prefix = PathBuf::from(conda_prefix);
399        let conda_data_dir = conda_prefix.join("share").join("jupyter");
400        data_dirs.push(conda_data_dir);
401    }
402
403    // Search for kernels inside the base python environment
404    let mut command = Command::new("python");
405    command.arg("-c");
406    command.arg("import sys; print(sys.prefix)");
407
408    #[cfg(windows)]
409    {
410        use smol::process::windows::CommandExt;
411        command.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
412    }
413
414    let command = command.output().await;
415
416    if let Ok(command) = command {
417        if command.status.success() {
418            let python_prefix = String::from_utf8(command.stdout);
419            if let Ok(python_prefix) = python_prefix {
420                let python_prefix = PathBuf::from(python_prefix.trim());
421                let python_data_dir = python_prefix.join("share").join("jupyter");
422                data_dirs.push(python_data_dir);
423            }
424        }
425    }
426
427    let kernel_dirs = data_dirs
428        .iter()
429        .map(|dir| dir.join("kernels"))
430        .map(|path| read_kernels_dir(path, fs.as_ref()))
431        .collect::<Vec<_>>();
432
433    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
434    let kernel_dirs = kernel_dirs
435        .into_iter()
436        .filter_map(Result::ok)
437        .flatten()
438        .collect::<Vec<_>>();
439
440    Ok(kernel_dirs)
441}
442
443#[cfg(test)]
444mod test {
445    use super::*;
446    use std::path::PathBuf;
447
448    use gpui::TestAppContext;
449    use project::FakeFs;
450    use serde_json::json;
451
452    #[gpui::test]
453    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
454        let fs = FakeFs::new(cx.executor());
455        fs.insert_tree(
456            "/jupyter",
457            json!({
458                ".zed": {
459                    "settings.json": r#"{ "tab_size": 8 }"#,
460                    "tasks.json": r#"[{
461                        "label": "cargo check",
462                        "command": "cargo",
463                        "args": ["check", "--all"]
464                    },]"#,
465                },
466                "kernels": {
467                    "python": {
468                        "kernel.json": r#"{
469                            "display_name": "Python 3",
470                            "language": "python",
471                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
472                            "env": {}
473                        }"#
474                    },
475                    "deno": {
476                        "kernel.json": r#"{
477                            "display_name": "Deno",
478                            "language": "typescript",
479                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
480                            "env": {}
481                        }"#
482                    }
483                },
484            }),
485        )
486        .await;
487
488        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
489            .await
490            .unwrap();
491
492        kernels.sort_by(|a, b| a.name.cmp(&b.name));
493
494        assert_eq!(
495            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
496            vec!["deno", "python"]
497        );
498    }
499}