kernels.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    channel::mpsc::{self, Receiver},
  4    future::Shared,
  5    stream::{self, SelectAll, StreamExt},
  6    SinkExt as _,
  7};
  8use gpui::{AppContext, EntityId, Task};
  9use project::Fs;
 10use runtimelib::{
 11    dirs, ConnectionInfo, ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent,
 12    KernelInfoReply,
 13};
 14use smol::{net::TcpListener, process::Command};
 15use std::{
 16    env,
 17    fmt::Debug,
 18    net::{IpAddr, Ipv4Addr, SocketAddr},
 19    path::PathBuf,
 20    sync::Arc,
 21};
 22use uuid::Uuid;
 23
 24#[derive(Debug, Clone)]
 25pub struct KernelSpecification {
 26    pub name: String,
 27    pub path: PathBuf,
 28    pub kernelspec: JupyterKernelspec,
 29}
 30
 31impl KernelSpecification {
 32    #[must_use]
 33    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 34        let argv = &self.kernelspec.argv;
 35
 36        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 37        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 38        anyhow::ensure!(
 39            argv.iter().any(|arg| arg == "{connection_file}"),
 40            "Missing 'connection_file' in argv in kernelspec {}",
 41            self.name
 42        );
 43
 44        let mut cmd = Command::new(&argv[0]);
 45
 46        for arg in &argv[1..] {
 47            if arg == "{connection_file}" {
 48                cmd.arg(connection_path);
 49            } else {
 50                cmd.arg(arg);
 51            }
 52        }
 53
 54        if let Some(env) = &self.kernelspec.env {
 55            cmd.envs(env);
 56        }
 57
 58        #[cfg(windows)]
 59        {
 60            use smol::process::windows::CommandExt;
 61            cmd.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
 62        }
 63
 64        Ok(cmd)
 65    }
 66}
 67
 68// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 69// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 70async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 71    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 72    addr_zeroport.set_port(0);
 73    let mut ports: [u16; 5] = [0; 5];
 74    for i in 0..5 {
 75        let listener = TcpListener::bind(addr_zeroport).await?;
 76        let addr = listener.local_addr()?;
 77        ports[i] = addr.port();
 78    }
 79    Ok(ports)
 80}
 81
 82#[derive(Debug, Clone)]
 83pub enum KernelStatus {
 84    Idle,
 85    Busy,
 86    Starting,
 87    Error,
 88    ShuttingDown,
 89    Shutdown,
 90    Restarting,
 91}
 92
 93impl KernelStatus {
 94    pub fn is_connected(&self) -> bool {
 95        match self {
 96            KernelStatus::Idle | KernelStatus::Busy => true,
 97            _ => false,
 98        }
 99    }
100}
101
102impl ToString for KernelStatus {
103    fn to_string(&self) -> String {
104        match self {
105            KernelStatus::Idle => "Idle".to_string(),
106            KernelStatus::Busy => "Busy".to_string(),
107            KernelStatus::Starting => "Starting".to_string(),
108            KernelStatus::Error => "Error".to_string(),
109            KernelStatus::ShuttingDown => "Shutting Down".to_string(),
110            KernelStatus::Shutdown => "Shutdown".to_string(),
111            KernelStatus::Restarting => "Restarting".to_string(),
112        }
113    }
114}
115
116impl From<&Kernel> for KernelStatus {
117    fn from(kernel: &Kernel) -> Self {
118        match kernel {
119            Kernel::RunningKernel(kernel) => match kernel.execution_state {
120                ExecutionState::Idle => KernelStatus::Idle,
121                ExecutionState::Busy => KernelStatus::Busy,
122            },
123            Kernel::StartingKernel(_) => KernelStatus::Starting,
124            Kernel::ErroredLaunch(_) => KernelStatus::Error,
125            Kernel::ShuttingDown => KernelStatus::ShuttingDown,
126            Kernel::Shutdown => KernelStatus::Shutdown,
127            Kernel::Restarting => KernelStatus::Restarting,
128        }
129    }
130}
131
132#[derive(Debug)]
133pub enum Kernel {
134    RunningKernel(RunningKernel),
135    StartingKernel(Shared<Task<()>>),
136    ErroredLaunch(String),
137    ShuttingDown,
138    Shutdown,
139    Restarting,
140}
141
142impl Kernel {
143    pub fn status(&self) -> KernelStatus {
144        self.into()
145    }
146
147    pub fn set_execution_state(&mut self, status: &ExecutionState) {
148        if let Kernel::RunningKernel(running_kernel) = self {
149            running_kernel.execution_state = status.clone();
150        }
151    }
152
153    pub fn set_kernel_info(&mut self, kernel_info: &KernelInfoReply) {
154        if let Kernel::RunningKernel(running_kernel) = self {
155            running_kernel.kernel_info = Some(kernel_info.clone());
156        }
157    }
158
159    pub fn is_shutting_down(&self) -> bool {
160        match self {
161            Kernel::Restarting | Kernel::ShuttingDown => true,
162            Kernel::RunningKernel(_)
163            | Kernel::StartingKernel(_)
164            | Kernel::ErroredLaunch(_)
165            | Kernel::Shutdown => false,
166        }
167    }
168}
169
170pub struct RunningKernel {
171    pub process: smol::process::Child,
172    _shell_task: Task<Result<()>>,
173    _iopub_task: Task<Result<()>>,
174    _control_task: Task<Result<()>>,
175    _routing_task: Task<Result<()>>,
176    connection_path: PathBuf,
177    pub working_directory: PathBuf,
178    pub request_tx: mpsc::Sender<JupyterMessage>,
179    pub execution_state: ExecutionState,
180    pub kernel_info: Option<KernelInfoReply>,
181}
182
183type JupyterMessageChannel = stream::SelectAll<Receiver<JupyterMessage>>;
184
185impl Debug for RunningKernel {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        f.debug_struct("RunningKernel")
188            .field("process", &self.process)
189            .finish()
190    }
191}
192
193impl RunningKernel {
194    pub fn new(
195        kernel_specification: KernelSpecification,
196        entity_id: EntityId,
197        working_directory: PathBuf,
198        fs: Arc<dyn Fs>,
199        cx: &mut AppContext,
200    ) -> Task<Result<(Self, JupyterMessageChannel)>> {
201        cx.spawn(|cx| async move {
202            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
203            let ports = peek_ports(ip).await?;
204
205            let connection_info = ConnectionInfo {
206                transport: "tcp".to_string(),
207                ip: ip.to_string(),
208                stdin_port: ports[0],
209                control_port: ports[1],
210                hb_port: ports[2],
211                shell_port: ports[3],
212                iopub_port: ports[4],
213                signature_scheme: "hmac-sha256".to_string(),
214                key: uuid::Uuid::new_v4().to_string(),
215                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
216            };
217
218            let runtime_dir = dirs::runtime_dir();
219            fs.create_dir(&runtime_dir)
220                .await
221                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
222            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
223            let content = serde_json::to_string(&connection_info)?;
224            fs.atomic_write(connection_path.clone(), content).await?;
225
226            let mut cmd = kernel_specification.command(&connection_path)?;
227
228            let process = cmd
229                .current_dir(&working_directory)
230                .stdout(std::process::Stdio::piped())
231                .stderr(std::process::Stdio::piped())
232                .stdin(std::process::Stdio::piped())
233                .kill_on_drop(true)
234                .spawn()
235                .context("failed to start the kernel process")?;
236
237            let session_id = Uuid::new_v4().to_string();
238
239            let mut iopub_socket = connection_info
240                .create_client_iopub_connection("", &session_id)
241                .await?;
242            let mut shell_socket = connection_info
243                .create_client_shell_connection(&session_id)
244                .await?;
245            let mut control_socket = connection_info
246                .create_client_control_connection(&session_id)
247                .await?;
248
249            let (mut iopub, iosub) = futures::channel::mpsc::channel(100);
250
251            let (request_tx, mut request_rx) =
252                futures::channel::mpsc::channel::<JupyterMessage>(100);
253
254            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
255            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
256
257            let mut messages_rx = SelectAll::new();
258            messages_rx.push(iosub);
259            messages_rx.push(control_reply_rx);
260            messages_rx.push(shell_reply_rx);
261
262            let iopub_task = cx.background_executor().spawn({
263                async move {
264                    while let Ok(message) = iopub_socket.read().await {
265                        iopub.send(message).await?;
266                    }
267                    anyhow::Ok(())
268                }
269            });
270
271            let (mut control_request_tx, mut control_request_rx) =
272                futures::channel::mpsc::channel(100);
273            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
274
275            let routing_task = cx.background_executor().spawn({
276                async move {
277                    while let Some(message) = request_rx.next().await {
278                        match message.content {
279                            JupyterMessageContent::DebugRequest(_)
280                            | JupyterMessageContent::InterruptRequest(_)
281                            | JupyterMessageContent::ShutdownRequest(_) => {
282                                control_request_tx.send(message).await?;
283                            }
284                            _ => {
285                                shell_request_tx.send(message).await?;
286                            }
287                        }
288                    }
289                    anyhow::Ok(())
290                }
291            });
292
293            let shell_task = cx.background_executor().spawn({
294                async move {
295                    while let Some(message) = shell_request_rx.next().await {
296                        shell_socket.send(message).await.ok();
297                        let reply = shell_socket.read().await?;
298                        shell_reply_tx.send(reply).await?;
299                    }
300                    anyhow::Ok(())
301                }
302            });
303
304            let control_task = cx.background_executor().spawn({
305                async move {
306                    while let Some(message) = control_request_rx.next().await {
307                        control_socket.send(message).await.ok();
308                        let reply = control_socket.read().await?;
309                        control_reply_tx.send(reply).await?;
310                    }
311                    anyhow::Ok(())
312                }
313            });
314
315            anyhow::Ok((
316                Self {
317                    process,
318                    request_tx,
319                    working_directory,
320                    _shell_task: shell_task,
321                    _iopub_task: iopub_task,
322                    _control_task: control_task,
323                    _routing_task: routing_task,
324                    connection_path,
325                    execution_state: ExecutionState::Idle,
326                    kernel_info: None,
327                },
328                messages_rx,
329            ))
330        })
331    }
332}
333
334impl Drop for RunningKernel {
335    fn drop(&mut self) {
336        std::fs::remove_file(&self.connection_path).ok();
337        self.request_tx.close_channel();
338        self.process.kill().ok();
339    }
340}
341
342async fn read_kernelspec_at(
343    // Path should be a directory to a jupyter kernelspec, as in
344    // /usr/local/share/jupyter/kernels/python3
345    kernel_dir: PathBuf,
346    fs: &dyn Fs,
347) -> Result<KernelSpecification> {
348    let path = kernel_dir;
349    let kernel_name = if let Some(kernel_name) = path.file_name() {
350        kernel_name.to_string_lossy().to_string()
351    } else {
352        anyhow::bail!("Invalid kernelspec directory: {path:?}");
353    };
354
355    if !fs.is_dir(path.as_path()).await {
356        anyhow::bail!("Not a directory: {path:?}");
357    }
358
359    let expected_kernel_json = path.join("kernel.json");
360    let spec = fs.load(expected_kernel_json.as_path()).await?;
361    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
362
363    Ok(KernelSpecification {
364        name: kernel_name,
365        path,
366        kernelspec: spec,
367    })
368}
369
370/// Read a directory of kernelspec directories
371async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<KernelSpecification>> {
372    let mut kernelspec_dirs = fs.read_dir(&path).await?;
373
374    let mut valid_kernelspecs = Vec::new();
375    while let Some(path) = kernelspec_dirs.next().await {
376        match path {
377            Ok(path) => {
378                if fs.is_dir(path.as_path()).await {
379                    if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
380                        valid_kernelspecs.push(kernelspec);
381                    }
382                }
383            }
384            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
385        }
386    }
387
388    Ok(valid_kernelspecs)
389}
390
391pub async fn kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<KernelSpecification>> {
392    let mut data_dirs = dirs::data_dirs();
393
394    // Pick up any kernels from conda or conda environment
395    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
396        let conda_prefix = PathBuf::from(conda_prefix);
397        let conda_data_dir = conda_prefix.join("share").join("jupyter");
398        data_dirs.push(conda_data_dir);
399    }
400
401    // Search for kernels inside the base python environment
402    let mut command = Command::new("python");
403    command.arg("-c");
404    command.arg("import sys; print(sys.prefix)");
405
406    #[cfg(windows)]
407    {
408        use smol::process::windows::CommandExt;
409        command.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
410    }
411
412    let command = command.output().await;
413
414    if let Ok(command) = command {
415        if command.status.success() {
416            let python_prefix = String::from_utf8(command.stdout);
417            if let Ok(python_prefix) = python_prefix {
418                let python_prefix = PathBuf::from(python_prefix.trim());
419                let python_data_dir = python_prefix.join("share").join("jupyter");
420                data_dirs.push(python_data_dir);
421            }
422        }
423    }
424
425    let kernel_dirs = data_dirs
426        .iter()
427        .map(|dir| dir.join("kernels"))
428        .map(|path| read_kernels_dir(path, fs.as_ref()))
429        .collect::<Vec<_>>();
430
431    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
432    let kernel_dirs = kernel_dirs
433        .into_iter()
434        .filter_map(Result::ok)
435        .flatten()
436        .collect::<Vec<_>>();
437
438    Ok(kernel_dirs)
439}
440
441#[cfg(test)]
442mod test {
443    use super::*;
444    use std::path::PathBuf;
445
446    use gpui::TestAppContext;
447    use project::FakeFs;
448    use serde_json::json;
449
450    #[gpui::test]
451    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
452        let fs = FakeFs::new(cx.executor());
453        fs.insert_tree(
454            "/jupyter",
455            json!({
456                ".zed": {
457                    "settings.json": r#"{ "tab_size": 8 }"#,
458                    "tasks.json": r#"[{
459                        "label": "cargo check",
460                        "command": "cargo",
461                        "args": ["check", "--all"]
462                    },]"#,
463                },
464                "kernels": {
465                    "python": {
466                        "kernel.json": r#"{
467                            "display_name": "Python 3",
468                            "language": "python",
469                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
470                            "env": {}
471                        }"#
472                    },
473                    "deno": {
474                        "kernel.json": r#"{
475                            "display_name": "Deno",
476                            "language": "typescript",
477                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
478                            "env": {}
479                        }"#
480                    }
481                },
482            }),
483        )
484        .await;
485
486        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
487            .await
488            .unwrap();
489
490        kernels.sort_by(|a, b| a.name.cmp(&b.name));
491
492        assert_eq!(
493            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
494            vec!["deno", "python"]
495        );
496    }
497}