kernels.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    channel::mpsc::{self, Receiver},
  4    future::Shared,
  5    stream::{self, SelectAll, StreamExt},
  6    SinkExt as _,
  7};
  8use gpui::{AppContext, EntityId, Task};
  9use project::Fs;
 10use runtimelib::{
 11    dirs, ConnectionInfo, ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent,
 12    KernelInfoReply,
 13};
 14use smol::{net::TcpListener, process::Command};
 15use std::{
 16    fmt::Debug,
 17    net::{IpAddr, Ipv4Addr, SocketAddr},
 18    path::PathBuf,
 19    sync::Arc,
 20};
 21
 22#[derive(Debug, Clone)]
 23pub struct KernelSpecification {
 24    pub name: String,
 25    pub path: PathBuf,
 26    pub kernelspec: JupyterKernelspec,
 27}
 28
 29impl KernelSpecification {
 30    #[must_use]
 31    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 32        let argv = &self.kernelspec.argv;
 33
 34        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 35        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 36        anyhow::ensure!(
 37            argv.iter().any(|arg| arg == "{connection_file}"),
 38            "Missing 'connection_file' in argv in kernelspec {}",
 39            self.name
 40        );
 41
 42        let mut cmd = Command::new(&argv[0]);
 43
 44        for arg in &argv[1..] {
 45            if arg == "{connection_file}" {
 46                cmd.arg(connection_path);
 47            } else {
 48                cmd.arg(arg);
 49            }
 50        }
 51
 52        if let Some(env) = &self.kernelspec.env {
 53            cmd.envs(env);
 54        }
 55
 56        Ok(cmd)
 57    }
 58}
 59
 60// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 61// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 62async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 63    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 64    addr_zeroport.set_port(0);
 65    let mut ports: [u16; 5] = [0; 5];
 66    for i in 0..5 {
 67        let listener = TcpListener::bind(addr_zeroport).await?;
 68        let addr = listener.local_addr()?;
 69        ports[i] = addr.port();
 70    }
 71    Ok(ports)
 72}
 73
 74#[derive(Debug, Clone)]
 75pub enum KernelStatus {
 76    Idle,
 77    Busy,
 78    Starting,
 79    Error,
 80    ShuttingDown,
 81    Shutdown,
 82}
 83
 84impl KernelStatus {
 85    pub fn is_connected(&self) -> bool {
 86        match self {
 87            KernelStatus::Idle | KernelStatus::Busy => true,
 88            _ => false,
 89        }
 90    }
 91}
 92
 93impl ToString for KernelStatus {
 94    fn to_string(&self) -> String {
 95        match self {
 96            KernelStatus::Idle => "Idle".to_string(),
 97            KernelStatus::Busy => "Busy".to_string(),
 98            KernelStatus::Starting => "Starting".to_string(),
 99            KernelStatus::Error => "Error".to_string(),
100            KernelStatus::ShuttingDown => "Shutting Down".to_string(),
101            KernelStatus::Shutdown => "Shutdown".to_string(),
102        }
103    }
104}
105
106impl From<&Kernel> for KernelStatus {
107    fn from(kernel: &Kernel) -> Self {
108        match kernel {
109            Kernel::RunningKernel(kernel) => match kernel.execution_state {
110                ExecutionState::Idle => KernelStatus::Idle,
111                ExecutionState::Busy => KernelStatus::Busy,
112            },
113            Kernel::StartingKernel(_) => KernelStatus::Starting,
114            Kernel::ErroredLaunch(_) => KernelStatus::Error,
115            Kernel::ShuttingDown => KernelStatus::ShuttingDown,
116            Kernel::Shutdown => KernelStatus::Shutdown,
117        }
118    }
119}
120
121#[derive(Debug)]
122pub enum Kernel {
123    RunningKernel(RunningKernel),
124    StartingKernel(Shared<Task<()>>),
125    ErroredLaunch(String),
126    ShuttingDown,
127    Shutdown,
128}
129
130impl Kernel {
131    pub fn status(&self) -> KernelStatus {
132        self.into()
133    }
134
135    pub fn set_execution_state(&mut self, status: &ExecutionState) {
136        match self {
137            Kernel::RunningKernel(running_kernel) => {
138                running_kernel.execution_state = status.clone();
139            }
140            _ => {}
141        }
142    }
143
144    pub fn set_kernel_info(&mut self, kernel_info: &KernelInfoReply) {
145        match self {
146            Kernel::RunningKernel(running_kernel) => {
147                running_kernel.kernel_info = Some(kernel_info.clone());
148            }
149            _ => {}
150        }
151    }
152
153    pub fn is_shutting_down(&self) -> bool {
154        match self {
155            Kernel::ShuttingDown => true,
156            Kernel::RunningKernel(_)
157            | Kernel::StartingKernel(_)
158            | Kernel::ErroredLaunch(_)
159            | Kernel::Shutdown => false,
160        }
161    }
162}
163
164pub struct RunningKernel {
165    pub process: smol::process::Child,
166    _shell_task: Task<Result<()>>,
167    _iopub_task: Task<Result<()>>,
168    _control_task: Task<Result<()>>,
169    _routing_task: Task<Result<()>>,
170    connection_path: PathBuf,
171    pub working_directory: PathBuf,
172    pub request_tx: mpsc::Sender<JupyterMessage>,
173    pub execution_state: ExecutionState,
174    pub kernel_info: Option<KernelInfoReply>,
175}
176
177type JupyterMessageChannel = stream::SelectAll<Receiver<JupyterMessage>>;
178
179impl Debug for RunningKernel {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        f.debug_struct("RunningKernel")
182            .field("process", &self.process)
183            .finish()
184    }
185}
186
187impl RunningKernel {
188    pub fn new(
189        kernel_specification: KernelSpecification,
190        entity_id: EntityId,
191        working_directory: PathBuf,
192        fs: Arc<dyn Fs>,
193        cx: &mut AppContext,
194    ) -> Task<Result<(Self, JupyterMessageChannel)>> {
195        cx.spawn(|cx| async move {
196            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
197            let ports = peek_ports(ip).await?;
198
199            let connection_info = ConnectionInfo {
200                transport: "tcp".to_string(),
201                ip: ip.to_string(),
202                stdin_port: ports[0],
203                control_port: ports[1],
204                hb_port: ports[2],
205                shell_port: ports[3],
206                iopub_port: ports[4],
207                signature_scheme: "hmac-sha256".to_string(),
208                key: uuid::Uuid::new_v4().to_string(),
209                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
210            };
211
212            let runtime_dir = dirs::runtime_dir();
213            fs.create_dir(&runtime_dir)
214                .await
215                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
216            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
217            let content = serde_json::to_string(&connection_info)?;
218            fs.atomic_write(connection_path.clone(), content).await?;
219
220            let mut cmd = kernel_specification.command(&connection_path)?;
221
222            let process = cmd
223                .current_dir(&working_directory)
224                // .stdout(Stdio::null())
225                // .stderr(Stdio::null())
226                .kill_on_drop(true)
227                .spawn()
228                .context("failed to start the kernel process")?;
229
230            let mut iopub_socket = connection_info.create_client_iopub_connection("").await?;
231            let mut shell_socket = connection_info.create_client_shell_connection().await?;
232            let mut control_socket = connection_info.create_client_control_connection().await?;
233
234            let (mut iopub, iosub) = futures::channel::mpsc::channel(100);
235
236            let (request_tx, mut request_rx) =
237                futures::channel::mpsc::channel::<JupyterMessage>(100);
238
239            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
240            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
241
242            let mut messages_rx = SelectAll::new();
243            messages_rx.push(iosub);
244            messages_rx.push(control_reply_rx);
245            messages_rx.push(shell_reply_rx);
246
247            let _iopub_task = cx.background_executor().spawn({
248                async move {
249                    while let Ok(message) = iopub_socket.read().await {
250                        iopub.send(message).await?;
251                    }
252                    anyhow::Ok(())
253                }
254            });
255
256            let (mut control_request_tx, mut control_request_rx) =
257                futures::channel::mpsc::channel(100);
258            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
259
260            let _routing_task = cx.background_executor().spawn({
261                async move {
262                    while let Some(message) = request_rx.next().await {
263                        match message.content {
264                            JupyterMessageContent::DebugRequest(_)
265                            | JupyterMessageContent::InterruptRequest(_)
266                            | JupyterMessageContent::ShutdownRequest(_) => {
267                                control_request_tx.send(message).await?;
268                            }
269                            _ => {
270                                shell_request_tx.send(message).await?;
271                            }
272                        }
273                    }
274                    anyhow::Ok(())
275                }
276            });
277
278            let _shell_task = cx.background_executor().spawn({
279                async move {
280                    while let Some(message) = shell_request_rx.next().await {
281                        shell_socket.send(message).await.ok();
282                        let reply = shell_socket.read().await?;
283                        shell_reply_tx.send(reply).await?;
284                    }
285                    anyhow::Ok(())
286                }
287            });
288
289            let _control_task = cx.background_executor().spawn({
290                async move {
291                    while let Some(message) = control_request_rx.next().await {
292                        control_socket.send(message).await.ok();
293                        let reply = control_socket.read().await?;
294                        control_reply_tx.send(reply).await?;
295                    }
296                    anyhow::Ok(())
297                }
298            });
299
300            anyhow::Ok((
301                Self {
302                    process,
303                    request_tx,
304                    working_directory,
305                    _shell_task,
306                    _iopub_task,
307                    _control_task,
308                    _routing_task,
309                    connection_path,
310                    execution_state: ExecutionState::Busy,
311                    kernel_info: None,
312                },
313                messages_rx,
314            ))
315        })
316    }
317}
318
319impl Drop for RunningKernel {
320    fn drop(&mut self) {
321        std::fs::remove_file(&self.connection_path).ok();
322
323        self.request_tx.close_channel();
324    }
325}
326
327async fn read_kernelspec_at(
328    // Path should be a directory to a jupyter kernelspec, as in
329    // /usr/local/share/jupyter/kernels/python3
330    kernel_dir: PathBuf,
331    fs: &dyn Fs,
332) -> Result<KernelSpecification> {
333    let path = kernel_dir;
334    let kernel_name = if let Some(kernel_name) = path.file_name() {
335        kernel_name.to_string_lossy().to_string()
336    } else {
337        anyhow::bail!("Invalid kernelspec directory: {path:?}");
338    };
339
340    if !fs.is_dir(path.as_path()).await {
341        anyhow::bail!("Not a directory: {path:?}");
342    }
343
344    let expected_kernel_json = path.join("kernel.json");
345    let spec = fs.load(expected_kernel_json.as_path()).await?;
346    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
347
348    Ok(KernelSpecification {
349        name: kernel_name,
350        path,
351        kernelspec: spec,
352    })
353}
354
355/// Read a directory of kernelspec directories
356async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<KernelSpecification>> {
357    let mut kernelspec_dirs = fs.read_dir(&path).await?;
358
359    let mut valid_kernelspecs = Vec::new();
360    while let Some(path) = kernelspec_dirs.next().await {
361        match path {
362            Ok(path) => {
363                if fs.is_dir(path.as_path()).await {
364                    if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
365                        valid_kernelspecs.push(kernelspec);
366                    }
367                }
368            }
369            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
370        }
371    }
372
373    Ok(valid_kernelspecs)
374}
375
376pub async fn kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<KernelSpecification>> {
377    let data_dirs = dirs::data_dirs();
378    let kernel_dirs = data_dirs
379        .iter()
380        .map(|dir| dir.join("kernels"))
381        .map(|path| read_kernels_dir(path, fs.as_ref()))
382        .collect::<Vec<_>>();
383
384    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
385    let kernel_dirs = kernel_dirs
386        .into_iter()
387        .filter_map(Result::ok)
388        .flatten()
389        .collect::<Vec<_>>();
390
391    Ok(kernel_dirs)
392}
393
394#[cfg(test)]
395mod test {
396    use super::*;
397    use std::path::PathBuf;
398
399    use gpui::TestAppContext;
400    use project::FakeFs;
401    use serde_json::json;
402
403    #[gpui::test]
404    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
405        let fs = FakeFs::new(cx.executor());
406        fs.insert_tree(
407            "/jupyter",
408            json!({
409                ".zed": {
410                    "settings.json": r#"{ "tab_size": 8 }"#,
411                    "tasks.json": r#"[{
412                        "label": "cargo check",
413                        "command": "cargo",
414                        "args": ["check", "--all"]
415                    },]"#,
416                },
417                "kernels": {
418                    "python": {
419                        "kernel.json": r#"{
420                            "display_name": "Python 3",
421                            "language": "python",
422                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
423                            "env": {}
424                        }"#
425                    },
426                    "deno": {
427                        "kernel.json": r#"{
428                            "display_name": "Deno",
429                            "language": "typescript",
430                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
431                            "env": {}
432                        }"#
433                    }
434                },
435            }),
436        )
437        .await;
438
439        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
440            .await
441            .unwrap();
442
443        kernels.sort_by(|a, b| a.name.cmp(&b.name));
444
445        assert_eq!(
446            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
447            vec!["deno", "python"]
448        );
449    }
450}