kernels.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    channel::mpsc::{self, Receiver},
  4    future::Shared,
  5    stream::{self, SelectAll, StreamExt},
  6    SinkExt as _,
  7};
  8use gpui::{AppContext, EntityId, Task};
  9use project::Fs;
 10use runtimelib::{
 11    dirs, ConnectionInfo, ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent,
 12    KernelInfoReply,
 13};
 14use smol::{net::TcpListener, process::Command};
 15use std::{
 16    env,
 17    fmt::Debug,
 18    net::{IpAddr, Ipv4Addr, SocketAddr},
 19    path::PathBuf,
 20    sync::Arc,
 21};
 22
 23#[derive(Debug, Clone)]
 24pub struct KernelSpecification {
 25    pub name: String,
 26    pub path: PathBuf,
 27    pub kernelspec: JupyterKernelspec,
 28}
 29
 30impl KernelSpecification {
 31    #[must_use]
 32    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 33        let argv = &self.kernelspec.argv;
 34
 35        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 36        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 37        anyhow::ensure!(
 38            argv.iter().any(|arg| arg == "{connection_file}"),
 39            "Missing 'connection_file' in argv in kernelspec {}",
 40            self.name
 41        );
 42
 43        let mut cmd = Command::new(&argv[0]);
 44
 45        for arg in &argv[1..] {
 46            if arg == "{connection_file}" {
 47                cmd.arg(connection_path);
 48            } else {
 49                cmd.arg(arg);
 50            }
 51        }
 52
 53        if let Some(env) = &self.kernelspec.env {
 54            cmd.envs(env);
 55        }
 56
 57        Ok(cmd)
 58    }
 59}
 60
 61// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 62// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 63async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 64    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 65    addr_zeroport.set_port(0);
 66    let mut ports: [u16; 5] = [0; 5];
 67    for i in 0..5 {
 68        let listener = TcpListener::bind(addr_zeroport).await?;
 69        let addr = listener.local_addr()?;
 70        ports[i] = addr.port();
 71    }
 72    Ok(ports)
 73}
 74
 75#[derive(Debug, Clone)]
 76pub enum KernelStatus {
 77    Idle,
 78    Busy,
 79    Starting,
 80    Error,
 81    ShuttingDown,
 82    Shutdown,
 83}
 84
 85impl KernelStatus {
 86    pub fn is_connected(&self) -> bool {
 87        match self {
 88            KernelStatus::Idle | KernelStatus::Busy => true,
 89            _ => false,
 90        }
 91    }
 92}
 93
 94impl ToString for KernelStatus {
 95    fn to_string(&self) -> String {
 96        match self {
 97            KernelStatus::Idle => "Idle".to_string(),
 98            KernelStatus::Busy => "Busy".to_string(),
 99            KernelStatus::Starting => "Starting".to_string(),
100            KernelStatus::Error => "Error".to_string(),
101            KernelStatus::ShuttingDown => "Shutting Down".to_string(),
102            KernelStatus::Shutdown => "Shutdown".to_string(),
103        }
104    }
105}
106
107impl From<&Kernel> for KernelStatus {
108    fn from(kernel: &Kernel) -> Self {
109        match kernel {
110            Kernel::RunningKernel(kernel) => match kernel.execution_state {
111                ExecutionState::Idle => KernelStatus::Idle,
112                ExecutionState::Busy => KernelStatus::Busy,
113            },
114            Kernel::StartingKernel(_) => KernelStatus::Starting,
115            Kernel::ErroredLaunch(_) => KernelStatus::Error,
116            Kernel::ShuttingDown => KernelStatus::ShuttingDown,
117            Kernel::Shutdown => KernelStatus::Shutdown,
118        }
119    }
120}
121
122#[derive(Debug)]
123pub enum Kernel {
124    RunningKernel(RunningKernel),
125    StartingKernel(Shared<Task<()>>),
126    ErroredLaunch(String),
127    ShuttingDown,
128    Shutdown,
129}
130
131impl Kernel {
132    pub fn status(&self) -> KernelStatus {
133        self.into()
134    }
135
136    pub fn set_execution_state(&mut self, status: &ExecutionState) {
137        match self {
138            Kernel::RunningKernel(running_kernel) => {
139                running_kernel.execution_state = status.clone();
140            }
141            _ => {}
142        }
143    }
144
145    pub fn set_kernel_info(&mut self, kernel_info: &KernelInfoReply) {
146        match self {
147            Kernel::RunningKernel(running_kernel) => {
148                running_kernel.kernel_info = Some(kernel_info.clone());
149            }
150            _ => {}
151        }
152    }
153
154    pub fn is_shutting_down(&self) -> bool {
155        match self {
156            Kernel::ShuttingDown => true,
157            Kernel::RunningKernel(_)
158            | Kernel::StartingKernel(_)
159            | Kernel::ErroredLaunch(_)
160            | Kernel::Shutdown => false,
161        }
162    }
163}
164
165pub struct RunningKernel {
166    pub process: smol::process::Child,
167    _shell_task: Task<Result<()>>,
168    _iopub_task: Task<Result<()>>,
169    _control_task: Task<Result<()>>,
170    _routing_task: Task<Result<()>>,
171    connection_path: PathBuf,
172    pub working_directory: PathBuf,
173    pub request_tx: mpsc::Sender<JupyterMessage>,
174    pub execution_state: ExecutionState,
175    pub kernel_info: Option<KernelInfoReply>,
176}
177
178type JupyterMessageChannel = stream::SelectAll<Receiver<JupyterMessage>>;
179
180impl Debug for RunningKernel {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        f.debug_struct("RunningKernel")
183            .field("process", &self.process)
184            .finish()
185    }
186}
187
188impl RunningKernel {
189    pub fn new(
190        kernel_specification: KernelSpecification,
191        entity_id: EntityId,
192        working_directory: PathBuf,
193        fs: Arc<dyn Fs>,
194        cx: &mut AppContext,
195    ) -> Task<Result<(Self, JupyterMessageChannel)>> {
196        cx.spawn(|cx| async move {
197            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
198            let ports = peek_ports(ip).await?;
199
200            let connection_info = ConnectionInfo {
201                transport: "tcp".to_string(),
202                ip: ip.to_string(),
203                stdin_port: ports[0],
204                control_port: ports[1],
205                hb_port: ports[2],
206                shell_port: ports[3],
207                iopub_port: ports[4],
208                signature_scheme: "hmac-sha256".to_string(),
209                key: uuid::Uuid::new_v4().to_string(),
210                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
211            };
212
213            let runtime_dir = dirs::runtime_dir();
214            fs.create_dir(&runtime_dir)
215                .await
216                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
217            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
218            let content = serde_json::to_string(&connection_info)?;
219            fs.atomic_write(connection_path.clone(), content).await?;
220
221            let mut cmd = kernel_specification.command(&connection_path)?;
222
223            let process = cmd
224                .current_dir(&working_directory)
225                .stdout(std::process::Stdio::piped())
226                .stderr(std::process::Stdio::piped())
227                .kill_on_drop(true)
228                .spawn()
229                .context("failed to start the kernel process")?;
230
231            let mut iopub_socket = connection_info.create_client_iopub_connection("").await?;
232            let mut shell_socket = connection_info.create_client_shell_connection().await?;
233            let mut control_socket = connection_info.create_client_control_connection().await?;
234
235            let (mut iopub, iosub) = futures::channel::mpsc::channel(100);
236
237            let (request_tx, mut request_rx) =
238                futures::channel::mpsc::channel::<JupyterMessage>(100);
239
240            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
241            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
242
243            let mut messages_rx = SelectAll::new();
244            messages_rx.push(iosub);
245            messages_rx.push(control_reply_rx);
246            messages_rx.push(shell_reply_rx);
247
248            let _iopub_task = cx.background_executor().spawn({
249                async move {
250                    while let Ok(message) = iopub_socket.read().await {
251                        iopub.send(message).await?;
252                    }
253                    anyhow::Ok(())
254                }
255            });
256
257            let (mut control_request_tx, mut control_request_rx) =
258                futures::channel::mpsc::channel(100);
259            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
260
261            let _routing_task = cx.background_executor().spawn({
262                async move {
263                    while let Some(message) = request_rx.next().await {
264                        match message.content {
265                            JupyterMessageContent::DebugRequest(_)
266                            | JupyterMessageContent::InterruptRequest(_)
267                            | JupyterMessageContent::ShutdownRequest(_) => {
268                                control_request_tx.send(message).await?;
269                            }
270                            _ => {
271                                shell_request_tx.send(message).await?;
272                            }
273                        }
274                    }
275                    anyhow::Ok(())
276                }
277            });
278
279            let _shell_task = cx.background_executor().spawn({
280                async move {
281                    while let Some(message) = shell_request_rx.next().await {
282                        shell_socket.send(message).await.ok();
283                        let reply = shell_socket.read().await?;
284                        shell_reply_tx.send(reply).await?;
285                    }
286                    anyhow::Ok(())
287                }
288            });
289
290            let _control_task = cx.background_executor().spawn({
291                async move {
292                    while let Some(message) = control_request_rx.next().await {
293                        control_socket.send(message).await.ok();
294                        let reply = control_socket.read().await?;
295                        control_reply_tx.send(reply).await?;
296                    }
297                    anyhow::Ok(())
298                }
299            });
300
301            anyhow::Ok((
302                Self {
303                    process,
304                    request_tx,
305                    working_directory,
306                    _shell_task,
307                    _iopub_task,
308                    _control_task,
309                    _routing_task,
310                    connection_path,
311                    execution_state: ExecutionState::Busy,
312                    kernel_info: None,
313                },
314                messages_rx,
315            ))
316        })
317    }
318}
319
320impl Drop for RunningKernel {
321    fn drop(&mut self) {
322        std::fs::remove_file(&self.connection_path).ok();
323
324        self.request_tx.close_channel();
325    }
326}
327
328async fn read_kernelspec_at(
329    // Path should be a directory to a jupyter kernelspec, as in
330    // /usr/local/share/jupyter/kernels/python3
331    kernel_dir: PathBuf,
332    fs: &dyn Fs,
333) -> Result<KernelSpecification> {
334    let path = kernel_dir;
335    let kernel_name = if let Some(kernel_name) = path.file_name() {
336        kernel_name.to_string_lossy().to_string()
337    } else {
338        anyhow::bail!("Invalid kernelspec directory: {path:?}");
339    };
340
341    if !fs.is_dir(path.as_path()).await {
342        anyhow::bail!("Not a directory: {path:?}");
343    }
344
345    let expected_kernel_json = path.join("kernel.json");
346    let spec = fs.load(expected_kernel_json.as_path()).await?;
347    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
348
349    Ok(KernelSpecification {
350        name: kernel_name,
351        path,
352        kernelspec: spec,
353    })
354}
355
356/// Read a directory of kernelspec directories
357async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<KernelSpecification>> {
358    let mut kernelspec_dirs = fs.read_dir(&path).await?;
359
360    let mut valid_kernelspecs = Vec::new();
361    while let Some(path) = kernelspec_dirs.next().await {
362        match path {
363            Ok(path) => {
364                if fs.is_dir(path.as_path()).await {
365                    if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
366                        valid_kernelspecs.push(kernelspec);
367                    }
368                }
369            }
370            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
371        }
372    }
373
374    Ok(valid_kernelspecs)
375}
376
377pub async fn kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<KernelSpecification>> {
378    let mut data_dirs = dirs::data_dirs();
379
380    // Pick up any kernels from conda or conda environment
381    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
382        let conda_prefix = PathBuf::from(conda_prefix);
383        let conda_data_dir = conda_prefix.join("share").join("jupyter");
384        data_dirs.push(conda_data_dir);
385    }
386
387    // Search for kernels inside the base python environment
388    let command = Command::new("python")
389        .arg("-c")
390        .arg("import sys; print(sys.prefix)")
391        .output()
392        .await;
393
394    if let Ok(command) = command {
395        if command.status.success() {
396            let python_prefix = String::from_utf8(command.stdout);
397            if let Ok(python_prefix) = python_prefix {
398                let python_prefix = PathBuf::from(python_prefix.trim());
399                let python_data_dir = python_prefix.join("share").join("jupyter");
400                data_dirs.push(python_data_dir);
401            }
402        }
403    }
404
405    let kernel_dirs = data_dirs
406        .iter()
407        .map(|dir| dir.join("kernels"))
408        .map(|path| read_kernels_dir(path, fs.as_ref()))
409        .collect::<Vec<_>>();
410
411    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
412    let kernel_dirs = kernel_dirs
413        .into_iter()
414        .filter_map(Result::ok)
415        .flatten()
416        .collect::<Vec<_>>();
417
418    Ok(kernel_dirs)
419}
420
421#[cfg(test)]
422mod test {
423    use super::*;
424    use std::path::PathBuf;
425
426    use gpui::TestAppContext;
427    use project::FakeFs;
428    use serde_json::json;
429
430    #[gpui::test]
431    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
432        let fs = FakeFs::new(cx.executor());
433        fs.insert_tree(
434            "/jupyter",
435            json!({
436                ".zed": {
437                    "settings.json": r#"{ "tab_size": 8 }"#,
438                    "tasks.json": r#"[{
439                        "label": "cargo check",
440                        "command": "cargo",
441                        "args": ["check", "--all"]
442                    },]"#,
443                },
444                "kernels": {
445                    "python": {
446                        "kernel.json": r#"{
447                            "display_name": "Python 3",
448                            "language": "python",
449                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
450                            "env": {}
451                        }"#
452                    },
453                    "deno": {
454                        "kernel.json": r#"{
455                            "display_name": "Deno",
456                            "language": "typescript",
457                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
458                            "env": {}
459                        }"#
460                    }
461                },
462            }),
463        )
464        .await;
465
466        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
467            .await
468            .unwrap();
469
470        kernels.sort_by(|a, b| a.name.cmp(&b.name));
471
472        assert_eq!(
473            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
474            vec!["deno", "python"]
475        );
476    }
477}