kernels.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    channel::mpsc::{self, Receiver},
  4    future::Shared,
  5    stream::{self, SelectAll, StreamExt},
  6    SinkExt as _,
  7};
  8use gpui::{AppContext, EntityId, Task};
  9use project::Fs;
 10use runtimelib::{
 11    dirs, ConnectionInfo, ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent,
 12    KernelInfoReply,
 13};
 14use smol::{net::TcpListener, process::Command};
 15use std::{
 16    env,
 17    fmt::Debug,
 18    net::{IpAddr, Ipv4Addr, SocketAddr},
 19    path::PathBuf,
 20    sync::Arc,
 21};
 22use uuid::Uuid;
 23
 24#[derive(Debug, Clone)]
 25pub struct KernelSpecification {
 26    pub name: String,
 27    pub path: PathBuf,
 28    pub kernelspec: JupyterKernelspec,
 29}
 30
 31impl KernelSpecification {
 32    #[must_use]
 33    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 34        let argv = &self.kernelspec.argv;
 35
 36        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 37        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 38        anyhow::ensure!(
 39            argv.iter().any(|arg| arg == "{connection_file}"),
 40            "Missing 'connection_file' in argv in kernelspec {}",
 41            self.name
 42        );
 43
 44        let mut cmd = Command::new(&argv[0]);
 45
 46        for arg in &argv[1..] {
 47            if arg == "{connection_file}" {
 48                cmd.arg(connection_path);
 49            } else {
 50                cmd.arg(arg);
 51            }
 52        }
 53
 54        if let Some(env) = &self.kernelspec.env {
 55            cmd.envs(env);
 56        }
 57
 58        Ok(cmd)
 59    }
 60}
 61
 62// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 63// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 64async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 65    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 66    addr_zeroport.set_port(0);
 67    let mut ports: [u16; 5] = [0; 5];
 68    for i in 0..5 {
 69        let listener = TcpListener::bind(addr_zeroport).await?;
 70        let addr = listener.local_addr()?;
 71        ports[i] = addr.port();
 72    }
 73    Ok(ports)
 74}
 75
 76#[derive(Debug, Clone)]
 77pub enum KernelStatus {
 78    Idle,
 79    Busy,
 80    Starting,
 81    Error,
 82    ShuttingDown,
 83    Shutdown,
 84}
 85
 86impl KernelStatus {
 87    pub fn is_connected(&self) -> bool {
 88        match self {
 89            KernelStatus::Idle | KernelStatus::Busy => true,
 90            _ => false,
 91        }
 92    }
 93}
 94
 95impl ToString for KernelStatus {
 96    fn to_string(&self) -> String {
 97        match self {
 98            KernelStatus::Idle => "Idle".to_string(),
 99            KernelStatus::Busy => "Busy".to_string(),
100            KernelStatus::Starting => "Starting".to_string(),
101            KernelStatus::Error => "Error".to_string(),
102            KernelStatus::ShuttingDown => "Shutting Down".to_string(),
103            KernelStatus::Shutdown => "Shutdown".to_string(),
104        }
105    }
106}
107
108impl From<&Kernel> for KernelStatus {
109    fn from(kernel: &Kernel) -> Self {
110        match kernel {
111            Kernel::RunningKernel(kernel) => match kernel.execution_state {
112                ExecutionState::Idle => KernelStatus::Idle,
113                ExecutionState::Busy => KernelStatus::Busy,
114            },
115            Kernel::StartingKernel(_) => KernelStatus::Starting,
116            Kernel::ErroredLaunch(_) => KernelStatus::Error,
117            Kernel::ShuttingDown => KernelStatus::ShuttingDown,
118            Kernel::Shutdown => KernelStatus::Shutdown,
119        }
120    }
121}
122
123#[derive(Debug)]
124pub enum Kernel {
125    RunningKernel(RunningKernel),
126    StartingKernel(Shared<Task<()>>),
127    ErroredLaunch(String),
128    ShuttingDown,
129    Shutdown,
130}
131
132impl Kernel {
133    pub fn status(&self) -> KernelStatus {
134        self.into()
135    }
136
137    pub fn set_execution_state(&mut self, status: &ExecutionState) {
138        match self {
139            Kernel::RunningKernel(running_kernel) => {
140                running_kernel.execution_state = status.clone();
141            }
142            _ => {}
143        }
144    }
145
146    pub fn set_kernel_info(&mut self, kernel_info: &KernelInfoReply) {
147        match self {
148            Kernel::RunningKernel(running_kernel) => {
149                running_kernel.kernel_info = Some(kernel_info.clone());
150            }
151            _ => {}
152        }
153    }
154
155    pub fn is_shutting_down(&self) -> bool {
156        match self {
157            Kernel::ShuttingDown => true,
158            Kernel::RunningKernel(_)
159            | Kernel::StartingKernel(_)
160            | Kernel::ErroredLaunch(_)
161            | Kernel::Shutdown => false,
162        }
163    }
164}
165
166pub struct RunningKernel {
167    pub process: smol::process::Child,
168    _shell_task: Task<Result<()>>,
169    _iopub_task: Task<Result<()>>,
170    _control_task: Task<Result<()>>,
171    _routing_task: Task<Result<()>>,
172    connection_path: PathBuf,
173    pub working_directory: PathBuf,
174    pub request_tx: mpsc::Sender<JupyterMessage>,
175    pub execution_state: ExecutionState,
176    pub kernel_info: Option<KernelInfoReply>,
177}
178
179type JupyterMessageChannel = stream::SelectAll<Receiver<JupyterMessage>>;
180
181impl Debug for RunningKernel {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        f.debug_struct("RunningKernel")
184            .field("process", &self.process)
185            .finish()
186    }
187}
188
189impl RunningKernel {
190    pub fn new(
191        kernel_specification: KernelSpecification,
192        entity_id: EntityId,
193        working_directory: PathBuf,
194        fs: Arc<dyn Fs>,
195        cx: &mut AppContext,
196    ) -> Task<Result<(Self, JupyterMessageChannel)>> {
197        cx.spawn(|cx| async move {
198            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
199            let ports = peek_ports(ip).await?;
200
201            let connection_info = ConnectionInfo {
202                transport: "tcp".to_string(),
203                ip: ip.to_string(),
204                stdin_port: ports[0],
205                control_port: ports[1],
206                hb_port: ports[2],
207                shell_port: ports[3],
208                iopub_port: ports[4],
209                signature_scheme: "hmac-sha256".to_string(),
210                key: uuid::Uuid::new_v4().to_string(),
211                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
212            };
213
214            let runtime_dir = dirs::runtime_dir();
215            fs.create_dir(&runtime_dir)
216                .await
217                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
218            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
219            let content = serde_json::to_string(&connection_info)?;
220            fs.atomic_write(connection_path.clone(), content).await?;
221
222            let mut cmd = kernel_specification.command(&connection_path)?;
223
224            let process = cmd
225                .current_dir(&working_directory)
226                .stdout(std::process::Stdio::piped())
227                .stderr(std::process::Stdio::piped())
228                .stdin(std::process::Stdio::piped())
229                .kill_on_drop(true)
230                .spawn()
231                .context("failed to start the kernel process")?;
232
233            let session_id = Uuid::new_v4().to_string();
234
235            let mut iopub_socket = connection_info
236                .create_client_iopub_connection("", &session_id)
237                .await?;
238            let mut shell_socket = connection_info
239                .create_client_shell_connection(&session_id)
240                .await?;
241            let mut control_socket = connection_info
242                .create_client_control_connection(&session_id)
243                .await?;
244
245            let (mut iopub, iosub) = futures::channel::mpsc::channel(100);
246
247            let (request_tx, mut request_rx) =
248                futures::channel::mpsc::channel::<JupyterMessage>(100);
249
250            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
251            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
252
253            let mut messages_rx = SelectAll::new();
254            messages_rx.push(iosub);
255            messages_rx.push(control_reply_rx);
256            messages_rx.push(shell_reply_rx);
257
258            let _iopub_task = cx.background_executor().spawn({
259                async move {
260                    while let Ok(message) = iopub_socket.read().await {
261                        iopub.send(message).await?;
262                    }
263                    anyhow::Ok(())
264                }
265            });
266
267            let (mut control_request_tx, mut control_request_rx) =
268                futures::channel::mpsc::channel(100);
269            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
270
271            let _routing_task = cx.background_executor().spawn({
272                async move {
273                    while let Some(message) = request_rx.next().await {
274                        match message.content {
275                            JupyterMessageContent::DebugRequest(_)
276                            | JupyterMessageContent::InterruptRequest(_)
277                            | JupyterMessageContent::ShutdownRequest(_) => {
278                                control_request_tx.send(message).await?;
279                            }
280                            _ => {
281                                shell_request_tx.send(message).await?;
282                            }
283                        }
284                    }
285                    anyhow::Ok(())
286                }
287            });
288
289            let _shell_task = cx.background_executor().spawn({
290                async move {
291                    while let Some(message) = shell_request_rx.next().await {
292                        shell_socket.send(message).await.ok();
293                        let reply = shell_socket.read().await?;
294                        shell_reply_tx.send(reply).await?;
295                    }
296                    anyhow::Ok(())
297                }
298            });
299
300            let _control_task = cx.background_executor().spawn({
301                async move {
302                    while let Some(message) = control_request_rx.next().await {
303                        control_socket.send(message).await.ok();
304                        let reply = control_socket.read().await?;
305                        control_reply_tx.send(reply).await?;
306                    }
307                    anyhow::Ok(())
308                }
309            });
310
311            anyhow::Ok((
312                Self {
313                    process,
314                    request_tx,
315                    working_directory,
316                    _shell_task,
317                    _iopub_task,
318                    _control_task,
319                    _routing_task,
320                    connection_path,
321                    execution_state: ExecutionState::Busy,
322                    kernel_info: None,
323                },
324                messages_rx,
325            ))
326        })
327    }
328}
329
330impl Drop for RunningKernel {
331    fn drop(&mut self) {
332        std::fs::remove_file(&self.connection_path).ok();
333        self.request_tx.close_channel();
334        self.process.kill().ok();
335    }
336}
337
338async fn read_kernelspec_at(
339    // Path should be a directory to a jupyter kernelspec, as in
340    // /usr/local/share/jupyter/kernels/python3
341    kernel_dir: PathBuf,
342    fs: &dyn Fs,
343) -> Result<KernelSpecification> {
344    let path = kernel_dir;
345    let kernel_name = if let Some(kernel_name) = path.file_name() {
346        kernel_name.to_string_lossy().to_string()
347    } else {
348        anyhow::bail!("Invalid kernelspec directory: {path:?}");
349    };
350
351    if !fs.is_dir(path.as_path()).await {
352        anyhow::bail!("Not a directory: {path:?}");
353    }
354
355    let expected_kernel_json = path.join("kernel.json");
356    let spec = fs.load(expected_kernel_json.as_path()).await?;
357    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
358
359    Ok(KernelSpecification {
360        name: kernel_name,
361        path,
362        kernelspec: spec,
363    })
364}
365
366/// Read a directory of kernelspec directories
367async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<KernelSpecification>> {
368    let mut kernelspec_dirs = fs.read_dir(&path).await?;
369
370    let mut valid_kernelspecs = Vec::new();
371    while let Some(path) = kernelspec_dirs.next().await {
372        match path {
373            Ok(path) => {
374                if fs.is_dir(path.as_path()).await {
375                    if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
376                        valid_kernelspecs.push(kernelspec);
377                    }
378                }
379            }
380            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
381        }
382    }
383
384    Ok(valid_kernelspecs)
385}
386
387pub async fn kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<KernelSpecification>> {
388    let mut data_dirs = dirs::data_dirs();
389
390    // Pick up any kernels from conda or conda environment
391    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
392        let conda_prefix = PathBuf::from(conda_prefix);
393        let conda_data_dir = conda_prefix.join("share").join("jupyter");
394        data_dirs.push(conda_data_dir);
395    }
396
397    // Search for kernels inside the base python environment
398    let command = Command::new("python")
399        .arg("-c")
400        .arg("import sys; print(sys.prefix)")
401        .output()
402        .await;
403
404    if let Ok(command) = command {
405        if command.status.success() {
406            let python_prefix = String::from_utf8(command.stdout);
407            if let Ok(python_prefix) = python_prefix {
408                let python_prefix = PathBuf::from(python_prefix.trim());
409                let python_data_dir = python_prefix.join("share").join("jupyter");
410                data_dirs.push(python_data_dir);
411            }
412        }
413    }
414
415    let kernel_dirs = data_dirs
416        .iter()
417        .map(|dir| dir.join("kernels"))
418        .map(|path| read_kernels_dir(path, fs.as_ref()))
419        .collect::<Vec<_>>();
420
421    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
422    let kernel_dirs = kernel_dirs
423        .into_iter()
424        .filter_map(Result::ok)
425        .flatten()
426        .collect::<Vec<_>>();
427
428    Ok(kernel_dirs)
429}
430
431#[cfg(test)]
432mod test {
433    use super::*;
434    use std::path::PathBuf;
435
436    use gpui::TestAppContext;
437    use project::FakeFs;
438    use serde_json::json;
439
440    #[gpui::test]
441    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
442        let fs = FakeFs::new(cx.executor());
443        fs.insert_tree(
444            "/jupyter",
445            json!({
446                ".zed": {
447                    "settings.json": r#"{ "tab_size": 8 }"#,
448                    "tasks.json": r#"[{
449                        "label": "cargo check",
450                        "command": "cargo",
451                        "args": ["check", "--all"]
452                    },]"#,
453                },
454                "kernels": {
455                    "python": {
456                        "kernel.json": r#"{
457                            "display_name": "Python 3",
458                            "language": "python",
459                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
460                            "env": {}
461                        }"#
462                    },
463                    "deno": {
464                        "kernel.json": r#"{
465                            "display_name": "Deno",
466                            "language": "typescript",
467                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
468                            "env": {}
469                        }"#
470                    }
471                },
472            }),
473        )
474        .await;
475
476        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
477            .await
478            .unwrap();
479
480        kernels.sort_by(|a, b| a.name.cmp(&b.name));
481
482        assert_eq!(
483            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
484            vec!["deno", "python"]
485        );
486    }
487}